增加后端API接口

This commit is contained in:
xhy 2025-04-03 22:11:20 +08:00
parent 552a09ee41
commit 56ea878c29
20 changed files with 597 additions and 3 deletions

View File

@ -1 +1 @@
from .app import *
from .app import MainApp

View File

@ -1,7 +1,9 @@
import argparse
import asyncio
import sys
import os
import time
import signal
from app.engines.report_engine import Reporter
@ -14,6 +16,8 @@ from .models.base import connect_db, create_database
from loguru import logger
import sqlalchemy.exc
from .web.web import WebApp
class MainApp:
"""主应用"""
@ -116,7 +120,15 @@ class MainApp:
def start_web(self):
"""开启 Web 模式"""
pass
# 注册 ctrl+c 处理程序,正常结束所有的 engine
signal.signal(signal.SIGINT, self.exit_handler)
# 启动 web 页面
web_app = WebApp()
asyncio.run(web_app.start())
logger.info("web stop.")
def run(self):
"""运行应用"""
@ -157,3 +169,7 @@ class MainApp:
else:
logger.info("启动 CLI 模式")
return self.start_cli()
def exit_handler(self, signum, frame):
# TODO 在这里结束各个 engine
print("CTRL+C called.")

View File

View File

@ -0,0 +1,8 @@
import enum
class ApiCode(enum.Enum):
OK = 20000
PARAM_ERROR = 30000
DB_ERROR = 40000
RUNTIME_ERROR = 50000

7
app/constants/domain.py Normal file
View File

@ -0,0 +1,7 @@
import enum
class DomainStatus(enum.Enum):
READY = 1 # 采集结束之后回到这个状态,新添加的默认也是这个状态
QUEUEING = 2 # 排队中,已经压入任务队列了,但是还没轮到处理
CRAWLING = 3 # 采集中

View File

@ -26,10 +26,13 @@ def update_updated_at(mapper, connection, target):
target.updated_at = get_timestamp()
# noinspection PyUnresolvedReferences
def connect_db(config: AppConfig):
"""连接数据库"""
# 导入所有模型,为了自动创建数据表
from .domain import DomainModel
from .report_urls import ReportUrlModel
dsn = f"mysql+pymysql://{config.database.user}:{config.database.password}@{config.database.host}:{config.database.port}/{config.database.database}"
engine = create_engine(dsn, echo=False)

View File

@ -12,7 +12,7 @@ class DomainModel(BaseModel, table=True):
# 域名
domain: str = Field(alias="domain", default="", sa_type=VARCHAR(1024))
# 爬取状态,TODO先空着后续有任务控制之后用这个字段表示这个域名的任务状态
# 爬取状态,@see constants.DomainStatus
status: int = Field(alias="status", default=0)
# 爬取间隔默认间隔为1周

