baidu-reporter/app/engines/crawl_engine.py

299 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import queue
import threading
import time
from DrissionPage.errors import ElementNotFoundError
from loguru import logger
from sqlmodel import Session, select
from app.config.config import AppCtx
from app.models.domain import DomainModel
from app.models.report_urls import ReportUrlModel
from app.utils.dp import DPEngine
class CrawlEngine:
"""色站URL采集器,自动在百度翻页搜索,并保存搜索结果
对应原项目中的 getBaiDuIncludeUrls 方法
"""
def __init__(self):
self.ev = threading.Event()
# 控制线程的运行状态1-运行0-停止
self.worker_status = 1
# 标记任务是否添加完成,是 CLI 模式正常结束的标记1-任务添加完成没有新任务了0-还有任务待添加
self.finish_task = 0
# 线程池
self.pool: list[threading.Thread] = []
self.worker_count = 2
# 工作队列
self.target_queue = queue.Queue(1024)
# 创建一个浏览器
self.dp_engine = DPEngine()
self.database = AppCtx.g_db_engine
def cli_start(self, target_domains: str, target_domain_filepath: str):
"""CLI 模式启动
target_domains: 英文逗号分割的字符串
target_domain_filepath: 存放目标域名的文件,每行一个
"""
# 把输入的域名先存入数据库
domains = self.add_domain(target_domains, target_domain_filepath)
# 启动线程池
for idx in range(self.worker_count):
# CLI 启动模式下和web的启动模式逻辑不一样这里要单独写一个cli worker
thread = threading.Thread(target=self.cli_worker, name=f"crawl_engine_{idx}", daemon=True)
self.pool.append(thread)
thread.start()
# 添加任务到队列中
for domain in domains:
logger.debug(f"开始添加 {domain} 到任务队列")
while True:
try:
self.target_queue.put_nowait(domain)
break
except queue.Full:
self.ev.wait(5)
continue
# 添加完成了,把标记为改为完成
self.finish_task = 1
def cli_wait(self):
[x.join() for x in self.pool]
def cli_worker(self):
"""CLI 模式下的 worker只处理输入的域名忽略数据库中的刷新条件"""
while True:
try:
# 当控制位被操作的时候,结束
if not self.worker_status:
logger.debug(f"{threading.current_thread().name}控制位退出")
break
# 当队列空了,并且任务已经全部推送完成的时候,标记为结束
if self.target_queue.empty() and self.finish_task:
logger.debug(f"{threading.current_thread().name}队列空了退出")
break
# 获取数据并开始采集
domain = self.target_queue.get_nowait()
surl = self.crawl(domain)
if not surl:
logger.debug(f"{threading.current_thread().name} 爬取 {domain} 无结果,开始处理下一个")
continue
# 存入数据库
with Session(self.database) as session:
self.save_surl(session, domain, surl)
except queue.Empty:
# 队列空了等1秒再取一次
self.ev.wait(1)
continue
except Exception as e:
logger.exception(f"{threading.current_thread().name} 执行错误:{e}")
continue
logger.info(f"{threading.current_thread().name} 退出")
def add_domain(self, input_domains: str, input_domain_filepath: str) -> list[str]:
"""把输入的域名存到库里"""
# 生成所有待采集的域名列表
domains = []
if input_domains:
domains.extend([d.strip() for d in input_domains.split(",") if d.strip()])
if input_domain_filepath:
with open(input_domain_filepath, "r") as fp:
for line in fp:
line = line.strip()
if line:
domains.append(line)
logger.debug(f"总共输入 {len(domains)} 个域名")
# 检查在数据库中是否有重复的
for domain in domains:
with Session(self.database) as session:
stmt = select(DomainModel).where(DomainModel.domain == domain)
result = session.exec(stmt).first()
if not result:
example = DomainModel(
domain=domain, status=1, crawl_interval=60 * 7 * 24, latest_crawl_time=0,
)
session.add(example)
session.commit()
logger.info(f"添加域名 {domain} 到数据库")
return domains
def stop(self):
"""停止采集器,通用的"""
self.ev.set()
self.worker_status = 0
self.dp_engine.browser.quit()
def start(self):
"""启动采集器,以 web 方式启动的时候,走这边"""
for idx in range(self.worker_count):
thread = threading.Thread(target=self.worker, name=f"crawl_engine_{idx}", daemon=True)
self.pool.append(thread)
thread.start()
def worker(self):
"""真正的工作函数后续以Web模式启动的时候走这个"""
logger.info("crawl worker start!")
while self.worker_status == 1:
# 检查数据库,从中获取需要爬取的域名
current_timestamp = int(time.time())
with Session(AppCtx.g_db_engine) as session:
stmt = select(DomainModel).where(
DomainModel.latest_crawl_time + DomainModel.crawl_interval <= current_timestamp
)
domains = session.exec(stmt).all()
for domain_model in domains:
# 采集
surl_set = self.crawl(domain_model.domain)
# 存储
if surl_set:
self.save_surl(session, domain_model, surl_set)
domain_model.latest_crawl_time = int(time.time())
session.add(domain_model)
session.commit()
self.ev.wait(60)
logger.info("crawl worker stop!")
def crawl(self, domain: str) -> set[str] | None:
"""爬取URL的函数"""
logger.debug(f"{threading.current_thread().name} 开始爬取:{domain}")
tab = self.dp_engine.browser.new_tab()
surl_set: set[str] = set()
try:
# 初始数据
end_time = int(time.time())
start_time = end_time - 3600 * 24 * 30 # 获取最近一个月的数据
# 依次每一页处理
max_page = 10 # 最大页码数量0表示不限制最大数量
current_page = 0 # 当前页码
# 先打开搜索页面
tab.get("https://www.baidu.com/")
tab.wait.eles_loaded("#kw")
tab.ele("#kw").input(f"site:{domain}\n", clear=True)
tab.wait.eles_loaded("#container")
tab.wait.eles_loaded("#timeRlt")
logger.debug(f"{threading.current_thread().name} #timeRlt 加载完成!")
# 设置搜索时间范围
self.ev.wait(2)
tab.ele("#timeRlt").click(True)
tab.wait.eles_loaded("@class:time_pop_")
self.ev.wait(2)
# logger.debug("时间菜单!")
tab.ele("t:li@@text()= 一月内 ").click(True)
tab.wait.eles_loaded(["#container", ".content_none", "#content_left"], any_one=True)
while True:
try:
# 增加页码
current_page += 1
logger.debug(f"{threading.current_thread().name} 爬取 {domain} 的第 {current_page} 页数据")
# 直接访问 URL 会触发验证码
# tab.get(
# f"https://www.baidu.com/s?wd=site%3A{domain}&gpc=stf%3D{start_time}%2C{end_time}%7Cstftype%3D1&pn={(current_page - 1) * 10}")
# tab.get(f"https://www.baidu.com/s?wd=site%3A{domain}&pn={(current_page - 1) * 10}")
# 终止条件
if current_page > max_page and max_page:
logger.debug(f"{threading.current_thread().name} 达到指定页码,退出")
break
# logger.debug(f"tab.html: {tab.html}")
self.ev.wait(0.3)
if "未找到相关结果" in tab.html:
logger.debug(f"{threading.current_thread().name} 未找到结果,退出")
break
# 获取数据
tab.wait.eles_loaded("@id=content_left")
results = tab.ele("@id=content_left").eles("@class:result")
# temp = [result.attr("mu") for result in results if result.attr("mu") is not None]
# logger.debug(f"{len(results)=}")
for result in results:
# logger.debug(f"{result=}")
surl = result.attr("mu")
if not surl:
continue
# 添加结果的时候,也检查一下抓到的 surl 是否和目标域名有关
if domain not in surl:
logger.debug(f"{threading.current_thread().name} URL {surl} 与目标域名 {domain} 无关,跳过")
else:
surl_set.add(surl)
logger.debug(f"{threading.current_thread().name} 找到 {surl}")
# 翻页的时候等一下,别太快了
self.ev.wait(0.3)
# 如果没有下一页了,这个地方会找不到元素,有 10 秒的 timeout
next_btn = tab.ele("t:a@@text():下一页")
if not next_btn:
logger.debug(f"{threading.current_thread().name} 没有下一页了")
break
next_btn.click(True)
except ElementNotFoundError as e:
logger.error(f"没有找到 HTML 元素,跳过,详细信息: {e}")
break
return surl_set
except Exception as e:
logger.error(f"{threading.current_thread().name} 爬取 {domain} 发生错误:{e}")
import traceback
traceback.print_exc()
finally:
tab.close()
@staticmethod
def save_surl(session: Session, domain: str, surl_set: set[str]):
"""保存采集到的URL"""
for surl in surl_set:
# 简单的判断一下 surl 中是否包含目标域名
if domain not in surl:
logger.debug(f"跳过保存 {surl} 因为与目标域名 {domain} 不符合")
continue
# 先检查是否存在
stmt = select(ReportUrlModel).where(ReportUrlModel.surl == surl)
exist = session.exec(stmt).first()
if exist:
continue
# 检查域名是否存在
domain_model = session.exec(
select(DomainModel).where(DomainModel.domain == domain)
).first()
if not domain_model:
logger.warning(f"域名 {domain} 不在数据库中")
continue
example = ReportUrlModel(
domain_id=domain_model.id,
domain=domain_model.domain,
surl=surl,
)
session.add(example)
session.commit()