增加引擎启动部分代码

This commit is contained in:
xhy 2025-04-03 22:44:22 +08:00
parent 56ea878c29
commit bb2e09c885
9 changed files with 99 additions and 17 deletions

View File

@ -27,6 +27,11 @@ class MainApp:
self.config: AppConfig = None self.config: AppConfig = None
self.db_engine = None self.db_engine = None
# 所有的engine
self.crawl_engine = None
self.evidence_engine = None
self.report_engine = None
def parse_args(self): def parse_args(self):
"""解析命令行参数""" """解析命令行参数"""
parser = argparse.ArgumentParser(description="Baidu Reporter") parser = argparse.ArgumentParser(description="Baidu Reporter")
@ -124,6 +129,19 @@ class MainApp:
# 注册 ctrl+c 处理程序,正常结束所有的 engine # 注册 ctrl+c 处理程序,正常结束所有的 engine
signal.signal(signal.SIGINT, self.exit_handler) signal.signal(signal.SIGINT, self.exit_handler)
# 启动所有的 engine
self.crawl_engine = CrawlEngine()
self.crawl_engine.start()
logger.info("crawl 启动")
self.evidence_engine = EvidenceEngine()
self.evidence_engine.start()
logger.info("evidence 启动")
self.report_engine = Reporter(["pc", "site", "wap"])
self.report_engine.start()
logger.info("report 启动")
# 启动 web 页面 # 启动 web 页面
web_app = WebApp() web_app = WebApp()
asyncio.run(web_app.start()) asyncio.run(web_app.start())
@ -171,5 +189,17 @@ class MainApp:
return self.start_cli() return self.start_cli()
def exit_handler(self, signum, frame): def exit_handler(self, signum, frame):
# TODO 在这里结束各个 engine # 在这里结束各个 engine
print("CTRL+C called.") logger.debug("CTRL+C called.")
self.crawl_engine.stop()
self.crawl_engine.cli_wait()
logger.info("crawl 退出")
self.evidence_engine.stop()
self.evidence_engine.wait()
logger.info("evidence 退出")
self.report_engine.stop()
self.report_engine.wait()
logger.info("report 退出")

View File

@ -7,6 +7,7 @@ from loguru import logger
from sqlmodel import Session, select from sqlmodel import Session, select
from app.config.config import AppCtx from app.config.config import AppCtx
from app.constants.domain import DomainStatus
from app.models.domain import DomainModel from app.models.domain import DomainModel
from app.models.report_urls import ReportUrlModel from app.models.report_urls import ReportUrlModel
from app.utils.dp import DPEngine from app.utils.dp import DPEngine
@ -27,7 +28,7 @@ class CrawlEngine:
# 线程池 # 线程池
self.pool: list[threading.Thread] = [] self.pool: list[threading.Thread] = []
self.worker_count = 2 self.worker_count = 1
# 工作队列 # 工作队列
self.target_queue = queue.Queue(1024) self.target_queue = queue.Queue(1024)
@ -148,25 +149,32 @@ class CrawlEngine:
def worker(self): def worker(self):
"""真正的工作函数后续以Web模式启动的时候走这个""" """真正的工作函数后续以Web模式启动的时候走这个"""
logger.info("crawl worker start!") logger.info("crawl worker start!")
while self.worker_status == 1: while self.worker_status:
# 检查数据库,从中获取需要爬取的域名 # 检查数据库,从中获取需要爬取的域名
current_timestamp = int(time.time()) current_timestamp = int(time.time())
with Session(AppCtx.g_db_engine) as session: with Session(AppCtx.g_db_engine) as session:
stmt = select(DomainModel).where( stmt = select(DomainModel).where(
DomainModel.latest_crawl_time + DomainModel.crawl_interval <= current_timestamp DomainModel.latest_crawl_time + DomainModel.crawl_interval * 60 <= current_timestamp
) )
domains = session.exec(stmt).all() domains = session.exec(stmt).all()
for domain_model in domains: for domain_model in domains:
# 采集前修改状态
domain_model.status = DomainStatus.CRAWLING.value
session.add(domain_model)
session.commit()
# 采集 # 采集
surl_set = self.crawl(domain_model.domain) surl_set = self.crawl(domain_model.domain)
# 存储 # 存储
if surl_set: if surl_set:
self.save_surl(session, domain_model, surl_set) self.save_surl(session, domain_model.domain, surl_set)
domain_model.latest_crawl_time = int(time.time()) domain_model.latest_crawl_time = int(time.time())
domain_model.status = DomainStatus.READY.value
session.add(domain_model) session.add(domain_model)
session.commit() session.commit()
@ -182,8 +190,8 @@ class CrawlEngine:
try: try:
# 初始数据 # 初始数据
end_time = int(time.time()) # end_time = int(time.time())
start_time = end_time - 3600 * 24 * 30 # 获取最近一个月的数据 # start_time = end_time - 3600 * 24 * 30 # 获取最近一个月的数据
# 依次每一页处理 # 依次每一页处理
max_page = 10 # 最大页码数量0表示不限制最大数量 max_page = 10 # 最大页码数量0表示不限制最大数量

