diff --git a/app/app.py b/app/app.py index bd4aa3e..d038f27 100644 --- a/app/app.py +++ b/app/app.py @@ -27,6 +27,11 @@ class MainApp: self.config: AppConfig = None self.db_engine = None + # 所有的engine + self.crawl_engine = None + self.evidence_engine = None + self.report_engine = None + def parse_args(self): """解析命令行参数""" parser = argparse.ArgumentParser(description="Baidu Reporter") @@ -124,6 +129,19 @@ class MainApp: # 注册 ctrl+c 处理程序,正常结束所有的 engine signal.signal(signal.SIGINT, self.exit_handler) + # 启动所有的 engine + self.crawl_engine = CrawlEngine() + self.crawl_engine.start() + logger.info("crawl 启动") + + self.evidence_engine = EvidenceEngine() + self.evidence_engine.start() + logger.info("evidence 启动") + + self.report_engine = Reporter(["pc", "site", "wap"]) + self.report_engine.start() + logger.info("report 启动") + # 启动 web 页面 web_app = WebApp() asyncio.run(web_app.start()) @@ -171,5 +189,17 @@ class MainApp: return self.start_cli() def exit_handler(self, signum, frame): - # TODO 在这里结束各个 engine - print("CTRL+C called.") \ No newline at end of file + # 在这里结束各个 engine + logger.debug("CTRL+C called.") + + self.crawl_engine.stop() + self.crawl_engine.cli_wait() + logger.info("crawl 退出") + + self.evidence_engine.stop() + self.evidence_engine.wait() + logger.info("evidence 退出") + + self.report_engine.stop() + self.report_engine.wait() + logger.info("report 退出") diff --git a/app/engines/crawl_engine.py b/app/engines/crawl_engine.py index 9a24000..40e985d 100644 --- a/app/engines/crawl_engine.py +++ b/app/engines/crawl_engine.py @@ -7,6 +7,7 @@ from loguru import logger from sqlmodel import Session, select from app.config.config import AppCtx +from app.constants.domain import DomainStatus from app.models.domain import DomainModel from app.models.report_urls import ReportUrlModel from app.utils.dp import DPEngine @@ -27,7 +28,7 @@ class CrawlEngine: # 线程池 self.pool: list[threading.Thread] = [] - self.worker_count = 2 + self.worker_count = 1 # 工作队列 self.target_queue = queue.Queue(1024) @@ -148,25 +149,32 @@ class CrawlEngine: def worker(self): """真正的工作函数,后续以Web模式启动的时候,走这个""" logger.info("crawl worker start!") - while self.worker_status == 1: + while self.worker_status: # 检查数据库,从中获取需要爬取的域名 current_timestamp = int(time.time()) with Session(AppCtx.g_db_engine) as session: stmt = select(DomainModel).where( - DomainModel.latest_crawl_time + DomainModel.crawl_interval <= current_timestamp + DomainModel.latest_crawl_time + DomainModel.crawl_interval * 60 <= current_timestamp ) domains = session.exec(stmt).all() for domain_model in domains: + + # 采集前修改状态 + domain_model.status = DomainStatus.CRAWLING.value + session.add(domain_model) + session.commit() + # 采集 surl_set = self.crawl(domain_model.domain) # 存储 if surl_set: - self.save_surl(session, domain_model, surl_set) + self.save_surl(session, domain_model.domain, surl_set) domain_model.latest_crawl_time = int(time.time()) + domain_model.status = DomainStatus.READY.value session.add(domain_model) session.commit() @@ -182,8 +190,8 @@ class CrawlEngine: try: # 初始数据 - end_time = int(time.time()) - start_time = end_time - 3600 * 24 * 30 # 获取最近一个月的数据 + # end_time = int(time.time()) + # start_time = end_time - 3600 * 24 * 30 # 获取最近一个月的数据 # 依次每一页处理 max_page = 10 # 最大页码数量,0表示不限制最大数量 diff --git a/app/engines/evidence_engine.py b/app/engines/evidence_engine.py index 9cc4d00..40c3ded 100644 --- a/app/engines/evidence_engine.py +++ b/app/engines/evidence_engine.py @@ -59,8 +59,8 @@ class EvidenceEngine: logger.debug(f"开始获取 {target['surl']} 的举报数据") self.get_screenshot_and_report_link(target) - # 每分钟跑一次 - self.ev.wait(60) + # 每10秒跑一次 + self.ev.wait(10) def stop(self): """结束线程""" @@ -69,6 +69,9 @@ class EvidenceEngine: self.dp_engine.close() self.wap_dp_engine.close() + def wait(self): + self.worker_thread.join() + def get_surl_from_db(self): """从数据库中获取数据""" result: list = [] diff --git a/app/engines/report_engine.py b/app/engines/report_engine.py index c11a7d7..f3bedde 100644 --- a/app/engines/report_engine.py +++ b/app/engines/report_engine.py @@ -33,6 +33,7 @@ class Reporter: def wait(self): self.worker_thread.join() + # noinspection DuplicatedCode def cli_start(self): for mode in self.mode: if mode == "pc": @@ -49,6 +50,7 @@ class Reporter: self.status = 0 self.ev.set() + # noinspection DuplicatedCode def worker(self): while self.status: for mode in self.mode: diff --git a/app/engines/reporters/base.py b/app/engines/reporters/base.py index adefa81..3b6c85a 100644 --- a/app/engines/reporters/base.py +++ b/app/engines/reporters/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod + class BaseReporter(ABC): """所有 reporter 的基类""" @@ -7,3 +8,7 @@ class BaseReporter(ABC): def run(self): """运行 reporter,子类必须实现此方法""" pass + + def stop(self): + """控制结束""" + pass diff --git a/app/engines/reporters/pc_reporter.py b/app/engines/reporters/pc_reporter.py index fa56b42..815b305 100644 --- a/app/engines/reporters/pc_reporter.py +++ b/app/engines/reporters/pc_reporter.py @@ -1,5 +1,6 @@ import os.path import random +import threading import time from urllib.parse import urlparse, parse_qs @@ -19,6 +20,8 @@ from ...utils.ydm_verify import YdmVerify class PcReporter(BaseReporter): def __init__(self): self.engine_name = "PC_REPORTER" + self.status = 1 + self.ev = threading.Event() self.database = AppCtx.g_db_engine self.upload_pic_url = "http://jubao.baidu.com/jubao/accu/upload" @@ -45,6 +48,10 @@ class PcReporter(BaseReporter): "Cookie": "", } + def stop(self): + self.status = 0 + self.ev.set() + def run(self): with Session(self.database) as session: stmt = select(ReportUrlModel).where(ReportUrlModel.is_report_by_one == False) @@ -53,6 +60,10 @@ class PcReporter(BaseReporter): logger.info(f"[{self.engine_name}] 共计 {len(rows)} 条记录需要举报") for row in rows: + + if not self.status: + break + # 选个 cookie report_cookie = random.choice(get_all_cookies()) self.headers["Cookie"] = report_cookie @@ -120,6 +131,8 @@ class PcReporter(BaseReporter): retry += 1 + self.ev.wait(5) + def do_report(self, ds, tk, surl, token, title, q, upload=''): try: phone = generate_random_phone_number() diff --git a/app/engines/reporters/site_reporter.py b/app/engines/reporters/site_reporter.py index 874828b..ed9c38c 100644 --- a/app/engines/reporters/site_reporter.py +++ b/app/engines/reporters/site_reporter.py @@ -1,6 +1,7 @@ import os.path import random import re +import threading import time import requests @@ -19,6 +20,8 @@ from ...utils.ua import random_ua class SiteReporter(BaseReporter): def __init__(self): self.engine_name = "SITE_REPORTER" + self.status = 1 + self.ev = threading.Event() self.upload_pic_url = "https://help.baidu.com/api/mpic" self.report_url = "https://help.baidu.com/jubaosubmit" @@ -42,6 +45,10 @@ class SiteReporter(BaseReporter): self.token_pattern = r'name="submit_token" value="(.*?)"' + def stop(self): + self.status = 0 + self.ev.set() + def run(self): """实现 PC 端的举报逻辑""" with Session(self.database) as session: @@ -51,6 +58,10 @@ class SiteReporter(BaseReporter): logger.info(f"[{self.engine_name}] 共计 {len(rows)} 条需要举报") for row in rows: + + if not self.status: + break + # 生成举报需要的基础数据 surl = row.surl q = row.q @@ -77,8 +88,8 @@ class SiteReporter(BaseReporter): session.add(row) session.commit() - # 等待5秒继续举报 - time.sleep(5) + # 等待5秒继续举报 + self.ev.wait(5) def upload_pic(self, img_path: str): try: diff --git a/app/engines/reporters/wap_reporter.py b/app/engines/reporters/wap_reporter.py index 4a92dab..51ad315 100644 --- a/app/engines/reporters/wap_reporter.py +++ b/app/engines/reporters/wap_reporter.py @@ -2,6 +2,7 @@ import base64 import json import os.path import random +import threading import time import requests @@ -20,6 +21,8 @@ class WapReporter(BaseReporter): def __init__(self): self.engine_name = "WAP_REPORTER" + self.status = 1 + self.ev = threading.Event() self.report_url = "https://ufosdk.baidu.com/api?m=Client&a=postMsg" self.request = requests.session() @@ -40,6 +43,10 @@ class WapReporter(BaseReporter): self.database = AppCtx.g_db_engine self.all_cookies = get_all_cookies() + def stop(self): + self.status = 0 + self.ev.set() + def run(self): """实现 WAP 端的举报逻辑""" with Session(self.database) as session: @@ -50,6 +57,9 @@ class WapReporter(BaseReporter): for row in rows: + if not self.status: + break + # 选个 cookie report_cookie = random.choice(get_all_cookies()) report_site_cookie = GenCookie.run(report_cookie) @@ -74,7 +84,7 @@ class WapReporter(BaseReporter): session.add(row) session.commit() - time.sleep(5) + self.ev.wait(5) def get_user_info(self): try: diff --git a/app/web/service/domain_service.py b/app/web/service/domain_service.py index 8b1dfd3..c3c303b 100644 --- a/app/web/service/domain_service.py +++ b/app/web/service/domain_service.py @@ -50,8 +50,8 @@ class DomainService: with Session(AppCtx.g_db_engine) as session: stmt = select(DomainModel).where(DomainModel.domain.in_(domains)) try: - rows = session.exec(stmt) - return ApiResult.ok(rows.all()) + rows = session.exec(stmt).all() + return ApiResult.ok(rows) except Exception as e: session.rollback() return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}") @@ -62,8 +62,8 @@ class DomainService: with Session(AppCtx.g_db_engine) as session: stmt = select(DomainModel).where(DomainModel.id.in_(domain_ids)) try: - rows = session.exec(stmt) - return ApiResult.ok(rows.all()) + rows = session.exec(stmt).all() + return ApiResult.ok(rows) except Exception as e: session.rollback() return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}")