95 lines
2.5 KiB
Python
95 lines
2.5 KiB
Python
"""配置管理"""
|
|
import os
|
|
import yaml
|
|
from typing import Optional, List
|
|
from pydantic import BaseModel
|
|
from functools import lru_cache
|
|
|
|
|
|
class AppConfig(BaseModel):
|
|
name: str = "job-crawler"
|
|
version: str = "1.0.0"
|
|
debug: bool = False
|
|
|
|
|
|
class TaskConfig(BaseModel):
|
|
"""单个任务配置"""
|
|
id: str
|
|
name: str = ""
|
|
enabled: bool = True
|
|
|
|
|
|
class ApiConfig(BaseModel):
|
|
base_url: str = "https://openapi.bazhuayu.com"
|
|
username: str = ""
|
|
password: str = ""
|
|
batch_size: int = 100
|
|
tasks: List[TaskConfig] = []
|
|
|
|
|
|
class RabbitMQConfig(BaseModel):
|
|
host: str = "localhost"
|
|
port: int = 5672
|
|
username: str = "guest"
|
|
password: str = "guest"
|
|
queue: str = "job_data"
|
|
message_ttl: int = 604800000 # 7天(毫秒)
|
|
|
|
|
|
class CrawlerConfig(BaseModel):
|
|
interval: int = 300
|
|
filter_days: int = 7
|
|
max_workers: int = 5
|
|
max_expired_batches: int = 3 # 连续过期批次阈值
|
|
auto_start: bool = True # 容器启动时自动开始采集
|
|
|
|
|
|
class DatabaseConfig(BaseModel):
|
|
path: str = "data/crawl_progress.db"
|
|
|
|
|
|
class Settings(BaseModel):
|
|
"""应用配置"""
|
|
app: AppConfig = AppConfig()
|
|
api: ApiConfig = ApiConfig()
|
|
rabbitmq: RabbitMQConfig = RabbitMQConfig()
|
|
crawler: CrawlerConfig = CrawlerConfig()
|
|
database: DatabaseConfig = DatabaseConfig()
|
|
|
|
@classmethod
|
|
def from_yaml(cls, config_path: str) -> "Settings":
|
|
"""从YAML文件加载配置"""
|
|
if not os.path.exists(config_path):
|
|
return cls()
|
|
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
data = yaml.safe_load(f) or {}
|
|
|
|
# 解析tasks
|
|
api_data = data.get('api', {})
|
|
tasks_data = api_data.pop('tasks', [])
|
|
tasks = [TaskConfig(**t) for t in tasks_data]
|
|
api_config = ApiConfig(**api_data, tasks=tasks)
|
|
|
|
return cls(
|
|
app=AppConfig(**data.get('app', {})),
|
|
api=api_config,
|
|
rabbitmq=RabbitMQConfig(**data.get('rabbitmq', {})),
|
|
crawler=CrawlerConfig(**data.get('crawler', {})),
|
|
database=DatabaseConfig(**data.get('database', {}))
|
|
)
|
|
|
|
def get_enabled_tasks(self) -> List[TaskConfig]:
|
|
"""获取启用的任务列表"""
|
|
return [t for t in self.api.tasks if t.enabled]
|
|
|
|
|
|
@lru_cache()
|
|
def get_settings() -> Settings:
|
|
"""获取配置"""
|
|
config_path = os.environ.get("CONFIG_PATH", "config/config.yml")
|
|
return Settings.from_yaml(config_path)
|
|
|
|
|
|
settings = get_settings()
|