View File

@ -59,8 +59,8 @@ class EvidenceEngine:
logger.debug(f"开始获取 {target['surl']} 的举报数据") logger.debug(f"开始获取 {target['surl']} 的举报数据")
self.get_screenshot_and_report_link(target) self.get_screenshot_and_report_link(target)
# 每分钟跑一次 # 每10秒跑一次
self.ev.wait(60) self.ev.wait(10)
def stop(self): def stop(self):
"""结束线程""" """结束线程"""
@ -69,6 +69,9 @@ class EvidenceEngine:
self.dp_engine.close() self.dp_engine.close()
self.wap_dp_engine.close() self.wap_dp_engine.close()
def wait(self):
self.worker_thread.join()
def get_surl_from_db(self): def get_surl_from_db(self):
"""从数据库中获取数据""" """从数据库中获取数据"""
result: list = [] result: list = []

View File

@ -33,6 +33,7 @@ class Reporter:
def wait(self): def wait(self):
self.worker_thread.join() self.worker_thread.join()
# noinspection DuplicatedCode
def cli_start(self): def cli_start(self):
for mode in self.mode: for mode in self.mode:
if mode == "pc": if mode == "pc":
@ -49,6 +50,7 @@ class Reporter:
self.status = 0 self.status = 0
self.ev.set() self.ev.set()
# noinspection DuplicatedCode
def worker(self): def worker(self):
while self.status: while self.status:
for mode in self.mode: for mode in self.mode:

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class BaseReporter(ABC): class BaseReporter(ABC):
"""所有 reporter 的基类""" """所有 reporter 的基类"""
@ -7,3 +8,7 @@ class BaseReporter(ABC):
def run(self): def run(self):
"""运行 reporter子类必须实现此方法""" """运行 reporter子类必须实现此方法"""
pass pass
def stop(self):
"""控制结束"""
pass

View File

@ -1,5 +1,6 @@
import os.path import os.path
import random import random
import threading
import time import time
from urllib.parse import urlparse, parse_qs from urllib.parse import urlparse, parse_qs
@ -19,6 +20,8 @@ from ...utils.ydm_verify import YdmVerify
class PcReporter(BaseReporter): class PcReporter(BaseReporter):
def __init__(self): def __init__(self):
self.engine_name = "PC_REPORTER" self.engine_name = "PC_REPORTER"
self.status = 1
self.ev = threading.Event()
self.database = AppCtx.g_db_engine self.database = AppCtx.g_db_engine
self.upload_pic_url = "http://jubao.baidu.com/jubao/accu/upload" self.upload_pic_url = "http://jubao.baidu.com/jubao/accu/upload"
@ -45,6 +48,10 @@ class PcReporter(BaseReporter):
"Cookie": "", "Cookie": "",
} }
def stop(self):
self.status = 0
self.ev.set()
def run(self): def run(self):
with Session(self.database) as session: with Session(self.database) as session:
stmt = select(ReportUrlModel).where(ReportUrlModel.is_report_by_one == False) stmt = select(ReportUrlModel).where(ReportUrlModel.is_report_by_one == False)
@ -53,6 +60,10 @@ class PcReporter(BaseReporter):
logger.info(f"[{self.engine_name}] 共计 {len(rows)} 条记录需要举报") logger.info(f"[{self.engine_name}] 共计 {len(rows)} 条记录需要举报")
for row in rows: for row in rows:
if not self.status:
break
# 选个 cookie # 选个 cookie
report_cookie = random.choice(get_all_cookies()) report_cookie = random.choice(get_all_cookies())
self.headers["Cookie"] = report_cookie self.headers["Cookie"] = report_cookie
@ -120,6 +131,8 @@ class PcReporter(BaseReporter):
retry += 1 retry += 1
self.ev.wait(5)
def do_report(self, ds, tk, surl, token, title, q, upload=''): def do_report(self, ds, tk, surl, token, title, q, upload=''):
try: try:
phone = generate_random_phone_number() phone = generate_random_phone_number()

View File

