import time from typing import Iterable, Optional from loguru import logger from sqlalchemy import delete, func, update from sqlmodel import Session, select from app.config.config import AppCtx from app.constants.api_result import ApiCode from app.constants.domain import DomainStatus from app.models.domain import DomainModel from app.models.report_urls import ReportUrlModel from app.web.results import ApiResult class DomainService: @classmethod def get_list(cls, page: int, page_size: int, domain: str, status: int): """获取域名列表""" with Session(AppCtx.g_db_engine) as session: stmt = select(DomainModel) stmt_total = select(func.count()) if domain: stmt = stmt.where(DomainModel.domain.like(f"%{domain}%")) stmt_total = stmt_total.where(DomainModel.domain.like(f"%{domain}%")) if status: stmt = stmt.where(DomainModel.status == status) stmt_total = stmt_total.where(DomainModel.status == status) # 设置分页 stmt = stmt.offset((page - 1) * page_size).limit(page_size) try: # 域名列表 rows = session.exec(stmt).all() # 查询符合筛选条件的总数量 total = session.exec(stmt_total).first() return ApiResult.ok({"total": total, "rows": rows}) except Exception as e: session.rollback() logger.exception(f"查询域名列表失败,错误:{e}") return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名列表失败,错误:{e}") @classmethod def get_by_domains(cls, domains: list[str]) -> ApiResult[Optional[DomainModel]]: """根据域名查询""" with Session(AppCtx.g_db_engine) as session: stmt = select(DomainModel).where(DomainModel.domain.in_(domains)) try: rows = session.exec(stmt).all() return ApiResult.ok(rows) except Exception as e: session.rollback() return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}") @classmethod def get_by_ids(cls, domain_ids: list[int]) -> ApiResult[Optional[DomainModel]]: """根据id查询""" with Session(AppCtx.g_db_engine) as session: stmt = select(DomainModel).where(DomainModel.id.in_(domain_ids)) try: rows = session.exec(stmt).all() return ApiResult.ok(rows) except Exception as e: session.rollback() return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}") @classmethod def add_domains(cls, interval: int, crawl_now: bool, domains: Iterable[str]): """批量添加域名""" with Session(AppCtx.g_db_engine) as session: new_domains = [ DomainModel( domain=x, status=DomainStatus.READY.value, crawl_interval=interval, latest_crawl_time=0 if not crawl_now else int(time.time()) ) for x in domains ] session.add_all(new_domains) try: session.commit() return ApiResult.ok(len(new_domains)) except Exception as e: logger.error(f"添加域名到数据库失败,错误:{e}") session.rollback() return ApiResult.error(ApiCode.DB_ERROR.value, f"添加域名失败,错误:{e}") @classmethod def delete_domains(cls, domain_ids: list[int], remove_surl: bool = False): """批量删除域名,remove_surl 表示是否同时删除 report_url 中该域名相关的数据""" with Session(AppCtx.g_db_engine) as session: stmt = delete(DomainModel).where(DomainModel.id.in_(domain_ids)) try: session.exec(stmt) # 如果设置了 remove_surl 为 True,则删除 report_url 中该域名相关的数据 if remove_surl: stmt = delete(ReportUrlModel).where(ReportUrlModel.domain_id.in_(domain_ids)) session.exec(stmt) session.commit() return ApiResult.ok(len(domain_ids)) except Exception as e: logger.error(f"删除域名失败,错误:{e}") session.rollback() return ApiResult.error(ApiCode.DB_ERROR.value, f"删除域名失败,错误:{e}") @classmethod def update_domain_interval(cls, domain_ids: list[int], interval: int) -> ApiResult[Optional[int]]: """批量更新域名的 interval 值""" with Session(AppCtx.g_db_engine) as session: stmt = update(DomainModel).where(DomainModel.id.in_(domain_ids)).values(crawl_interval=interval) try: session.exec(stmt) session.commit() return ApiResult.ok(len(domain_ids)) except Exception as e: logger.error(f"更新域名 interval 失败,错误:{e}") session.rollback() return ApiResult.error(ApiCode.DB_ERROR.value, f"更新域名 interval 失败,错误:{e}")