0
app/web/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,102 @@
from typing import Annotated
from fastapi import APIRouter, UploadFile, Form, Query
from app.constants.api_result import ApiCode
from app.web.request.domain_request import AddDomainRequest, DeleteDomainRequest, UpdateDomainRequest, \
GetDomainListRequest
from app.web.results import ApiResult
from app.web.service.domain_service import DomainService
router = APIRouter(prefix="/api/domain", tags=["域名管理"])
@router.get("/v1/list")
def get_all_domains(request: Annotated[GetDomainListRequest, Query()]):
"""获取所有的域名信息,支持根据域名、状态进行搜索,不传则返回全部数据,支持分页"""
return DomainService.get_list(request.page, request.size, request.domain, request.status)
@router.post("/v1/add")
def add_domains(request: AddDomainRequest):
"""添加域名"""
# 检查是否有重复的
result = DomainService.get_by_domains(request.domains)
if not result.success:
return result
existed_domains = [item.domain for item in result.data]
new_domains = [x for x in request.domains if x not in existed_domains]
if not new_domains:
return ApiResult.ok(0)
# 添加并返回
return DomainService.add_domains(request.crawl_interval, request.crawl_now, new_domains)
@router.post("/v1/import")
def import_domains(
# 同时提交文件和参数的时候,没办法使用 FormModel 的形式,必须一个一个定义
file: UploadFile,
crawl_interval: int = Form(),
crawl_now: bool = Form(),
):
"""通过上传文件添加域名,如果单个文件很大,以后改成开新协程/线程处理"""
# 把文件内容读出来
domains = []
for line in file.file:
line = line.strip()
domains.append(line.decode("UTF-8"))
# 创建协程任务
# asyncio.create_task(DomainService.add_domains(crawl_interval, crawl_now, domains))
# 检查是否有重复域名
result = DomainService.get_by_domains(domains)
if not result.success:
return result
existed_domains = [item.domain for item in result.data]
new_domains = [x for x in domains if x not in existed_domains]
# 添加并返回
return DomainService.add_domains(crawl_interval, crawl_now, new_domains)
# noinspection DuplicatedCode
@router.post("/v1/update")
def update_domain(request: UpdateDomainRequest):
"""更新域名的数据,主要是采集间隔,支持批量修改,传入多个 id"""
# 检查待更新的域名是否存在
result = DomainService.get_by_ids(request.domain_ids)
if not result.success:
return result
existed_domain_ids = [item.id for item in result.data]
for domain_id in request.domain_ids:
if domain_id not in existed_domain_ids:
return ApiResult.error(ApiCode.PARAM_ERROR.value, f"域名 ID {domain_id} 不存在")
# 更新刷新时间
return DomainService.update_domain_interval(request.domain_ids, request.crawl_interval)
# noinspection DuplicatedCode
@router.post("/v1/delete")
def delete_domain(request: DeleteDomainRequest):
"""删除域名,支持批量删除,传入多个 id"""
# 检查待删除的域名是否存在
result = DomainService.get_by_ids(request.domain_ids)
if not result.success:
return result
existed_domain_ids = [item.id for item in result.data]
for domain_id in request.domain_ids:
if domain_id not in existed_domain_ids:
return ApiResult.error(ApiCode.PARAM_ERROR.value, f"域名 ID {domain_id} 不存在")
# 删除域名
return DomainService.delete_domains(request.domain_ids, request.remove_surl)

View File

@ -0,0 +1,80 @@
from typing import Annotated
from fastapi import APIRouter, Query
from app.web.request.report_request import AddUrlsRequest, CollectEvidenceRequest, ReportRequest, GetUrlListRequest
from app.web.service.domain_service import DomainService
from app.web.service.report_service import ReportURLService
router = APIRouter(prefix="/api/urls", tags=["URL管理"])
@router.get("/v1/list")
async def get_all_urls(request: Annotated[GetUrlListRequest, Query()]):
"""获取所有的URL支持根据域名、状态进行过滤不传则返回全部数据支持分页"""
return ReportURLService.get_list(
request.domain,
request.surl,
request.is_report_by_one,
request.is_report_by_site,
request.is_report_by_wap,
request.has_evidence,
request.page,
request.size
)
@router.post("/v1/add")
async def add_urls(request: AddUrlsRequest):
"""
手动添加 URL 到域名中支持批量添加
格式 [
{"domain": "", "surl": ""}, {"domain": "", "surl": ""} ...
]
添加之前先检查 domain 有没有没有的话就去创建一个 domain
"""
# 把所有的域名列表解出来,看看有没有不存在的,如果有就新建一个域名
# 这里还需要获取域名的 id
input_domains = [item.domain for item in request.urls]
result = DomainService.get_by_domains(input_domains)
if not result.success:
return result
# 创建新域名
new_domains = [x for x in input_domains if x not in result.data]
if new_domains:
result = DomainService.add_domains(1440, True, new_domains)
if not result.success:
return result
# 再获取一遍域名模型
result = DomainService.get_by_domains(input_domains)
if not result.success:
return result
# 创建 URL
domain_map: dict[str, int] = {x.domain: x.id for x in result.data}
return ReportURLService.add_urls(domain_map, request.urls)
@router.post("/v1/evidence")
async def collect_evidence(request: CollectEvidenceRequest):
"""
强制手动触发证据收集任务支持批量传入已经收集过的 URL 也要强制收集
TODO:本来应该需要使用任务队列的为了简单先把数据库的相关标记改为 0 也能达到一样的效果
又不是不能用 XD
"""
return ReportURLService.batch_update_evidence_flag(request.ids)
@router.post("/v1/report")
async def report(request: ReportRequest):
"""举报指定的URL支持批量传入 id 批量举报
先通过改数据库然后等引擎自己调度实现
"""
return ReportURLService.batch_update_report_flag(
request.ids,
request.report_by_one,
request.report_by_site,
request.report_by_wap
)

