142 lines
6.0 KiB
Python
142 lines
6.0 KiB
Python
import time
|
||
from typing import Iterable, Optional
|
||
|
||
from loguru import logger
|
||
from sqlalchemy import delete, func, update
|
||
from sqlmodel import Session, select
|
||
|
||
from app.config.config import AppCtx
|
||
from app.constants.api_result import ApiCode
|
||
from app.constants.domain import DomainStatus
|
||
from app.models.domain import DomainModel
|
||
from app.models.report_urls import ReportUrlModel
|
||
from app.web.results import ApiResult
|
||
|
||
|
||
class DomainService:
|
||
|
||
@classmethod
|
||
def get_list(cls, page: int, page_size: int, domain: str, status: int):
|
||
"""获取域名列表"""
|
||
with Session(AppCtx.g_db_engine) as session:
|
||
stmt = select(DomainModel)
|
||
stmt_total = select(func.count(DomainModel.id))
|
||
if domain:
|
||
stmt = stmt.where(DomainModel.domain.like(f"%{domain}%"))
|
||
stmt_total = stmt_total.where(DomainModel.domain.like(f"%{domain}%"))
|
||
if status:
|
||
stmt = stmt.where(DomainModel.status == status)
|
||
stmt_total = stmt_total.where(DomainModel.status == status)
|
||
|
||
# 设置分页
|
||
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||
|
||
try:
|
||
# 域名列表
|
||
rows = session.exec(stmt).all()
|
||
|
||
# 查询符合筛选条件的总数量
|
||
total = session.exec(stmt_total).first()
|
||
logger.debug(f"{total=}")
|
||
|
||
return ApiResult.ok({"total": total, "rows": rows})
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.exception(f"查询域名列表失败,错误:{e}")
|
||
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名列表失败,错误:{e}")
|
||
|
||
@classmethod
|
||
def get_by_domains(cls, domains: list[str]) -> ApiResult[Optional[DomainModel]]:
|
||
"""根据域名查询"""
|
||
with Session(AppCtx.g_db_engine) as session:
|
||
stmt = select(DomainModel).where(DomainModel.domain.in_(domains))
|
||
try:
|
||
rows = session.exec(stmt).all()
|
||
return ApiResult.ok(rows)
|
||
except Exception as e:
|
||
session.rollback()
|
||
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}")
|
||
|
||
@classmethod
|
||
def get_by_ids(cls, domain_ids: list[int]) -> ApiResult[Optional[DomainModel]]:
|
||
"""根据id查询"""
|
||
with Session(AppCtx.g_db_engine) as session:
|
||
stmt = select(DomainModel).where(DomainModel.id.in_(domain_ids))
|
||
try:
|
||
rows = session.exec(stmt).all()
|
||
return ApiResult.ok(rows)
|
||
except Exception as e:
|
||
session.rollback()
|
||
return ApiResult.error(ApiCode.DB_ERROR.value, f"查询域名失败,错误:{e}")
|
||
|
||
@classmethod
|
||
def add_domains(cls, interval: int, crawl_now: bool, domains: Iterable[str]):
|
||
"""批量添加域名"""
|
||
with Session(AppCtx.g_db_engine) as session:
|
||
new_domains = [
|
||
DomainModel(
|
||
domain=x,
|
||
status=DomainStatus.READY.value,
|
||
crawl_interval=interval,
|
||
latest_crawl_time=0 if crawl_now else int(time.time())
|
||
) for x in domains
|
||
]
|
||
|
||
session.add_all(new_domains)
|
||
try:
|
||
session.commit()
|
||
return ApiResult.ok(len(new_domains))
|
||
except Exception as e:
|
||
logger.error(f"添加域名到数据库失败,错误:{e}")
|
||
session.rollback()
|
||
return ApiResult.error(ApiCode.DB_ERROR.value, f"添加域名失败,错误:{e}")
|
||
|
||
@classmethod
|
||
def delete_domains(cls, domain_ids: list[int], remove_surl: bool = False):
|
||
"""批量删除域名,remove_surl 表示是否同时删除 report_url 中该域名相关的数据"""
|
||
with Session(AppCtx.g_db_engine) as session:
|
||
stmt = delete(DomainModel).where(DomainModel.id.in_(domain_ids))
|
||
try:
|
||
session.exec(stmt)
|
||
|
||
# 如果设置了 remove_surl 为 True,则删除 report_url 中该域名相关的数据
|
||
if remove_surl:
|
||
stmt = delete(ReportUrlModel).where(ReportUrlModel.domain_id.in_(domain_ids))
|
||
session.exec(stmt)
|
||
|
||
session.commit()
|
||
|
||
return ApiResult.ok(len(domain_ids))
|
||
except Exception as e:
|
||
logger.error(f"删除域名失败,错误:{e}")
|
||
session.rollback()
|
||
return ApiResult.error(ApiCode.DB_ERROR.value, f"删除域名失败,错误:{e}")
|
||
|
||
@classmethod
|
||
def update_domain_interval(cls, domain_ids: list[int], interval: int) -> ApiResult[Optional[int]]:
|
||
"""批量更新域名的 interval 值"""
|
||
with Session(AppCtx.g_db_engine) as session:
|
||
stmt = update(DomainModel).where(DomainModel.id.in_(domain_ids)).values(crawl_interval=interval)
|
||
try:
|
||
session.exec(stmt)
|
||
session.commit()
|
||
return ApiResult.ok(len(domain_ids))
|
||
except Exception as e:
|
||
logger.error(f"更新域名 interval 失败,错误:{e}")
|
||
session.rollback()
|
||
return ApiResult.error(ApiCode.DB_ERROR.value, f"更新域名 interval 失败,错误:{e}")
|
||
|
||
@classmethod
|
||
def update_domain_status(cls, domain_ids: list[int], status: int) -> ApiResult[Optional[int]]:
|
||
"""批量更新域名的 status 值"""
|
||
with Session(AppCtx.g_db_engine) as session:
|
||
stmt = update(DomainModel).where(DomainModel.id.in_(domain_ids)).values(status=status)
|
||
try:
|
||
session.exec(stmt)
|
||
session.commit()
|
||
return ApiResult.ok(len(domain_ids))
|
||
except Exception as e:
|
||
logger.error(f"更新域名 status 失败,错误:{e}")
|
||
session.rollback()
|
||
return ApiResult.error(ApiCode.DB_ERROR.value, f"更新域名 status 失败,错误:{e}")
|