"""采集进度存储""" import sqlite3 import os import logging from datetime import datetime from typing import Optional from contextlib import contextmanager from app.models import CrawlProgress from app.core.config import settings logger = logging.getLogger(__name__) class ProgressStore: """采集进度存储(SQLite)""" def __init__(self, db_path: str = None): self.db_path = db_path or settings.database.path os.makedirs(os.path.dirname(self.db_path) or ".", exist_ok=True) self._init_db() def _init_db(self): """初始化数据库""" with self._get_conn() as conn: conn.execute(""" CREATE TABLE IF NOT EXISTS crawl_progress ( task_id TEXT PRIMARY KEY, last_start_offset INTEGER, total INTEGER DEFAULT 0, last_update TEXT, status TEXT DEFAULT 'idle', filtered_count INTEGER DEFAULT 0, produced_count INTEGER DEFAULT 0 ) """) conn.commit() @contextmanager def _get_conn(self): """获取数据库连接""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row try: yield conn finally: conn.close() def get_progress(self, task_id: str) -> Optional[CrawlProgress]: """获取采集进度""" with self._get_conn() as conn: cursor = conn.execute("SELECT * FROM crawl_progress WHERE task_id = ?", (task_id,)) row = cursor.fetchone() if row: return CrawlProgress( task_id=row["task_id"], last_start_offset=row["last_start_offset"], total=row["total"], last_update=row["last_update"] or "", status=row["status"] ) return None def save_progress(self, task_id: str, last_start_offset: int, total: int, status: str = "running", filtered_count: int = 0, produced_count: int = 0): """保存采集进度""" now = datetime.now().isoformat() with self._get_conn() as conn: conn.execute(""" INSERT INTO crawl_progress (task_id, last_start_offset, total, last_update, status, filtered_count, produced_count) VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT(task_id) DO UPDATE SET last_start_offset = excluded.last_start_offset, total = excluded.total, last_update = excluded.last_update, status = excluded.status, filtered_count = excluded.filtered_count, produced_count = excluded.produced_count """, (task_id, last_start_offset, total, now, status, filtered_count, produced_count)) conn.commit() def get_stats(self, task_id: str) -> dict: """获取统计信息""" with self._get_conn() as conn: cursor = conn.execute("SELECT * FROM crawl_progress WHERE task_id = ?", (task_id,)) row = cursor.fetchone() if row: return dict(row) return {} def reset_progress(self, task_id: str): """重置采集进度""" with self._get_conn() as conn: conn.execute("DELETE FROM crawl_progress WHERE task_id = ?", (task_id,)) conn.commit() progress_store = ProgressStore()