View File

@ -0,0 +1,10 @@
from fastapi import APIRouter
router = APIRouter(tags=["健康检查"])
@router.get("/status")
async def status():
return {
"status": "ok"
}

View File

View File

@ -0,0 +1,38 @@
from pydantic import BaseModel, Field
class GetDomainListRequest(BaseModel):
"""获取域名列表"""
# 分页参数
page: int = Field(default=1, gt=0)
size: int = Field(default=50, gt=0)
# 过滤条件
domain: str = ""
status: int = 0
class AddDomainRequest(BaseModel):
"""添加域名到数据库的请求参数"""
crawl_interval: int
crawl_now: bool = True
domains: list[str]
class ImportDomainFormRequest(BaseModel):
"""通过文件导入的"""
crawl_interval: int
crawl_now: bool = True
class DeleteDomainRequest(BaseModel):
"""删除域名的请求"""
domain_ids: list[int]
remove_surl: bool = False
class UpdateDomainRequest(BaseModel):
"""更新域名的请求"""
domain_ids: list[int]
crawl_interval: int

View File

@ -0,0 +1,38 @@
from typing import Optional
from pydantic import BaseModel, Field
class GetUrlListRequest(BaseModel):
domain: str = ""
surl: str = ""
is_report_by_one: Optional[bool] = False
is_report_by_site: Optional[bool] = False
is_report_by_wap: Optional[bool] = False
has_evidence: Optional[bool] = False
page: int = Field(default=1, gt=0)
size: int = Field(default=50, gt=0)
class AddUrlItem(BaseModel):
domain: str
surl: str
class AddUrlsRequest(BaseModel):
"""手动添加URL的请求体"""
urls: list[AddUrlItem]
class CollectEvidenceRequest(BaseModel):
"""手动触发证据收集的请求体"""
ids: list[int]
class ReportRequest(BaseModel):
"""手动触发证据收集的请求体"""
ids: list[int]
report_by_one: bool
report_by_site: bool
report_by_wap: bool

24
app/web/results.py Normal file
View File

@ -0,0 +1,24 @@
from dataclasses import dataclass
from typing import Any, Generic
from typing_extensions import TypeVar
from app.constants.api_result import ApiCode
T = TypeVar("T")
@dataclass
class ApiResult(Generic[T]):
code: int
message: str
success: bool
data: T | None = None
@staticmethod
def ok(data: T | None = None) -> 'ApiResult[T]':
return ApiResult(code=ApiCode.OK.value, message="ok", success=True, data=data)
@staticmethod
def error(code: int, message: str) -> 'ApiResult[None]':
return ApiResult(code=code, message=message, success=False, data=None)

View File

View File