@ -1,6 +1,7 @@
import os.path import os.path
import random import random
import re import re
import threading
import time import time
import requests import requests
@ -19,6 +20,8 @@ from ...utils.ua import random_ua
class SiteReporter(BaseReporter): class SiteReporter(BaseReporter):
def __init__(self): def __init__(self):
self.engine_name = "SITE_REPORTER" self.engine_name = "SITE_REPORTER"
self.status = 1
self.ev = threading.Event()
self.upload_pic_url = "https://help.baidu.com/api/mpic" self.upload_pic_url = "https://help.baidu.com/api/mpic"
self.report_url = "https://help.baidu.com/jubaosubmit" self.report_url = "https://help.baidu.com/jubaosubmit"
@ -42,6 +45,10 @@ class SiteReporter(BaseReporter):
self.token_pattern = r'name="submit_token" value="(.*?)"' self.token_pattern = r'name="submit_token" value="(.*?)"'
def stop(self):
self.status = 0
self.ev.set()
def run(self): def run(self):
"""实现 PC 端的举报逻辑""" """实现 PC 端的举报逻辑"""
with Session(self.database) as session: with Session(self.database) as session:
@ -51,6 +58,10 @@ class SiteReporter(BaseReporter):
logger.info(f"[{self.engine_name}] 共计 {len(rows)} 条需要举报") logger.info(f"[{self.engine_name}] 共计 {len(rows)} 条需要举报")
for row in rows: for row in rows:
if not self.status:
break
# 生成举报需要的基础数据 # 生成举报需要的基础数据
surl = row.surl surl = row.surl
q = row.q q = row.q
@ -78,7 +89,7 @@ class SiteReporter(BaseReporter):
session.commit() session.commit()
# 等待5秒继续举报 # 等待5秒继续举报
time.sleep(5) self.ev.wait(5)
def upload_pic(self, img_path: str): def upload_pic(self, img_path: str):
try: try:

View File

@ -2,6 +2,7 @@ import base64
import json import json
import os.path import os.path
import random import random
import threading
import time import time
import requests import requests
@ -20,6 +21,8 @@ class WapReporter(BaseReporter):
def __init__(self): def __init__(self):
self.engine_name = "WAP_REPORTER" self.engine_name = "WAP_REPORTER"
self.status = 1
self.ev = threading.Event()
self.report_url = "https://ufosdk.baidu.com/api?m=Client&a=postMsg" self.report_url = "https://ufosdk.baidu.com/api?m=Client&a=postMsg"
self.request = requests.session() self.request = requests.session()
@ -40,6 +43,10 @@ class WapReporter(BaseReporter):
self.database = AppCtx.g_db_engine self.database = AppCtx.g_db_engine
self.all_cookies = get_all_cookies() self.all_cookies = get_all_cookies()
def stop(self):
self.status = 0
self.ev.set()
def run(self): def run(self):
"""实现 WAP 端的举报逻辑""" """实现 WAP 端的举报逻辑"""
with Session(self.database) as session: with Session(self.database) as session:
@ -50,6 +57,9 @@ class WapReporter(BaseReporter):
for row in rows: for row in rows:
if not self.status:
break
# 选个 cookie # 选个 cookie
report_cookie = random.choice(get_all_cookies()) report_cookie = random.choice(get_all_cookies())
report_site_cookie = GenCookie.run(report_cookie) report_site_cookie = GenCookie.run(report_cookie)
@ -74,7 +84,7 @@ class WapReporter(BaseReporter):
session.add(row) session.add(row)
session.commit() session.commit()
time.sleep(5) self.ev.wait(5)
def get_user_info(self): def get_user_info(self):
try: try:

View File

@ -50,8 +50,8 @@ class DomainService:
with Session(AppCtx.g_db_engine) as session: with Session(AppCtx.g_db_engine) as session:
stmt = select(DomainModel).where(DomainModel.domain.in_(domains)) stmt = select(DomainModel).where(DomainModel.domain.in_(domains))
try: try:
rows = session.exec(stmt) rows = session.exec(stmt).all()
return ApiResult.ok(rows.all()) return ApiResult.ok(rows)
except Exception as e: except Exception as e:
session.rollback() session.rollback()
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}") return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}")
@ -62,8 +62,8 @@ class DomainService:
with Session(AppCtx.g_db_engine) as session: with Session(AppCtx.g_db_engine) as session:
stmt = select(DomainModel).where(DomainModel.id.in_(domain_ids)) stmt = select(DomainModel).where(DomainModel.id.in_(domain_ids))
try: try:
rows = session.exec(stmt) rows = session.exec(stmt).all()
return ApiResult.ok(rows.all()) return ApiResult.ok(rows)
except Exception as e: except Exception as e:
session.rollback() session.rollback()
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}") return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}")