增加后端API接口
This commit is contained in:
parent
552a09ee41
commit
56ea878c29
@ -1 +1 @@
|
|||||||
from .app import *
|
from .app import MainApp
|
||||||
18
app/app.py
18
app/app.py
@ -1,7 +1,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import signal
|
||||||
|
|
||||||
from app.engines.report_engine import Reporter
|
from app.engines.report_engine import Reporter
|
||||||
|
|
||||||
@ -14,6 +16,8 @@ from .models.base import connect_db, create_database
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
import sqlalchemy.exc
|
import sqlalchemy.exc
|
||||||
|
|
||||||
|
from .web.web import WebApp
|
||||||
|
|
||||||
|
|
||||||
class MainApp:
|
class MainApp:
|
||||||
"""主应用"""
|
"""主应用"""
|
||||||
@ -116,7 +120,15 @@ class MainApp:
|
|||||||
|
|
||||||
def start_web(self):
|
def start_web(self):
|
||||||
"""开启 Web 模式"""
|
"""开启 Web 模式"""
|
||||||
pass
|
|
||||||
|
# 注册 ctrl+c 处理程序,正常结束所有的 engine
|
||||||
|
signal.signal(signal.SIGINT, self.exit_handler)
|
||||||
|
|
||||||
|
# 启动 web 页面
|
||||||
|
web_app = WebApp()
|
||||||
|
asyncio.run(web_app.start())
|
||||||
|
|
||||||
|
logger.info("web stop.")
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""运行应用"""
|
"""运行应用"""
|
||||||
@ -157,3 +169,7 @@ class MainApp:
|
|||||||
else:
|
else:
|
||||||
logger.info("启动 CLI 模式")
|
logger.info("启动 CLI 模式")
|
||||||
return self.start_cli()
|
return self.start_cli()
|
||||||
|
|
||||||
|
def exit_handler(self, signum, frame):
|
||||||
|
# TODO 在这里结束各个 engine
|
||||||
|
print("CTRL+C called.")
|
||||||
0
app/constants/__init__.py
Normal file
0
app/constants/__init__.py
Normal file
8
app/constants/api_result.py
Normal file
8
app/constants/api_result.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
|
class ApiCode(enum.Enum):
|
||||||
|
OK = 20000
|
||||||
|
PARAM_ERROR = 30000
|
||||||
|
DB_ERROR = 40000
|
||||||
|
RUNTIME_ERROR = 50000
|
||||||
7
app/constants/domain.py
Normal file
7
app/constants/domain.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
|
class DomainStatus(enum.Enum):
|
||||||
|
READY = 1 # 采集结束之后回到这个状态,新添加的默认也是这个状态
|
||||||
|
QUEUEING = 2 # 排队中,已经压入任务队列了,但是还没轮到处理
|
||||||
|
CRAWLING = 3 # 采集中
|
||||||
@ -26,10 +26,13 @@ def update_updated_at(mapper, connection, target):
|
|||||||
target.updated_at = get_timestamp()
|
target.updated_at = get_timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
def connect_db(config: AppConfig):
|
def connect_db(config: AppConfig):
|
||||||
"""连接数据库"""
|
"""连接数据库"""
|
||||||
|
|
||||||
# 导入所有模型,为了自动创建数据表
|
# 导入所有模型,为了自动创建数据表
|
||||||
|
from .domain import DomainModel
|
||||||
|
from .report_urls import ReportUrlModel
|
||||||
|
|
||||||
dsn = f"mysql+pymysql://{config.database.user}:{config.database.password}@{config.database.host}:{config.database.port}/{config.database.database}"
|
dsn = f"mysql+pymysql://{config.database.user}:{config.database.password}@{config.database.host}:{config.database.port}/{config.database.database}"
|
||||||
engine = create_engine(dsn, echo=False)
|
engine = create_engine(dsn, echo=False)
|
||||||
|
|||||||
@ -12,7 +12,7 @@ class DomainModel(BaseModel, table=True):
|
|||||||
# 域名
|
# 域名
|
||||||
domain: str = Field(alias="domain", default="", sa_type=VARCHAR(1024))
|
domain: str = Field(alias="domain", default="", sa_type=VARCHAR(1024))
|
||||||
|
|
||||||
# 爬取状态,TODO:先空着,后续有任务控制之后,用这个字段表示这个域名的任务状态
|
# 爬取状态,@see constants.DomainStatus
|
||||||
status: int = Field(alias="status", default=0)
|
status: int = Field(alias="status", default=0)
|
||||||
|
|
||||||
# 爬取间隔,默认间隔为1周
|
# 爬取间隔,默认间隔为1周
|
||||||
|
|||||||
0
app/web/__init__.py
Normal file
0
app/web/__init__.py
Normal file
0
app/web/controller/__init__.py
Normal file
0
app/web/controller/__init__.py
Normal file
102
app/web/controller/domain.py
Normal file
102
app/web/controller/domain.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, UploadFile, Form, Query
|
||||||
|
|
||||||
|
from app.constants.api_result import ApiCode
|
||||||
|
from app.web.request.domain_request import AddDomainRequest, DeleteDomainRequest, UpdateDomainRequest, \
|
||||||
|
GetDomainListRequest
|
||||||
|
from app.web.results import ApiResult
|
||||||
|
from app.web.service.domain_service import DomainService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/domain", tags=["域名管理"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/list")
|
||||||
|
def get_all_domains(request: Annotated[GetDomainListRequest, Query()]):
|
||||||
|
"""获取所有的域名信息,支持根据域名、状态进行搜索,不传则返回全部数据,支持分页"""
|
||||||
|
return DomainService.get_list(request.page, request.size, request.domain, request.status)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/add")
|
||||||
|
def add_domains(request: AddDomainRequest):
|
||||||
|
"""添加域名"""
|
||||||
|
|
||||||
|
# 检查是否有重复的
|
||||||
|
result = DomainService.get_by_domains(request.domains)
|
||||||
|
if not result.success:
|
||||||
|
return result
|
||||||
|
|
||||||
|
existed_domains = [item.domain for item in result.data]
|
||||||
|
new_domains = [x for x in request.domains if x not in existed_domains]
|
||||||
|
if not new_domains:
|
||||||
|
return ApiResult.ok(0)
|
||||||
|
|
||||||
|
# 添加并返回
|
||||||
|
return DomainService.add_domains(request.crawl_interval, request.crawl_now, new_domains)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/import")
|
||||||
|
def import_domains(
|
||||||
|
# 同时提交文件和参数的时候,没办法使用 FormModel 的形式,必须一个一个定义
|
||||||
|
file: UploadFile,
|
||||||
|
crawl_interval: int = Form(),
|
||||||
|
crawl_now: bool = Form(),
|
||||||
|
):
|
||||||
|
"""通过上传文件添加域名,如果单个文件很大,以后改成开新协程/线程处理"""
|
||||||
|
|
||||||
|
# 把文件内容读出来
|
||||||
|
domains = []
|
||||||
|
for line in file.file:
|
||||||
|
line = line.strip()
|
||||||
|
domains.append(line.decode("UTF-8"))
|
||||||
|
|
||||||
|
# 创建协程任务
|
||||||
|
# asyncio.create_task(DomainService.add_domains(crawl_interval, crawl_now, domains))
|
||||||
|
|
||||||
|
# 检查是否有重复域名
|
||||||
|
result = DomainService.get_by_domains(domains)
|
||||||
|
if not result.success:
|
||||||
|
return result
|
||||||
|
existed_domains = [item.domain for item in result.data]
|
||||||
|
new_domains = [x for x in domains if x not in existed_domains]
|
||||||
|
|
||||||
|
# 添加并返回
|
||||||
|
return DomainService.add_domains(crawl_interval, crawl_now, new_domains)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection DuplicatedCode
|
||||||
|
@router.post("/v1/update")
|
||||||
|
def update_domain(request: UpdateDomainRequest):
|
||||||
|
"""更新域名的数据,主要是采集间隔,支持批量修改,传入多个 id"""
|
||||||
|
|
||||||
|
# 检查待更新的域名是否存在
|
||||||
|
result = DomainService.get_by_ids(request.domain_ids)
|
||||||
|
if not result.success:
|
||||||
|
return result
|
||||||
|
|
||||||
|
existed_domain_ids = [item.id for item in result.data]
|
||||||
|
for domain_id in request.domain_ids:
|
||||||
|
if domain_id not in existed_domain_ids:
|
||||||
|
return ApiResult.error(ApiCode.PARAM_ERROR.value, f"域名 ID {domain_id} 不存在")
|
||||||
|
|
||||||
|
# 更新刷新时间
|
||||||
|
return DomainService.update_domain_interval(request.domain_ids, request.crawl_interval)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection DuplicatedCode
|
||||||
|
@router.post("/v1/delete")
|
||||||
|
def delete_domain(request: DeleteDomainRequest):
|
||||||
|
"""删除域名,支持批量删除,传入多个 id"""
|
||||||
|
|
||||||
|
# 检查待删除的域名是否存在
|
||||||
|
result = DomainService.get_by_ids(request.domain_ids)
|
||||||
|
if not result.success:
|
||||||
|
return result
|
||||||
|
|
||||||
|
existed_domain_ids = [item.id for item in result.data]
|
||||||
|
for domain_id in request.domain_ids:
|
||||||
|
if domain_id not in existed_domain_ids:
|
||||||
|
return ApiResult.error(ApiCode.PARAM_ERROR.value, f"域名 ID {domain_id} 不存在")
|
||||||
|
|
||||||
|
# 删除域名
|
||||||
|
return DomainService.delete_domains(request.domain_ids, request.remove_surl)
|
||||||
80
app/web/controller/report.py
Normal file
80
app/web/controller/report.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Query
|
||||||
|
|
||||||
|
from app.web.request.report_request import AddUrlsRequest, CollectEvidenceRequest, ReportRequest, GetUrlListRequest
|
||||||
|
from app.web.service.domain_service import DomainService
|
||||||
|
from app.web.service.report_service import ReportURLService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/urls", tags=["URL管理"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/list")
|
||||||
|
async def get_all_urls(request: Annotated[GetUrlListRequest, Query()]):
|
||||||
|
"""获取所有的URL,支持根据域名、状态进行过滤,不传则返回全部数据,支持分页"""
|
||||||
|
return ReportURLService.get_list(
|
||||||
|
request.domain,
|
||||||
|
request.surl,
|
||||||
|
request.is_report_by_one,
|
||||||
|
request.is_report_by_site,
|
||||||
|
request.is_report_by_wap,
|
||||||
|
request.has_evidence,
|
||||||
|
request.page,
|
||||||
|
request.size
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/add")
|
||||||
|
async def add_urls(request: AddUrlsRequest):
|
||||||
|
"""
|
||||||
|
手动添加 URL 到域名中,支持批量添加
|
||||||
|
格式 [
|
||||||
|
{"domain": "", "surl": ""}, {"domain": "", "surl": ""} ...
|
||||||
|
]
|
||||||
|
添加之前先检查 domain 有没有,没有的话就去创建一个 domain
|
||||||
|
"""
|
||||||
|
# 把所有的域名列表解出来,看看有没有不存在的,如果有就新建一个域名
|
||||||
|
# 这里还需要获取域名的 id
|
||||||
|
input_domains = [item.domain for item in request.urls]
|
||||||
|
result = DomainService.get_by_domains(input_domains)
|
||||||
|
if not result.success:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 创建新域名
|
||||||
|
new_domains = [x for x in input_domains if x not in result.data]
|
||||||
|
if new_domains:
|
||||||
|
result = DomainService.add_domains(1440, True, new_domains)
|
||||||
|
if not result.success:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 再获取一遍域名模型
|
||||||
|
result = DomainService.get_by_domains(input_domains)
|
||||||
|
if not result.success:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 创建 URL
|
||||||
|
domain_map: dict[str, int] = {x.domain: x.id for x in result.data}
|
||||||
|
return ReportURLService.add_urls(domain_map, request.urls)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/evidence")
|
||||||
|
async def collect_evidence(request: CollectEvidenceRequest):
|
||||||
|
"""
|
||||||
|
强制手动触发证据收集任务,支持批量传入,已经收集过的 URL 也要强制收集
|
||||||
|
TODO:本来应该需要使用任务队列的,为了简单先把数据库的相关标记改为 0 ,也能达到一样的效果
|
||||||
|
又不是不能用 XD
|
||||||
|
"""
|
||||||
|
return ReportURLService.batch_update_evidence_flag(request.ids)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/report")
|
||||||
|
async def report(request: ReportRequest):
|
||||||
|
"""举报指定的URL,支持批量传入 id 批量举报
|
||||||
|
先通过改数据库,然后等引擎自己调度实现
|
||||||
|
"""
|
||||||
|
return ReportURLService.batch_update_report_flag(
|
||||||
|
request.ids,
|
||||||
|
request.report_by_one,
|
||||||
|
request.report_by_site,
|
||||||
|
request.report_by_wap
|
||||||
|
)
|
||||||
10
app/web/controller/status.py
Normal file
10
app/web/controller/status.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(tags=["健康检查"])
|
||||||
|
|
||||||
|
@router.get("/status")
|
||||||
|
async def status():
|
||||||
|
return {
|
||||||
|
"status": "ok"
|
||||||
|
}
|
||||||
0
app/web/request/__init__.py
Normal file
0
app/web/request/__init__.py
Normal file
38
app/web/request/domain_request.py
Normal file
38
app/web/request/domain_request.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class GetDomainListRequest(BaseModel):
|
||||||
|
"""获取域名列表"""
|
||||||
|
|
||||||
|
# 分页参数
|
||||||
|
page: int = Field(default=1, gt=0)
|
||||||
|
size: int = Field(default=50, gt=0)
|
||||||
|
|
||||||
|
# 过滤条件
|
||||||
|
domain: str = ""
|
||||||
|
status: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AddDomainRequest(BaseModel):
|
||||||
|
"""添加域名到数据库的请求参数"""
|
||||||
|
crawl_interval: int
|
||||||
|
crawl_now: bool = True
|
||||||
|
domains: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ImportDomainFormRequest(BaseModel):
|
||||||
|
"""通过文件导入的"""
|
||||||
|
crawl_interval: int
|
||||||
|
crawl_now: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteDomainRequest(BaseModel):
|
||||||
|
"""删除域名的请求"""
|
||||||
|
domain_ids: list[int]
|
||||||
|
remove_surl: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateDomainRequest(BaseModel):
|
||||||
|
"""更新域名的请求"""
|
||||||
|
domain_ids: list[int]
|
||||||
|
crawl_interval: int
|
||||||
38
app/web/request/report_request.py
Normal file
38
app/web/request/report_request.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class GetUrlListRequest(BaseModel):
|
||||||
|
domain: str = ""
|
||||||
|
surl: str = ""
|
||||||
|
is_report_by_one: Optional[bool] = False
|
||||||
|
is_report_by_site: Optional[bool] = False
|
||||||
|
is_report_by_wap: Optional[bool] = False
|
||||||
|
has_evidence: Optional[bool] = False
|
||||||
|
|
||||||
|
page: int = Field(default=1, gt=0)
|
||||||
|
size: int = Field(default=50, gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
class AddUrlItem(BaseModel):
|
||||||
|
domain: str
|
||||||
|
surl: str
|
||||||
|
|
||||||
|
|
||||||
|
class AddUrlsRequest(BaseModel):
|
||||||
|
"""手动添加URL的请求体"""
|
||||||
|
urls: list[AddUrlItem]
|
||||||
|
|
||||||
|
|
||||||
|
class CollectEvidenceRequest(BaseModel):
|
||||||
|
"""手动触发证据收集的请求体"""
|
||||||
|
ids: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
class ReportRequest(BaseModel):
|
||||||
|
"""手动触发证据收集的请求体"""
|
||||||
|
ids: list[int]
|
||||||
|
report_by_one: bool
|
||||||
|
report_by_site: bool
|
||||||
|
report_by_wap: bool
|
||||||
24
app/web/results.py
Normal file
24
app/web/results.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Generic
|
||||||
|
|
||||||
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
|
from app.constants.api_result import ApiCode
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ApiResult(Generic[T]):
|
||||||
|
code: int
|
||||||
|
message: str
|
||||||
|
success: bool
|
||||||
|
data: T | None = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ok(data: T | None = None) -> 'ApiResult[T]':
|
||||||
|
return ApiResult(code=ApiCode.OK.value, message="ok", success=True, data=data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def error(code: int, message: str) -> 'ApiResult[None]':
|
||||||
|
return ApiResult(code=code, message=message, success=False, data=None)
|
||||||
0
app/web/service/__init__.py
Normal file
0
app/web/service/__init__.py
Normal file
126
app/web/service/domain_service.py
Normal file
126
app/web/service/domain_service.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
import time
|
||||||
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from sqlalchemy import delete, func, update
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
|
from app.config.config import AppCtx
|
||||||
|
from app.constants.api_result import ApiCode
|
||||||
|
from app.constants.domain import DomainStatus
|
||||||
|
from app.models.domain import DomainModel
|
||||||
|
from app.models.report_urls import ReportUrlModel
|
||||||
|
from app.web.results import ApiResult
|
||||||
|
|
||||||
|
|
||||||
|
class DomainService:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_list(cls, page: int, page_size: int, domain: str, status: int):
|
||||||
|
"""获取域名列表"""
|
||||||
|
with Session(AppCtx.g_db_engine) as session:
|
||||||
|
stmt = select(DomainModel)
|
||||||
|
stmt_total = select(func.count())
|
||||||
|
if domain:
|
||||||
|
stmt = stmt.where(DomainModel.domain.like(f"%{domain}%"))
|
||||||
|
stmt_total = stmt_total.where(DomainModel.domain.like(f"%{domain}%"))
|
||||||
|
if status:
|
||||||
|
stmt = stmt.where(DomainModel.status == status)
|
||||||
|
stmt_total = stmt_total.where(DomainModel.status == status)
|
||||||
|
|
||||||
|
# 设置分页
|
||||||
|
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 域名列表
|
||||||
|
rows = session.exec(stmt).all()
|
||||||
|
|
||||||
|
# 查询符合筛选条件的总数量
|
||||||
|
total = session.exec(stmt_total).first()
|
||||||
|
|
||||||
|
return ApiResult.ok({"total": total, "rows": rows})
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
logger.exception(f"查询域名列表失败,错误:{e}")
|
||||||
|
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名列表失败,错误:{e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_by_domains(cls, domains: list[str]) -> ApiResult[Optional[DomainModel]]:
|
||||||
|
"""根据域名查询"""
|
||||||
|
with Session(AppCtx.g_db_engine) as session:
|
||||||
|
stmt = select(DomainModel).where(DomainModel.domain.in_(domains))
|
||||||
|
try:
|
||||||
|
rows = session.exec(stmt)
|
||||||
|
return ApiResult.ok(rows.all())
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_by_ids(cls, domain_ids: list[int]) -> ApiResult[Optional[DomainModel]]:
|
||||||
|
"""根据id查询"""
|
||||||
|
with Session(AppCtx.g_db_engine) as session:
|
||||||
|
stmt = select(DomainModel).where(DomainModel.id.in_(domain_ids))
|
||||||
|
try:
|
||||||
|
rows = session.exec(stmt)
|
||||||
|
return ApiResult.ok(rows.all())
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_domains(cls, interval: int, crawl_now: bool, domains: Iterable[str]):
|
||||||
|
"""批量添加域名"""
|
||||||
|
with Session(AppCtx.g_db_engine) as session:
|
||||||
|
new_domains = [
|
||||||
|
DomainModel(
|
||||||
|
domain=x,
|
||||||
|
status=DomainStatus.READY.value,
|
||||||
|
crawl_interval=interval,
|
||||||
|
latest_crawl_time=0 if not crawl_now else int(time.time())
|
||||||
|
) for x in domains
|
||||||
|
]
|
||||||
|
|
||||||
|
session.add_all(new_domains)
|
||||||
|
try:
|
||||||
|
session.commit()
|
||||||
|
return ApiResult.ok(len(new_domains))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"添加域名到数据库失败,错误:{e}")
|
||||||
|
session.rollback()
|
||||||
|
return ApiResult.error(ApiCode.DB_ERROR.value, f"添加域名失败,错误:{e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_domains(cls, domain_ids: list[int], remove_surl: bool = False):
|
||||||
|
"""批量删除域名,remove_surl 表示是否同时删除 report_url 中该域名相关的数据"""
|
||||||
|
with Session(AppCtx.g_db_engine) as session:
|
||||||
|
stmt = delete(DomainModel).where(DomainModel.id.in_(domain_ids))
|
||||||
|
try:
|
||||||
|
session.exec(stmt)
|
||||||
|
|
||||||
|
# 如果设置了 remove_surl 为 True,则删除 report_url 中该域名相关的数据
|
||||||
|
if remove_surl:
|
||||||
|
stmt = delete(ReportUrlModel).where(ReportUrlModel.domain_id.in_(domain_ids))
|
||||||
|
session.exec(stmt)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return ApiResult.ok(len(domain_ids))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除域名失败,错误:{e}")
|
||||||
|
session.rollback()
|
||||||
|
return ApiResult.error(ApiCode.DB_ERROR.value, f"删除域名失败,错误:{e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_domain_interval(cls, domain_ids: list[int], interval: int) -> ApiResult[Optional[int]]:
|
||||||
|
"""批量更新域名的 interval 值"""
|
||||||
|
with Session(AppCtx.g_db_engine) as session:
|
||||||
|
stmt = update(DomainModel).where(DomainModel.id.in_(domain_ids)).values(crawl_interval=interval)
|
||||||
|
try:
|
||||||
|
session.exec(stmt)
|
||||||
|
session.commit()
|
||||||
|
return ApiResult.ok(len(domain_ids))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新域名 interval 失败,错误:{e}")
|
||||||
|
session.rollback()
|
||||||
|
return ApiResult.error(ApiCode.DB_ERROR.value, f"更新域名 interval 失败,错误:{e}")
|
||||||
116
app/web/service/report_service.py
Normal file
116
app/web/service/report_service.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from sqlalchemy import update, func
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
|
from app.config.config import AppCtx
|
||||||
|
from app.constants.api_result import ApiCode
|
||||||
|
from app.models.report_urls import ReportUrlModel
|
||||||
|
from app.web.request.report_request import AddUrlItem
|
||||||
|
from app.web.results import ApiResult
|
||||||
|
|
||||||
|
|
||||||
|
class ReportURLService:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_list(
|
||||||
|
cls, domain: str, surl: str, is_report_by_one: Optional[bool], is_report_by_site: Optional[bool],
|
||||||
|
is_report_by_wap: Optional[bool], has_evidence: Optional[bool], page: int, size: int):
|
||||||
|
|
||||||
|
with Session(AppCtx.g_db_engine) as session:
|
||||||
|
stmt = select(ReportUrlModel)
|
||||||
|
total_stmt = select(func.count())
|
||||||
|
if domain:
|
||||||
|
stmt = stmt.where(ReportUrlModel.domain.like(f"%{domain}%"))
|
||||||
|
total_stmt = total_stmt.where(ReportUrlModel.domain.like(f"%{domain}%"))
|
||||||
|
if surl:
|
||||||
|
stmt = stmt.where(ReportUrlModel.surl.like(f"%{surl}%"))
|
||||||
|
total_stmt = total_stmt.where(ReportUrlModel.surl.like(f"%{surl}%"))
|
||||||
|
if is_report_by_one is not None:
|
||||||
|
stmt = stmt.where(ReportUrlModel.is_report_by_one == is_report_by_one)
|
||||||
|
total_stmt = total_stmt.where(ReportUrlModel.is_report_by_one == is_report_by_one)
|
||||||
|
if is_report_by_site is not None:
|
||||||
|
stmt = stmt.where(ReportUrlModel.is_report_by_site == is_report_by_site)
|
||||||
|
total_stmt = total_stmt.where(ReportUrlModel.is_report_by_site == is_report_by_site)
|
||||||
|
if is_report_by_wap is not None:
|
||||||
|
stmt = stmt.where(ReportUrlModel.is_report_by_wap == is_report_by_wap)
|
||||||
|
total_stmt = total_stmt.where(ReportUrlModel.is_report_by_wap == is_report_by_wap)
|
||||||
|
if has_evidence is not None:
|
||||||
|
stmt = stmt.where(ReportUrlModel.has_evidence == has_evidence)
|
||||||
|
total_stmt = total_stmt.where(ReportUrlModel.has_evidence == has_evidence)
|
||||||
|
|
||||||
|
# 设置分页
|
||||||
|
stmt = stmt.offset((page - 1) * size).limit(size)
|
||||||
|
|
||||||
|
try:
|
||||||
|
total = session.exec(total_stmt).first()
|
||||||
|
urls = session.exec(stmt).all()
|
||||||
|
return ApiResult.ok({
|
||||||
|
"total": total,
|
||||||
|
"data": urls,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取URL列表失败: {e}")
|
||||||
|
return ApiResult.error(ApiCode.DB_ERROR.value, str(e))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_urls(cls, domain_map: dict[str, int], urls: list[AddUrlItem]) -> ApiResult[Optional[int]]:
|
||||||
|
"""添加URL"""
|
||||||
|
if not urls:
|
||||||
|
return ApiResult.ok(0)
|
||||||
|
|
||||||
|
models = []
|
||||||
|
|
||||||
|
for url in urls:
|
||||||
|
domain_id = domain_map.get(url.domain, None)
|
||||||
|
if not domain_id:
|
||||||
|
return ApiResult.error(ApiCode.PARAM_ERROR.value, f"域名 {url.domain} 不存在")
|
||||||
|
models.append(ReportUrlModel(
|
||||||
|
domain_id=domain_id,
|
||||||
|
domain=url.domain,
|
||||||
|
surl=url.surl,
|
||||||
|
))
|
||||||
|
|
||||||
|
with Session(AppCtx.g_db_engine) as session:
|
||||||
|
try:
|
||||||
|
session.add_all(models)
|
||||||
|
session.commit()
|
||||||
|
return ApiResult.ok(len(models))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"添加URL失败: {e}")
|
||||||
|
session.rollback()
|
||||||
|
return ApiResult.error(ApiCode.DB_ERROR.value, str(e))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_update_evidence_flag(cls, url_ids: list[int]):
|
||||||
|
"""批量更新URL的has_evidence字段"""
|
||||||
|
with Session(AppCtx.g_db_engine) as session:
|
||||||
|
try:
|
||||||
|
stmt = update(ReportUrlModel).where(ReportUrlModel.id.in_(url_ids)).values(has_evidence=False)
|
||||||
|
session.exec(stmt)
|
||||||
|
session.commit()
|
||||||
|
return ApiResult.ok(len(url_ids))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量更新URL的has_evidence字段失败: {e}")
|
||||||
|
session.rollback()
|
||||||
|
return ApiResult.error(ApiCode.DB_ERROR.value, str(e))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_update_report_flag(cls, ids: list[int], report_by_one: bool, report_by_site: bool, report_by_wap: bool):
|
||||||
|
with Session(AppCtx.g_db_engine) as session:
|
||||||
|
try:
|
||||||
|
stmt = update(ReportUrlModel).where(ReportUrlModel.id.in_(ids))
|
||||||
|
if report_by_wap:
|
||||||
|
stmt = stmt.values(is_report_by_wap=False)
|
||||||
|
elif report_by_site:
|
||||||
|
stmt = stmt.values(is_report_by_site=False)
|
||||||
|
elif report_by_one:
|
||||||
|
stmt = stmt.values(is_report_by_one=False)
|
||||||
|
session.exec(stmt)
|
||||||
|
session.commit()
|
||||||
|
return ApiResult.ok(len(ids))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量更新URL的has_evidence字段失败: {e}")
|
||||||
|
session.rollback()
|
||||||
|
return ApiResult.error(ApiCode.DB_ERROR.value, str(e))
|
||||||
26
app/web/web.py
Normal file
26
app/web/web.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from .controller.domain import router as domain_router
|
||||||
|
from .controller.report import router as report_router
|
||||||
|
from .controller.status import router as status_router
|
||||||
|
|
||||||
|
|
||||||
|
class WebApp:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.app = FastAPI()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def start():
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# 导入路由
|
||||||
|
app.include_router(status_router)
|
||||||
|
app.include_router(report_router)
|
||||||
|
app.include_router(domain_router)
|
||||||
|
|
||||||
|
# TODO 先写死,后面从配置文件里取
|
||||||
|
cfg = uvicorn.Config(app, host="127.0.0.1", port=3000)
|
||||||
|
server = uvicorn.Server(cfg)
|
||||||
|
await server.serve()
|
||||||
Loading…
x
Reference in New Issue
Block a user