@ -0,0 +1,126 @@
import time
from typing import Iterable, Optional
from loguru import logger
from sqlalchemy import delete, func, update
from sqlmodel import Session, select
from app.config.config import AppCtx
from app.constants.api_result import ApiCode
from app.constants.domain import DomainStatus
from app.models.domain import DomainModel
from app.models.report_urls import ReportUrlModel
from app.web.results import ApiResult
class DomainService:
@classmethod
def get_list(cls, page: int, page_size: int, domain: str, status: int):
"""获取域名列表"""
with Session(AppCtx.g_db_engine) as session:
stmt = select(DomainModel)
stmt_total = select(func.count())
if domain:
stmt = stmt.where(DomainModel.domain.like(f"%{domain}%"))
stmt_total = stmt_total.where(DomainModel.domain.like(f"%{domain}%"))
if status:
stmt = stmt.where(DomainModel.status == status)
stmt_total = stmt_total.where(DomainModel.status == status)
# 设置分页
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
try:
# 域名列表
rows = session.exec(stmt).all()
# 查询符合筛选条件的总数量
total = session.exec(stmt_total).first()
return ApiResult.ok({"total": total, "rows": rows})
except Exception as e:
session.rollback()
logger.exception(f"查询域名列表失败,错误:{e}")
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名列表失败,错误:{e}")
@classmethod
def get_by_domains(cls, domains: list[str]) -> ApiResult[Optional[DomainModel]]:
"""根据域名查询"""
with Session(AppCtx.g_db_engine) as session:
stmt = select(DomainModel).where(DomainModel.domain.in_(domains))
try:
rows = session.exec(stmt)
return ApiResult.ok(rows.all())
except Exception as e:
session.rollback()
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}")
@classmethod
def get_by_ids(cls, domain_ids: list[int]) -> ApiResult[Optional[DomainModel]]:
"""根据id查询"""
with Session(AppCtx.g_db_engine) as session:
stmt = select(DomainModel).where(DomainModel.id.in_(domain_ids))
try:
rows = session.exec(stmt)
return ApiResult.ok(rows.all())
except Exception as e:
session.rollback()
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}")
@classmethod
def add_domains(cls, interval: int, crawl_now: bool, domains: Iterable[str]):
"""批量添加域名"""
with Session(AppCtx.g_db_engine) as session:
new_domains = [
DomainModel(
domain=x,
status=DomainStatus.READY.value,
crawl_interval=interval,
latest_crawl_time=0 if not crawl_now else int(time.time())
) for x in domains
]
session.add_all(new_domains)
try:
session.commit()
return ApiResult.ok(len(new_domains))
except Exception as e:
logger.error(f"添加域名到数据库失败,错误:{e}")
session.rollback()
return ApiResult.error(ApiCode.DB_ERROR.value, f"添加域名失败,错误:{e}")
@classmethod
def delete_domains(cls, domain_ids: list[int], remove_surl: bool = False):
"""批量删除域名remove_surl 表示是否同时删除 report_url 中该域名相关的数据"""
with Session(AppCtx.g_db_engine) as session:
stmt = delete(DomainModel).where(DomainModel.id.in_(domain_ids))
try:
session.exec(stmt)
# 如果设置了 remove_surl 为 True则删除 report_url 中该域名相关的数据
if remove_surl:
stmt = delete(ReportUrlModel).where(ReportUrlModel.domain_id.in_(domain_ids))
session.exec(stmt)
session.commit()
return ApiResult.ok(len(domain_ids))
except Exception as e:
logger.error(f"删除域名失败,错误:{e}")
session.rollback()
return ApiResult.error(ApiCode.DB_ERROR.value, f"删除域名失败,错误:{e}")
@classmethod
def update_domain_interval(cls, domain_ids: list[int], interval: int) -> ApiResult[Optional[int]]:
"""批量更新域名的 interval 值"""
with Session(AppCtx.g_db_engine) as session:
stmt = update(DomainModel).where(DomainModel.id.in_(domain_ids)).values(crawl_interval=interval)
try:
session.exec(stmt)
session.commit()
return ApiResult.ok(len(domain_ids))
except Exception as e:
logger.error(f"更新域名 interval 失败,错误:{e}")
session.rollback()
return ApiResult.error(ApiCode.DB_ERROR.value, f"更新域名 interval 失败,错误:{e}")

View File

