""" 数据库管理模块 (SQLAlchemy) 负责项目数据、任务状态、素材路径的持久化存储 支持 SQLite 和 PostgreSQL """ import json import logging import time from typing import Dict, List, Any, Optional from sqlalchemy import create_engine, Column, String, Integer, Text, Float, UniqueConstraint, func from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base from sqlalchemy.dialects.postgresql import JSONB import config logger = logging.getLogger(__name__) Base = declarative_base() class Project(Base): __tablename__ = 'projects' id = Column(String, primary_key=True) name = Column(String) status = Column(String) # created, script_generated, images_generated, videos_generated, completed product_info = Column(Text) # JSON string (SQLite) or JSONB (PG - using Text for compat) script_data = Column(Text) # JSON string created_at = Column(Float, default=time.time) updated_at = Column(Float, default=time.time, onupdate=time.time) class SceneAsset(Base): __tablename__ = 'scene_assets' id = Column(Integer, primary_key=True, autoincrement=True) project_id = Column(String, index=True) scene_id = Column(Integer) asset_type = Column(String) # image, video status = Column(String) # pending, processing, completed, failed local_path = Column(Text, nullable=True) remote_url = Column(Text, nullable=True) task_id = Column(String, nullable=True) # 外部 API 的任务 ID metadata_json = Column("metadata", Text, nullable=True) # JSON string (renamed to avoid conflict with metadata attr) created_at = Column(Float, default=time.time) updated_at = Column(Float, default=time.time, onupdate=time.time) __table_args__ = (UniqueConstraint('project_id', 'scene_id', 'asset_type', name='uix_project_scene_asset'),) class AppConfig(Base): __tablename__ = 'app_config' key = Column(String, primary_key=True) value = Column(Text) # JSON string description = Column(Text, nullable=True) updated_at = Column(Float, default=time.time, onupdate=time.time) class DBManager: def __init__(self, connection_string: str = None): if not connection_string: connection_string = config.DB_CONNECTION_STRING self.engine = create_engine(connection_string, pool_recycle=3600) self.Session = scoped_session(sessionmaker(bind=self.engine)) self._init_db() def _init_db(self): """初始化表结构""" Base.metadata.create_all(self.engine) def _get_session(self): return self.Session() # --- Project Operations --- def create_project(self, project_id: str, name: str, product_info: Dict[str, Any]): session = self._get_session() try: # Check if exists existing = session.query(Project).filter_by(id=project_id).first() if existing: logger.warning(f"Project {project_id} already exists.") return new_project = Project( id=project_id, name=name, status="created", product_info=json.dumps(product_info, ensure_ascii=False), created_at=time.time(), updated_at=time.time() ) session.add(new_project) session.commit() except Exception as e: session.rollback() logger.error(f"Error creating project: {e}") raise finally: session.close() def update_project_script(self, project_id: str, script: Dict[str, Any]): session = self._get_session() try: project = session.query(Project).filter_by(id=project_id).first() if project: project.script_data = json.dumps(script, ensure_ascii=False) project.status = "script_generated" project.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"Error updating script: {e}") finally: session.close() def update_project_product_info(self, project_id: str, product_info: Dict[str, Any]): """ Update project.product_info JSON (read-write with Postgres shared DB). Used to persist editor state without changing schema. """ session = self._get_session() try: project = session.query(Project).filter_by(id=project_id).first() if project: project.product_info = json.dumps(product_info, ensure_ascii=False) project.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"Error updating product_info: {e}") raise finally: session.close() def update_project_status(self, project_id: str, status: str): session = self._get_session() try: project = session.query(Project).filter_by(id=project_id).first() if project: project.status = status project.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"Error updating status: {e}") finally: session.close() def get_project(self, project_id: str) -> Optional[Dict[str, Any]]: session = self._get_session() try: project = session.query(Project).filter_by(id=project_id).first() if project: data = { "id": project.id, "name": project.name, "status": project.status, "product_info": json.loads(project.product_info) if project.product_info else {}, "script_data": json.loads(project.script_data) if project.script_data else None, "created_at": project.created_at, "updated_at": project.updated_at } return data return None finally: session.close() def list_projects(self) -> List[Dict[str, Any]]: session = self._get_session() try: projects = session.query(Project).order_by(Project.updated_at.desc()).all() results = [] for p in projects: results.append({ "id": p.id, "name": p.name, "status": p.status, "updated_at": p.updated_at }) return results finally: session.close() # --- Asset/Task Operations --- def save_asset(self, project_id: str, scene_id: int, asset_type: str, status: str, local_path: str = None, remote_url: str = None, task_id: str = None, metadata: Dict = None): """保存或更新资产记录 (UPSERT 逻辑)""" session = self._get_session() try: asset = session.query(SceneAsset).filter_by( project_id=project_id, scene_id=scene_id, asset_type=asset_type ).first() meta_json = json.dumps(metadata, ensure_ascii=False) if metadata else "{}" if asset: asset.status = status asset.local_path = local_path asset.remote_url = remote_url asset.task_id = task_id asset.metadata_json = meta_json asset.updated_at = time.time() else: new_asset = SceneAsset( project_id=project_id, scene_id=scene_id, asset_type=asset_type, status=status, local_path=local_path, remote_url=remote_url, task_id=task_id, metadata_json=meta_json, created_at=time.time(), updated_at=time.time() ) session.add(new_asset) session.commit() except Exception as e: session.rollback() logger.error(f"Error saving asset: {e}") finally: session.close() def get_assets(self, project_id: str, asset_type: str = None) -> List[Dict[str, Any]]: session = self._get_session() try: query = session.query(SceneAsset).filter_by(project_id=project_id) if asset_type: query = query.filter_by(asset_type=asset_type) assets = query.all() results = [] for a in assets: data = { "id": a.id, "project_id": a.project_id, "scene_id": a.scene_id, "asset_type": a.asset_type, "status": a.status, "local_path": a.local_path, "remote_url": a.remote_url, "task_id": a.task_id, "metadata": json.loads(a.metadata_json) if a.metadata_json else {}, "updated_at": a.updated_at } results.append(data) return results finally: session.close() def get_asset(self, project_id: str, scene_id: int, asset_type: str) -> Optional[Dict[str, Any]]: session = self._get_session() try: a = session.query(SceneAsset).filter_by( project_id=project_id, scene_id=scene_id, asset_type=asset_type ).first() if a: return { "id": a.id, "project_id": a.project_id, "scene_id": a.scene_id, "asset_type": a.asset_type, "status": a.status, "local_path": a.local_path, "remote_url": a.remote_url, "task_id": a.task_id, "metadata": json.loads(a.metadata_json) if a.metadata_json else {}, "updated_at": a.updated_at } return None finally: session.close() def update_asset_metadata(self, project_id: str, scene_id: int, asset_type: str, patch: Dict[str, Any]) -> None: """Merge-patch asset.metadata JSON without overwriting other fields.""" if not patch: return session = self._get_session() try: asset = session.query(SceneAsset).filter_by( project_id=project_id, scene_id=scene_id, asset_type=asset_type ).first() if not asset: return try: existing = json.loads(asset.metadata_json) if asset.metadata_json else {} except Exception: existing = {} if not isinstance(existing, dict): existing = {} existing.update(patch) asset.metadata_json = json.dumps(existing, ensure_ascii=False) asset.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"Error updating asset metadata: {e}") finally: session.close() def clear_asset(self, project_id: str, scene_id: int, asset_type: str, *, status: str = "pending") -> None: """ Clear an asset record (keep the row, but remove paths/task/metadata). Used to invalidate stale videos after images are regenerated. """ session = self._get_session() try: asset = session.query(SceneAsset).filter_by( project_id=project_id, scene_id=scene_id, asset_type=asset_type, ).first() if not asset: return asset.status = status asset.local_path = None asset.remote_url = None asset.task_id = None asset.metadata_json = "{}" asset.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"Error clearing asset: {e}") finally: session.close() def clear_assets(self, project_id: str, asset_type: str, *, status: str = "pending") -> int: """ Clear all assets of a type for a project. Returns number of rows affected. """ session = self._get_session() try: q = session.query(SceneAsset).filter_by(project_id=project_id, asset_type=asset_type) rows = q.all() if not rows: return 0 for asset in rows: asset.status = status asset.local_path = None asset.remote_url = None asset.task_id = None asset.metadata_json = "{}" asset.updated_at = time.time() session.commit() return len(rows) except Exception as e: session.rollback() logger.error(f"Error clearing assets: {e}") return 0 finally: session.close() # --- Config/Prompt Operations --- def get_config(self, key: str, default: Any = None) -> Any: session = self._get_session() try: cfg = session.query(AppConfig).filter_by(key=key).first() if cfg: try: return json.loads(cfg.value) except: return cfg.value return default finally: session.close() def set_config(self, key: str, value: Any, description: str = None): session = self._get_session() try: json_val = json.dumps(value, ensure_ascii=False) cfg = session.query(AppConfig).filter_by(key=key).first() if cfg: cfg.value = json_val if description: cfg.description = description cfg.updated_at = time.time() else: new_cfg = AppConfig( key=key, value=json_val, description=description, updated_at=time.time() ) session.add(new_cfg) session.commit() except Exception as e: session.rollback() logger.error(f"Error setting config: {e}") finally: session.close() # Singleton instance db = DBManager()