增加后端API接口
This commit is contained in:
parent
552a09ee41
commit
56ea878c29
@ -1 +1 @@
|
||||
from .app import *
|
||||
from .app import MainApp
|
||||
18
app/app.py
18
app/app.py
@ -1,7 +1,9 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import signal
|
||||
|
||||
from app.engines.report_engine import Reporter
|
||||
|
||||
@ -14,6 +16,8 @@ from .models.base import connect_db, create_database
|
||||
from loguru import logger
|
||||
import sqlalchemy.exc
|
||||
|
||||
from .web.web import WebApp
|
||||
|
||||
|
||||
class MainApp:
|
||||
"""主应用"""
|
||||
@ -116,7 +120,15 @@ class MainApp:
|
||||
|
||||
def start_web(self):
|
||||
"""开启 Web 模式"""
|
||||
pass
|
||||
|
||||
# 注册 ctrl+c 处理程序,正常结束所有的 engine
|
||||
signal.signal(signal.SIGINT, self.exit_handler)
|
||||
|
||||
# 启动 web 页面
|
||||
web_app = WebApp()
|
||||
asyncio.run(web_app.start())
|
||||
|
||||
logger.info("web stop.")
|
||||
|
||||
def run(self):
|
||||
"""运行应用"""
|
||||
@ -157,3 +169,7 @@ class MainApp:
|
||||
else:
|
||||
logger.info("启动 CLI 模式")
|
||||
return self.start_cli()
|
||||
|
||||
def exit_handler(self, signum, frame):
|
||||
# TODO 在这里结束各个 engine
|
||||
print("CTRL+C called.")
|
||||
0
app/constants/__init__.py
Normal file
0
app/constants/__init__.py
Normal file
8
app/constants/api_result.py
Normal file
8
app/constants/api_result.py
Normal file
@ -0,0 +1,8 @@
|
||||
import enum
|
||||
|
||||
|
||||
class ApiCode(enum.Enum):
|
||||
OK = 20000
|
||||
PARAM_ERROR = 30000
|
||||
DB_ERROR = 40000
|
||||
RUNTIME_ERROR = 50000
|
||||
7
app/constants/domain.py
Normal file
7
app/constants/domain.py
Normal file
@ -0,0 +1,7 @@
|
||||
import enum
|
||||
|
||||
|
||||
class DomainStatus(enum.Enum):
|
||||
READY = 1 # 采集结束之后回到这个状态,新添加的默认也是这个状态
|
||||
QUEUEING = 2 # 排队中,已经压入任务队列了,但是还没轮到处理
|
||||
CRAWLING = 3 # 采集中
|
||||
@ -26,10 +26,13 @@ def update_updated_at(mapper, connection, target):
|
||||
target.updated_at = get_timestamp()
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def connect_db(config: AppConfig):
|
||||
"""连接数据库"""
|
||||
|
||||
# 导入所有模型,为了自动创建数据表
|
||||
from .domain import DomainModel
|
||||
from .report_urls import ReportUrlModel
|
||||
|
||||
dsn = f"mysql+pymysql://{config.database.user}:{config.database.password}@{config.database.host}:{config.database.port}/{config.database.database}"
|
||||
engine = create_engine(dsn, echo=False)
|
||||
|
||||
@ -12,7 +12,7 @@ class DomainModel(BaseModel, table=True):
|
||||
# 域名
|
||||
domain: str = Field(alias="domain", default="", sa_type=VARCHAR(1024))
|
||||
|
||||
# 爬取状态,TODO:先空着,后续有任务控制之后,用这个字段表示这个域名的任务状态
|
||||
# 爬取状态,@see constants.DomainStatus
|
||||
status: int = Field(alias="status", default=0)
|
||||
|
||||
# 爬取间隔,默认间隔为1周
|
||||
|
||||
0
app/web/__init__.py
Normal file
0
app/web/__init__.py
Normal file
0
app/web/controller/__init__.py
Normal file
0
app/web/controller/__init__.py
Normal file
102
app/web/controller/domain.py
Normal file
102
app/web/controller/domain.py
Normal file
@ -0,0 +1,102 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, UploadFile, Form, Query
|
||||
|
||||
from app.constants.api_result import ApiCode
|
||||
from app.web.request.domain_request import AddDomainRequest, DeleteDomainRequest, UpdateDomainRequest, \
|
||||
GetDomainListRequest
|
||||
from app.web.results import ApiResult
|
||||
from app.web.service.domain_service import DomainService
|
||||
|
||||
router = APIRouter(prefix="/api/domain", tags=["域名管理"])
|
||||
|
||||
|
||||
@router.get("/v1/list")
|
||||
def get_all_domains(request: Annotated[GetDomainListRequest, Query()]):
|
||||
"""获取所有的域名信息,支持根据域名、状态进行搜索,不传则返回全部数据,支持分页"""
|
||||
return DomainService.get_list(request.page, request.size, request.domain, request.status)
|
||||
|
||||
|
||||
@router.post("/v1/add")
|
||||
def add_domains(request: AddDomainRequest):
|
||||
"""添加域名"""
|
||||
|
||||
# 检查是否有重复的
|
||||
result = DomainService.get_by_domains(request.domains)
|
||||
if not result.success:
|
||||
return result
|
||||
|
||||
existed_domains = [item.domain for item in result.data]
|
||||
new_domains = [x for x in request.domains if x not in existed_domains]
|
||||
if not new_domains:
|
||||
return ApiResult.ok(0)
|
||||
|
||||
# 添加并返回
|
||||
return DomainService.add_domains(request.crawl_interval, request.crawl_now, new_domains)
|
||||
|
||||
|
||||
@router.post("/v1/import")
|
||||
def import_domains(
|
||||
# 同时提交文件和参数的时候,没办法使用 FormModel 的形式,必须一个一个定义
|
||||
file: UploadFile,
|
||||
crawl_interval: int = Form(),
|
||||
crawl_now: bool = Form(),
|
||||
):
|
||||
"""通过上传文件添加域名,如果单个文件很大,以后改成开新协程/线程处理"""
|
||||
|
||||
# 把文件内容读出来
|
||||
domains = []
|
||||
for line in file.file:
|
||||
line = line.strip()
|
||||
domains.append(line.decode("UTF-8"))
|
||||
|
||||
# 创建协程任务
|
||||
# asyncio.create_task(DomainService.add_domains(crawl_interval, crawl_now, domains))
|
||||
|
||||
# 检查是否有重复域名
|
||||
result = DomainService.get_by_domains(domains)
|
||||
if not result.success:
|
||||
return result
|
||||
existed_domains = [item.domain for item in result.data]
|
||||
new_domains = [x for x in domains if x not in existed_domains]
|
||||
|
||||
# 添加并返回
|
||||
return DomainService.add_domains(crawl_interval, crawl_now, new_domains)
|
||||
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
@router.post("/v1/update")
|
||||
def update_domain(request: UpdateDomainRequest):
|
||||
"""更新域名的数据,主要是采集间隔,支持批量修改,传入多个 id"""
|
||||
|
||||
# 检查待更新的域名是否存在
|
||||
result = DomainService.get_by_ids(request.domain_ids)
|
||||
if not result.success:
|
||||
return result
|
||||
|
||||
existed_domain_ids = [item.id for item in result.data]
|
||||
for domain_id in request.domain_ids:
|
||||
if domain_id not in existed_domain_ids:
|
||||
return ApiResult.error(ApiCode.PARAM_ERROR.value, f"域名 ID {domain_id} 不存在")
|
||||
|
||||
# 更新刷新时间
|
||||
return DomainService.update_domain_interval(request.domain_ids, request.crawl_interval)
|
||||
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
@router.post("/v1/delete")
|
||||
def delete_domain(request: DeleteDomainRequest):
|
||||
"""删除域名,支持批量删除,传入多个 id"""
|
||||
|
||||
# 检查待删除的域名是否存在
|
||||
result = DomainService.get_by_ids(request.domain_ids)
|
||||
if not result.success:
|
||||
return result
|
||||
|
||||
existed_domain_ids = [item.id for item in result.data]
|
||||
for domain_id in request.domain_ids:
|
||||
if domain_id not in existed_domain_ids:
|
||||
return ApiResult.error(ApiCode.PARAM_ERROR.value, f"域名 ID {domain_id} 不存在")
|
||||
|
||||
# 删除域名
|
||||
return DomainService.delete_domains(request.domain_ids, request.remove_surl)
|
||||
80
app/web/controller/report.py
Normal file
80
app/web/controller/report.py
Normal file
@ -0,0 +1,80 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
|
||||
from app.web.request.report_request import AddUrlsRequest, CollectEvidenceRequest, ReportRequest, GetUrlListRequest
|
||||
from app.web.service.domain_service import DomainService
|
||||
from app.web.service.report_service import ReportURLService
|
||||
|
||||
router = APIRouter(prefix="/api/urls", tags=["URL管理"])
|
||||
|
||||
|
||||
@router.get("/v1/list")
|
||||
async def get_all_urls(request: Annotated[GetUrlListRequest, Query()]):
|
||||
"""获取所有的URL,支持根据域名、状态进行过滤,不传则返回全部数据,支持分页"""
|
||||
return ReportURLService.get_list(
|
||||
request.domain,
|
||||
request.surl,
|
||||
request.is_report_by_one,
|
||||
request.is_report_by_site,
|
||||
request.is_report_by_wap,
|
||||
request.has_evidence,
|
||||
request.page,
|
||||
request.size
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/add")
|
||||
async def add_urls(request: AddUrlsRequest):
|
||||
"""
|
||||
手动添加 URL 到域名中,支持批量添加
|
||||
格式 [
|
||||
{"domain": "", "surl": ""}, {"domain": "", "surl": ""} ...
|
||||
]
|
||||
添加之前先检查 domain 有没有,没有的话就去创建一个 domain
|
||||
"""
|
||||
# 把所有的域名列表解出来,看看有没有不存在的,如果有就新建一个域名
|
||||
# 这里还需要获取域名的 id
|
||||
input_domains = [item.domain for item in request.urls]
|
||||
result = DomainService.get_by_domains(input_domains)
|
||||
if not result.success:
|
||||
return result
|
||||
|
||||
# 创建新域名
|
||||
new_domains = [x for x in input_domains if x not in result.data]
|
||||
if new_domains:
|
||||
result = DomainService.add_domains(1440, True, new_domains)
|
||||
if not result.success:
|
||||
return result
|
||||
|
||||
# 再获取一遍域名模型
|
||||
result = DomainService.get_by_domains(input_domains)
|
||||
if not result.success:
|
||||
return result
|
||||
|
||||
# 创建 URL
|
||||
domain_map: dict[str, int] = {x.domain: x.id for x in result.data}
|
||||
return ReportURLService.add_urls(domain_map, request.urls)
|
||||
|
||||
|
||||
@router.post("/v1/evidence")
|
||||
async def collect_evidence(request: CollectEvidenceRequest):
|
||||
"""
|
||||
强制手动触发证据收集任务,支持批量传入,已经收集过的 URL 也要强制收集
|
||||
TODO:本来应该需要使用任务队列的,为了简单先把数据库的相关标记改为 0 ,也能达到一样的效果
|
||||
又不是不能用 XD
|
||||
"""
|
||||
return ReportURLService.batch_update_evidence_flag(request.ids)
|
||||
|
||||
|
||||
@router.post("/v1/report")
|
||||
async def report(request: ReportRequest):
|
||||
"""举报指定的URL,支持批量传入 id 批量举报
|
||||
先通过改数据库,然后等引擎自己调度实现
|
||||
"""
|
||||
return ReportURLService.batch_update_report_flag(
|
||||
request.ids,
|
||||
request.report_by_one,
|
||||
request.report_by_site,
|
||||
request.report_by_wap
|
||||
)
|
||||
10
app/web/controller/status.py
Normal file
10
app/web/controller/status.py
Normal file
@ -0,0 +1,10 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
|
||||
router = APIRouter(tags=["健康检查"])
|
||||
|
||||
@router.get("/status")
|
||||
async def status():
|
||||
return {
|
||||
"status": "ok"
|
||||
}
|
||||
0
app/web/request/__init__.py
Normal file
0
app/web/request/__init__.py
Normal file
38
app/web/request/domain_request.py
Normal file
38
app/web/request/domain_request.py
Normal file
@ -0,0 +1,38 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GetDomainListRequest(BaseModel):
|
||||
"""获取域名列表"""
|
||||
|
||||
# 分页参数
|
||||
page: int = Field(default=1, gt=0)
|
||||
size: int = Field(default=50, gt=0)
|
||||
|
||||
# 过滤条件
|
||||
domain: str = ""
|
||||
status: int = 0
|
||||
|
||||
|
||||
class AddDomainRequest(BaseModel):
|
||||
"""添加域名到数据库的请求参数"""
|
||||
crawl_interval: int
|
||||
crawl_now: bool = True
|
||||
domains: list[str]
|
||||
|
||||
|
||||
class ImportDomainFormRequest(BaseModel):
|
||||
"""通过文件导入的"""
|
||||
crawl_interval: int
|
||||
crawl_now: bool = True
|
||||
|
||||
|
||||
class DeleteDomainRequest(BaseModel):
|
||||
"""删除域名的请求"""
|
||||
domain_ids: list[int]
|
||||
remove_surl: bool = False
|
||||
|
||||
|
||||
class UpdateDomainRequest(BaseModel):
|
||||
"""更新域名的请求"""
|
||||
domain_ids: list[int]
|
||||
crawl_interval: int
|
||||
38
app/web/request/report_request.py
Normal file
38
app/web/request/report_request.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GetUrlListRequest(BaseModel):
|
||||
domain: str = ""
|
||||
surl: str = ""
|
||||
is_report_by_one: Optional[bool] = False
|
||||
is_report_by_site: Optional[bool] = False
|
||||
is_report_by_wap: Optional[bool] = False
|
||||
has_evidence: Optional[bool] = False
|
||||
|
||||
page: int = Field(default=1, gt=0)
|
||||
size: int = Field(default=50, gt=0)
|
||||
|
||||
|
||||
class AddUrlItem(BaseModel):
|
||||
domain: str
|
||||
surl: str
|
||||
|
||||
|
||||
class AddUrlsRequest(BaseModel):
|
||||
"""手动添加URL的请求体"""
|
||||
urls: list[AddUrlItem]
|
||||
|
||||
|
||||
class CollectEvidenceRequest(BaseModel):
|
||||
"""手动触发证据收集的请求体"""
|
||||
ids: list[int]
|
||||
|
||||
|
||||
class ReportRequest(BaseModel):
|
||||
"""手动触发证据收集的请求体"""
|
||||
ids: list[int]
|
||||
report_by_one: bool
|
||||
report_by_site: bool
|
||||
report_by_wap: bool
|
||||
24
app/web/results.py
Normal file
24
app/web/results.py
Normal file
@ -0,0 +1,24 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from app.constants.api_result import ApiCode
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiResult(Generic[T]):
|
||||
code: int
|
||||
message: str
|
||||
success: bool
|
||||
data: T | None = None
|
||||
|
||||
@staticmethod
|
||||
def ok(data: T | None = None) -> 'ApiResult[T]':
|
||||
return ApiResult(code=ApiCode.OK.value, message="ok", success=True, data=data)
|
||||
|
||||
@staticmethod
|
||||
def error(code: int, message: str) -> 'ApiResult[None]':
|
||||
return ApiResult(code=code, message=message, success=False, data=None)
|
||||
0
app/web/service/__init__.py
Normal file
0
app/web/service/__init__.py
Normal file
126
app/web/service/domain_service.py
Normal file
126
app/web/service/domain_service.py
Normal file
@ -0,0 +1,126 @@
|
||||
import time
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import delete, func, update
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.config.config import AppCtx
|
||||
from app.constants.api_result import ApiCode
|
||||
from app.constants.domain import DomainStatus
|
||||
from app.models.domain import DomainModel
|
||||
from app.models.report_urls import ReportUrlModel
|
||||
from app.web.results import ApiResult
|
||||
|
||||
|
||||
class DomainService:
|
||||
|
||||
@classmethod
|
||||
def get_list(cls, page: int, page_size: int, domain: str, status: int):
|
||||
"""获取域名列表"""
|
||||
with Session(AppCtx.g_db_engine) as session:
|
||||
stmt = select(DomainModel)
|
||||
stmt_total = select(func.count())
|
||||
if domain:
|
||||
stmt = stmt.where(DomainModel.domain.like(f"%{domain}%"))
|
||||
stmt_total = stmt_total.where(DomainModel.domain.like(f"%{domain}%"))
|
||||
if status:
|
||||
stmt = stmt.where(DomainModel.status == status)
|
||||
stmt_total = stmt_total.where(DomainModel.status == status)
|
||||
|
||||
# 设置分页
|
||||
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
try:
|
||||
# 域名列表
|
||||
rows = session.exec(stmt).all()
|
||||
|
||||
# 查询符合筛选条件的总数量
|
||||
total = session.exec(stmt_total).first()
|
||||
|
||||
return ApiResult.ok({"total": total, "rows": rows})
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.exception(f"查询域名列表失败,错误:{e}")
|
||||
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名列表失败,错误:{e}")
|
||||
|
||||
@classmethod
|
||||
def get_by_domains(cls, domains: list[str]) -> ApiResult[Optional[DomainModel]]:
|
||||
"""根据域名查询"""
|
||||
with Session(AppCtx.g_db_engine) as session:
|
||||
stmt = select(DomainModel).where(DomainModel.domain.in_(domains))
|
||||
try:
|
||||
rows = session.exec(stmt)
|
||||
return ApiResult.ok(rows.all())
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}")
|
||||
|
||||
@classmethod
|
||||
def get_by_ids(cls, domain_ids: list[int]) -> ApiResult[Optional[DomainModel]]:
|
||||
"""根据id查询"""
|
||||
with Session(AppCtx.g_db_engine) as session:
|
||||
stmt = select(DomainModel).where(DomainModel.id.in_(domain_ids))
|
||||
try:
|
||||
rows = session.exec(stmt)
|
||||
return ApiResult.ok(rows.all())
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}")
|
||||
|
||||
@classmethod
|
||||
def add_domains(cls, interval: int, crawl_now: bool, domains: Iterable[str]):
|
||||
"""批量添加域名"""
|
||||
with Session(AppCtx.g_db_engine) as session:
|
||||
new_domains = [
|
||||
DomainModel(
|
||||
domain=x,
|
||||
status=DomainStatus.READY.value,
|
||||
crawl_interval=interval,
|
||||
latest_crawl_time=0 if not crawl_now else int(time.time())
|
||||
) for x in domains
|
||||
]
|
||||
|
||||
session.add_all(new_domains)
|
||||
try:
|
||||
session.commit()
|
||||
return ApiResult.ok(len(new_domains))
|
||||
except Exception as e:
|
||||
logger.error(f"添加域名到数据库失败,错误:{e}")
|
||||
session.rollback()
|
||||
return ApiResult.error(ApiCode.DB_ERROR.value, f"添加域名失败,错误:{e}")
|
||||
|
||||
@classmethod
|
||||
def delete_domains(cls, domain_ids: list[int], remove_surl: bool = False):
|
||||
"""批量删除域名,remove_surl 表示是否同时删除 report_url 中该域名相关的数据"""
|
||||
with Session(AppCtx.g_db_engine) as session:
|
||||
stmt = delete(DomainModel).where(DomainModel.id.in_(domain_ids))
|
||||
try:
|
||||
session.exec(stmt)
|
||||
|
||||
# 如果设置了 remove_surl 为 True,则删除 report_url 中该域名相关的数据
|
||||
if remove_surl:
|
||||
stmt = delete(ReportUrlModel).where(ReportUrlModel.domain_id.in_(domain_ids))
|
||||
session.exec(stmt)
|
||||
|
||||
session.commit()
|
||||
|
||||
return ApiResult.ok(len(domain_ids))
|
||||
except Exception as e:
|
||||
logger.error(f"删除域名失败,错误:{e}")
|
||||
session.rollback()
|
||||
return ApiResult.error(ApiCode.DB_ERROR.value, f"删除域名失败,错误:{e}")
|
||||
|
||||
@classmethod
|
||||
def update_domain_interval(cls, domain_ids: list[int], interval: int) -> ApiResult[Optional[int]]:
|
||||
"""批量更新域名的 interval 值"""
|
||||
with Session(AppCtx.g_db_engine) as session:
|
||||
stmt = update(DomainModel).where(DomainModel.id.in_(domain_ids)).values(crawl_interval=interval)
|
||||
try:
|
||||
session.exec(stmt)
|
||||
session.commit()
|
||||
return ApiResult.ok(len(domain_ids))
|
||||
except Exception as e:
|
||||
logger.error(f"更新域名 interval 失败,错误:{e}")
|
||||
session.rollback()
|
||||
return ApiResult.error(ApiCode.DB_ERROR.value, f"更新域名 interval 失败,错误:{e}")
|
||||
116
app/web/service/report_service.py
Normal file
116
app/web/service/report_service.py
Normal file
@ -0,0 +1,116 @@
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import update, func
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.config.config import AppCtx
|
||||
from app.constants.api_result import ApiCode
|
||||
from app.models.report_urls import ReportUrlModel
|
||||
from app.web.request.report_request import AddUrlItem
|
||||
from app.web.results import ApiResult
|
||||
|
||||
|
||||
class ReportURLService:
|
||||
|
||||
@classmethod
|
||||
def get_list(
|
||||
cls, domain: str, surl: str, is_report_by_one: Optional[bool], is_report_by_site: Optional[bool],
|
||||
is_report_by_wap: Optional[bool], has_evidence: Optional[bool], page: int, size: int):
|
||||
|
||||
with Session(AppCtx.g_db_engine) as session:
|
||||
stmt = select(ReportUrlModel)
|
||||
total_stmt = select(func.count())
|
||||
if domain:
|
||||
stmt = stmt.where(ReportUrlModel.domain.like(f"%{domain}%"))
|
||||
total_stmt = total_stmt.where(ReportUrlModel.domain.like(f"%{domain}%"))
|
||||
if surl:
|
||||
stmt = stmt.where(ReportUrlModel.surl.like(f"%{surl}%"))
|
||||
total_stmt = total_stmt.where(ReportUrlModel.surl.like(f"%{surl}%"))
|
||||
if is_report_by_one is not None:
|
||||
stmt = stmt.where(ReportUrlModel.is_report_by_one == is_report_by_one)
|
||||
total_stmt = total_stmt.where(ReportUrlModel.is_report_by_one == is_report_by_one)
|
||||
if is_report_by_site is not None:
|
||||
stmt = stmt.where(ReportUrlModel.is_report_by_site == is_report_by_site)
|
||||
total_stmt = total_stmt.where(ReportUrlModel.is_report_by_site == is_report_by_site)
|
||||
if is_report_by_wap is not None:
|
||||
stmt = stmt.where(ReportUrlModel.is_report_by_wap == is_report_by_wap)
|
||||
total_stmt = total_stmt.where(ReportUrlModel.is_report_by_wap == is_report_by_wap)
|
||||
if has_evidence is not None:
|
||||
stmt = stmt.where(ReportUrlModel.has_evidence == has_evidence)
|
||||
total_stmt = total_stmt.where(ReportUrlModel.has_evidence == has_evidence)
|
||||
|
||||
# 设置分页
|
||||
stmt = stmt.offset((page - 1) * size).limit(size)
|
||||
|
||||
try:
|
||||
total = session.exec(total_stmt).first()
|
||||
urls = session.exec(stmt).all()
|
||||
return ApiResult.ok({
|
||||
"total": total,
|
||||
"data": urls,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"获取URL列表失败: {e}")
|
||||
return ApiResult.error(ApiCode.DB_ERROR.value, str(e))
|
||||
|
||||
@classmethod
|
||||
def add_urls(cls, domain_map: dict[str, int], urls: list[AddUrlItem]) -> ApiResult[Optional[int]]:
|
||||
"""添加URL"""
|
||||
if not urls:
|
||||
return ApiResult.ok(0)
|
||||
|
||||
models = []
|
||||
|
||||
for url in urls:
|
||||
domain_id = domain_map.get(url.domain, None)
|
||||
if not domain_id:
|
||||
return ApiResult.error(ApiCode.PARAM_ERROR.value, f"域名 {url.domain} 不存在")
|
||||
models.append(ReportUrlModel(
|
||||
domain_id=domain_id,
|
||||
domain=url.domain,
|
||||
surl=url.surl,
|
||||
))
|
||||
|
||||
with Session(AppCtx.g_db_engine) as session:
|
||||
try:
|
||||
session.add_all(models)
|
||||
session.commit()
|
||||
return ApiResult.ok(len(models))
|
||||
except Exception as e:
|
||||
logger.error(f"添加URL失败: {e}")
|
||||
session.rollback()
|
||||
return ApiResult.error(ApiCode.DB_ERROR.value, str(e))
|
||||
|
||||
@classmethod
|
||||
def batch_update_evidence_flag(cls, url_ids: list[int]):
|
||||
"""批量更新URL的has_evidence字段"""
|
||||
with Session(AppCtx.g_db_engine) as session:
|
||||
try:
|
||||
stmt = update(ReportUrlModel).where(ReportUrlModel.id.in_(url_ids)).values(has_evidence=False)
|
||||
session.exec(stmt)
|
||||
session.commit()
|
||||
return ApiResult.ok(len(url_ids))
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新URL的has_evidence字段失败: {e}")
|
||||
session.rollback()
|
||||
return ApiResult.error(ApiCode.DB_ERROR.value, str(e))
|
||||
|
||||
@classmethod
|
||||
def batch_update_report_flag(cls, ids: list[int], report_by_one: bool, report_by_site: bool, report_by_wap: bool):
|
||||
with Session(AppCtx.g_db_engine) as session:
|
||||
try:
|
||||
stmt = update(ReportUrlModel).where(ReportUrlModel.id.in_(ids))
|
||||
if report_by_wap:
|
||||
stmt = stmt.values(is_report_by_wap=False)
|
||||
elif report_by_site:
|
||||
stmt = stmt.values(is_report_by_site=False)
|
||||
elif report_by_one:
|
||||
stmt = stmt.values(is_report_by_one=False)
|
||||
session.exec(stmt)
|
||||
session.commit()
|
||||
return ApiResult.ok(len(ids))
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新URL的has_evidence字段失败: {e}")
|
||||
session.rollback()
|
||||
return ApiResult.error(ApiCode.DB_ERROR.value, str(e))
|
||||
26
app/web/web.py
Normal file
26
app/web/web.py
Normal file
@ -0,0 +1,26 @@
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .controller.domain import router as domain_router
|
||||
from .controller.report import router as report_router
|
||||
from .controller.status import router as status_router
|
||||
|
||||
|
||||
class WebApp:
|
||||
|
||||
def __init__(self):
|
||||
self.app = FastAPI()
|
||||
|
||||
@staticmethod
|
||||
async def start():
|
||||
app = FastAPI()
|
||||
|
||||
# 导入路由
|
||||
app.include_router(status_router)
|
||||
app.include_router(report_router)
|
||||
app.include_router(domain_router)
|
||||
|
||||
# TODO 先写死,后面从配置文件里取
|
||||
cfg = uvicorn.Config(app, host="127.0.0.1", port=3000)
|
||||
server = uvicorn.Server(cfg)
|
||||
await server.serve()
|
||||
Loading…
x
Reference in New Issue
Block a user