@ -0,0 +1,116 @@
from typing import Optional
from loguru import logger
from sqlalchemy import update, func
from sqlmodel import Session, select
from app.config.config import AppCtx
from app.constants.api_result import ApiCode
from app.models.report_urls import ReportUrlModel
from app.web.request.report_request import AddUrlItem
from app.web.results import ApiResult
class ReportURLService:
@classmethod
def get_list(
cls, domain: str, surl: str, is_report_by_one: Optional[bool], is_report_by_site: Optional[bool],
is_report_by_wap: Optional[bool], has_evidence: Optional[bool], page: int, size: int):
with Session(AppCtx.g_db_engine) as session:
stmt = select(ReportUrlModel)
total_stmt = select(func.count())
if domain:
stmt = stmt.where(ReportUrlModel.domain.like(f"%{domain}%"))
total_stmt = total_stmt.where(ReportUrlModel.domain.like(f"%{domain}%"))
if surl:
stmt = stmt.where(ReportUrlModel.surl.like(f"%{surl}%"))
total_stmt = total_stmt.where(ReportUrlModel.surl.like(f"%{surl}%"))
if is_report_by_one is not None:
stmt = stmt.where(ReportUrlModel.is_report_by_one == is_report_by_one)
total_stmt = total_stmt.where(ReportUrlModel.is_report_by_one == is_report_by_one)
if is_report_by_site is not None:
stmt = stmt.where(ReportUrlModel.is_report_by_site == is_report_by_site)
total_stmt = total_stmt.where(ReportUrlModel.is_report_by_site == is_report_by_site)
if is_report_by_wap is not None:
stmt = stmt.where(ReportUrlModel.is_report_by_wap == is_report_by_wap)
total_stmt = total_stmt.where(ReportUrlModel.is_report_by_wap == is_report_by_wap)
if has_evidence is not None:
stmt = stmt.where(ReportUrlModel.has_evidence == has_evidence)
total_stmt = total_stmt.where(ReportUrlModel.has_evidence == has_evidence)
# 设置分页
stmt = stmt.offset((page - 1) * size).limit(size)
try:
total = session.exec(total_stmt).first()
urls = session.exec(stmt).all()
return ApiResult.ok({
"total": total,
"data": urls,
})
except Exception as e:
logger.error(f"获取URL列表失败: {e}")
return ApiResult.error(ApiCode.DB_ERROR.value, str(e))
@classmethod
def add_urls(cls, domain_map: dict[str, int], urls: list[AddUrlItem]) -> ApiResult[Optional[int]]:
"""添加URL"""
if not urls:
return ApiResult.ok(0)
models = []
for url in urls:
domain_id = domain_map.get(url.domain, None)
if not domain_id:
return ApiResult.error(ApiCode.PARAM_ERROR.value, f"域名 {url.domain} 不存在")
models.append(ReportUrlModel(
domain_id=domain_id,
domain=url.domain,
surl=url.surl,
))
with Session(AppCtx.g_db_engine) as session:
try:
session.add_all(models)
session.commit()
return ApiResult.ok(len(models))
except Exception as e:
logger.error(f"添加URL失败: {e}")
session.rollback()
return ApiResult.error(ApiCode.DB_ERROR.value, str(e))
@classmethod
def batch_update_evidence_flag(cls, url_ids: list[int]):
"""批量更新URL的has_evidence字段"""
with Session(AppCtx.g_db_engine) as session:
try:
stmt = update(ReportUrlModel).where(ReportUrlModel.id.in_(url_ids)).values(has_evidence=False)
session.exec(stmt)
session.commit()
return ApiResult.ok(len(url_ids))
except Exception as e:
logger.error(f"批量更新URL的has_evidence字段失败: {e}")
session.rollback()
return ApiResult.error(ApiCode.DB_ERROR.value, str(e))
@classmethod
def batch_update_report_flag(cls, ids: list[int], report_by_one: bool, report_by_site: bool, report_by_wap: bool):
with Session(AppCtx.g_db_engine) as session:
try:
stmt = update(ReportUrlModel).where(ReportUrlModel.id.in_(ids))
if report_by_wap:
stmt = stmt.values(is_report_by_wap=False)
elif report_by_site:
stmt = stmt.values(is_report_by_site=False)
elif report_by_one:
stmt = stmt.values(is_report_by_one=False)
session.exec(stmt)
session.commit()
return ApiResult.ok(len(ids))
except Exception as e:
logger.error(f"批量更新URL的has_evidence字段失败: {e}")
session.rollback()
return ApiResult.error(ApiCode.DB_ERROR.value, str(e))

26
app/web/web.py Normal file
View File

@ -0,0 +1,26 @@
import uvicorn
from fastapi import FastAPI
from .controller.domain import router as domain_router
from .controller.report import router as report_router
from .controller.status import router as status_router
class WebApp:
def __init__(self):
self.app = FastAPI()
@staticmethod
async def start():
app = FastAPI()
# 导入路由
app.include_router(status_router)
app.include_router(report_router)
app.include_router(domain_router)
# TODO 先写死,后面从配置文件里取
cfg = uvicorn.Config(app, host="127.0.0.1", port=3000)
server = uvicorn.Server(cfg)
await server.serve()