""" 数据库管理模块 (SQLAlchemy) 负责项目数据、任务状态、素材路径的持久化存储 支持 SQLite 和 PostgreSQL """ import json import logging import time import secrets from typing import Dict, List, Any, Optional, Tuple from sqlalchemy import create_engine, Column, String, Integer, Text, Float, UniqueConstraint, func, text, inspect from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base from sqlalchemy.dialects.postgresql import JSONB import config # NOTE: Some deployments do not ship `modules/auth.py`. # Keep a minimal, local auth helper here to avoid hard dependency. import hashlib import hmac def _hash_password(password: str, salt_hex: str = None) -> tuple[str, str]: salt = bytes.fromhex(salt_hex) if salt_hex else secrets.token_bytes(16) dk = hashlib.pbkdf2_hmac("sha256", (password or "").encode("utf-8"), salt, 120_000) return dk.hex(), salt.hex() def _verify_password(password: str, pwd_hash_hex: str, salt_hex: str) -> bool: try: cand, _ = _hash_password(password or "", salt_hex=salt_hex) return hmac.compare_digest(cand, pwd_hash_hex or "") except Exception: return False def _new_session_token() -> str: return secrets.token_urlsafe(32) def _hash_token(token: str) -> str: # store only hash for sessions return hashlib.sha256((token or "").encode("utf-8")).hexdigest() logger = logging.getLogger(__name__) Base = declarative_base() class Project(Base): __tablename__ = 'projects' id = Column(String, primary_key=True) name = Column(String) status = Column(String) # created, script_generated, images_generated, videos_generated, completed product_info = Column(Text) # JSON string (SQLite) or JSONB (PG - using Text for compat) script_data = Column(Text) # JSON string owner_user_id = Column(String, index=True, nullable=True) created_at = Column(Float, default=time.time) updated_at = Column(Float, default=time.time, onupdate=time.time) class SceneAsset(Base): __tablename__ = 'scene_assets' id = Column(Integer, primary_key=True, autoincrement=True) project_id = Column(String, index=True) scene_id = Column(Integer) asset_type = Column(String) # image, video status = Column(String) # pending, processing, completed, failed local_path = Column(Text, nullable=True) remote_url = Column(Text, nullable=True) task_id = Column(String, nullable=True) # 外部 API 的任务 ID metadata_json = Column("metadata", Text, nullable=True) # JSON string (renamed to avoid conflict with metadata attr) created_at = Column(Float, default=time.time) updated_at = Column(Float, default=time.time, onupdate=time.time) __table_args__ = (UniqueConstraint('project_id', 'scene_id', 'asset_type', name='uix_project_scene_asset'),) class AppConfig(Base): __tablename__ = 'app_config' key = Column(String, primary_key=True) value = Column(Text) # JSON string description = Column(Text, nullable=True) updated_at = Column(Float, default=time.time, onupdate=time.time) class User(Base): __tablename__ = "users" id = Column(String, primary_key=True) # uuid hex username = Column(String, unique=True, index=True) password_hash = Column(String) password_salt = Column(String) role = Column(String, default="user") # admin/user is_active = Column(Integer, default=1) # 1/0 for portability created_at = Column(Float, default=time.time) updated_at = Column(Float, default=time.time, onupdate=time.time) last_login_at = Column(Float, nullable=True) class UserSession(Base): __tablename__ = "user_sessions" id = Column(String, primary_key=True) # uuid hex user_id = Column(String, index=True) token_hash = Column(String, index=True) expires_at = Column(Float) created_at = Column(Float, default=time.time) last_seen_at = Column(Float, default=time.time) ip = Column(String, nullable=True) user_agent = Column(String, nullable=True) class UserPrompt(Base): __tablename__ = "user_prompts" id = Column(String, primary_key=True) # uuid hex user_id = Column(String, index=True) key = Column(String, index=True) value = Column(Text) updated_at = Column(Float, default=time.time, onupdate=time.time) __table_args__ = (UniqueConstraint("user_id", "key", name="uix_user_prompt"),) class RenderJob(Base): """ Render job status table (async compose pipeline). This is intentionally minimal and portable across SQLite/Postgres. """ __tablename__ = "render_jobs" id = Column(String, primary_key=True) # task_id project_id = Column(String, index=True) status = Column(String, default="queued") # queued/running/success/failed/cancelled progress = Column(Float, default=0.0) message = Column(Text, default="") output_path = Column(Text, nullable=True) output_url = Column(Text, nullable=True) error = Column(Text, nullable=True) request_json = Column(Text, nullable=True) # JSON string of ComposeRequest parent_id = Column(String, nullable=True, index=True) created_at = Column(Float, default=time.time) updated_at = Column(Float, default=time.time, onupdate=time.time) class DBManager: def __init__(self, connection_string: str = None): if not connection_string: connection_string = config.DB_CONNECTION_STRING self.engine = create_engine(connection_string, pool_recycle=3600) self.Session = scoped_session(sessionmaker(bind=self.engine)) self._init_db() # bootstrap admin (safe to call repeatedly) try: self._bootstrap_admin_id = self.ensure_admin_user("admin", "admin1234") except Exception: self._bootstrap_admin_id = None def _init_db(self): """初始化表结构 + 轻量自迁移(不依赖 Alembic)""" Base.metadata.create_all(self.engine) self._ensure_schema() def _ensure_schema(self) -> None: """ Ensure newer columns/tables exist in both Postgres and SQLite. create_all will not add columns to existing tables, so we do a minimal ALTER here. """ try: insp = inspect(self.engine) # projects.owner_user_id try: cols = [c["name"] for c in insp.get_columns("projects")] except Exception: cols = [] if "owner_user_id" not in cols: logger.info("Migrating: add projects.owner_user_id") with self.engine.begin() as conn: conn.execute(text("ALTER TABLE projects ADD COLUMN owner_user_id VARCHAR")) try: conn.execute(text("CREATE INDEX IF NOT EXISTS ix_projects_owner_user_id ON projects (owner_user_id)")) except Exception: # some engines (older sqlite) may not support IF NOT EXISTS try: conn.execute(text("CREATE INDEX ix_projects_owner_user_id ON projects (owner_user_id)")) except Exception: pass except Exception as e: logger.warning(f"Schema ensure skipped/failed (non-fatal): {e}") # Ensure render_jobs table exists (create_all should handle this; keep for safety) try: insp = inspect(self.engine) if "render_jobs" not in insp.get_table_names(): logger.info("Migrating: create render_jobs table") Base.metadata.create_all(self.engine) except Exception as e: logger.warning(f"render_jobs ensure skipped/failed (non-fatal): {e}") def _get_session(self): return self.Session() # --- Project Operations --- def create_project(self, project_id: str, name: str, product_info: Dict[str, Any], owner_user_id: Optional[str] = None): session = self._get_session() try: # Check if exists existing = session.query(Project).filter_by(id=project_id).first() if existing: logger.warning(f"Project {project_id} already exists.") return new_project = Project( id=project_id, name=name, status="created", product_info=json.dumps(product_info, ensure_ascii=False), owner_user_id=owner_user_id, created_at=time.time(), updated_at=time.time() ) session.add(new_project) session.commit() except Exception as e: session.rollback() logger.error(f"Error creating project: {e}") raise finally: session.close() def update_project_script(self, project_id: str, script: Dict[str, Any]): session = self._get_session() try: project = session.query(Project).filter_by(id=project_id).first() if project: project.script_data = json.dumps(script, ensure_ascii=False) project.status = "script_generated" project.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"Error updating script: {e}") finally: session.close() def update_project_product_info(self, project_id: str, product_info: Dict[str, Any]): """ Update project.product_info JSON (read-write with Postgres shared DB). Used to persist editor state without changing schema. """ session = self._get_session() try: project = session.query(Project).filter_by(id=project_id).first() if project: project.product_info = json.dumps(product_info, ensure_ascii=False) project.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"Error updating product_info: {e}") raise finally: session.close() def update_project_status(self, project_id: str, status: str): session = self._get_session() try: project = session.query(Project).filter_by(id=project_id).first() if project: project.status = status project.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"Error updating status: {e}") finally: session.close() def get_project(self, project_id: str) -> Optional[Dict[str, Any]]: session = self._get_session() try: project = session.query(Project).filter_by(id=project_id).first() if project: data = { "id": project.id, "name": project.name, "status": project.status, "product_info": json.loads(project.product_info) if project.product_info else {}, "script_data": json.loads(project.script_data) if project.script_data else None, "owner_user_id": getattr(project, "owner_user_id", None), "created_at": project.created_at, "updated_at": project.updated_at } return data return None finally: session.close() def list_projects(self) -> List[Dict[str, Any]]: session = self._get_session() try: projects = session.query(Project).order_by(Project.updated_at.desc()).all() results = [] for p in projects: results.append({ "id": p.id, "name": p.name, "status": p.status, "owner_user_id": getattr(p, "owner_user_id", None), "updated_at": p.updated_at }) return results finally: session.close() # --- User/Auth Operations --- def ensure_admin_user(self, username: str = "admin", password: str = "admin1234") -> str: """Create bootstrap admin user if missing. Returns admin user_id.""" session = self._get_session() try: u = session.query(User).filter_by(username=username).first() if u: return u.id pwd_hash, salt_hex = _hash_password(password) uid = secrets.token_hex(16) u = User( id=uid, username=username, password_hash=pwd_hash, password_salt=salt_hex, role="admin", is_active=1, created_at=time.time(), updated_at=time.time(), ) session.add(u) session.commit() logger.warning("Bootstrap admin created (please change password asap).") return uid except Exception as e: session.rollback() logger.error(f"ensure_admin_user failed: {e}") raise finally: session.close() def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]: session = self._get_session() try: u = session.query(User).filter_by(username=username).first() if not u or int(getattr(u, "is_active", 1) or 0) != 1: return None if not _verify_password(password or "", getattr(u, "password_hash", ""), getattr(u, "password_salt", "")): return None u.last_login_at = time.time() session.commit() return {"id": u.id, "username": u.username, "role": u.role, "is_active": int(u.is_active or 0)} except Exception as e: session.rollback() logger.error(f"authenticate_user error: {e}") return None finally: session.close() def create_session(self, user_id: str, *, ttl_seconds: int = 7 * 24 * 3600, ip: str = None, user_agent: str = None) -> str: """Returns raw session token (store hash in DB).""" token = _new_session_token() token_hash = _hash_token(token) sid = secrets.token_hex(16) session = self._get_session() try: now = time.time() s = UserSession( id=sid, user_id=user_id, token_hash=token_hash, expires_at=now + int(ttl_seconds), created_at=now, last_seen_at=now, ip=ip, user_agent=user_agent, ) session.add(s) session.commit() return token except Exception as e: session.rollback() logger.error(f"create_session error: {e}") raise finally: session.close() def validate_session(self, token: str) -> Optional[Dict[str, Any]]: if not token: return None token_hash = _hash_token(token) session = self._get_session() try: now = time.time() s = session.query(UserSession).filter_by(token_hash=token_hash).first() if not s or (s.expires_at and float(s.expires_at) < now): return None u = session.query(User).filter_by(id=s.user_id).first() if not u or int(getattr(u, "is_active", 1) or 0) != 1: return None s.last_seen_at = now session.commit() return {"id": u.id, "username": u.username, "role": u.role, "is_active": int(u.is_active or 0)} except Exception as e: session.rollback() logger.error(f"validate_session error: {e}") return None finally: session.close() def revoke_session(self, token: str) -> None: if not token: return token_hash = _hash_token(token) session = self._get_session() try: s = session.query(UserSession).filter_by(token_hash=token_hash).first() if s: session.delete(s) session.commit() except Exception as e: session.rollback() logger.error(f"revoke_session error: {e}") finally: session.close() def list_users(self) -> List[Dict[str, Any]]: session = self._get_session() try: users = session.query(User).order_by(User.created_at.asc()).all() return [ { "id": u.id, "username": u.username, "role": u.role, "is_active": int(u.is_active or 0), "created_at": u.created_at, "last_login_at": u.last_login_at, } for u in users ] finally: session.close() def upsert_user(self, username: str, password: str, *, role: str = "user", is_active: int = 1) -> str: session = self._get_session() try: u = session.query(User).filter_by(username=username).first() pwd_hash, salt_hex = _hash_password(password) now = time.time() if u: u.password_hash = pwd_hash u.password_salt = salt_hex u.role = role u.is_active = int(is_active) u.updated_at = now session.commit() return u.id uid = secrets.token_hex(16) u = User( id=uid, username=username, password_hash=pwd_hash, password_salt=salt_hex, role=role, is_active=int(is_active), created_at=now, updated_at=now, ) session.add(u) session.commit() return uid except Exception as e: session.rollback() logger.error(f"upsert_user error: {e}") raise finally: session.close() def set_user_active(self, user_id: str, is_active: int) -> None: session = self._get_session() try: u = session.query(User).filter_by(id=user_id).first() if not u: return u.is_active = int(is_active) u.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"set_user_active error: {e}") finally: session.close() def reset_user_password(self, user_id: str, new_password: str) -> None: session = self._get_session() try: u = session.query(User).filter_by(id=user_id).first() if not u: return pwd_hash, salt_hex = _hash_password(new_password) u.password_hash = pwd_hash u.password_salt = salt_hex u.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"reset_user_password error: {e}") finally: session.close() def get_user_prompt(self, user_id: str, key: str) -> Optional[str]: if not user_id or not key: return None session = self._get_session() try: p = session.query(UserPrompt).filter_by(user_id=user_id, key=key).first() return p.value if p else None finally: session.close() def set_user_prompt(self, user_id: str, key: str, value: Any) -> None: if not user_id or not key: return session = self._get_session() try: now = time.time() v = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value p = session.query(UserPrompt).filter_by(user_id=user_id, key=key).first() if p: p.value = v p.updated_at = now else: p = UserPrompt(id=secrets.token_hex(16), user_id=user_id, key=key, value=v, updated_at=now) session.add(p) session.commit() except Exception as e: session.rollback() logger.error(f"set_user_prompt error: {e}") finally: session.close() # --- RBAC helpers --- def list_projects_for_user(self, user: Dict[str, Any]) -> List[Dict[str, Any]]: if not user: return [] if user.get("role") == "admin": return self.list_projects() uid = user.get("id") session = self._get_session() try: projects = session.query(Project).filter_by(owner_user_id=uid).order_by(Project.updated_at.desc()).all() return [{"id": p.id, "name": p.name, "status": p.status, "owner_user_id": p.owner_user_id, "updated_at": p.updated_at} for p in projects] finally: session.close() def get_project_for_user(self, project_id: str, user: Dict[str, Any]) -> Optional[Dict[str, Any]]: data = self.get_project(project_id) if not data or not user: return None if user.get("role") == "admin": return data if data.get("owner_user_id") != user.get("id"): return None return data def migrate_projects_owner_to(self, owner_user_id: str) -> int: """Assign owner_user_id for legacy projects where it is NULL/empty.""" session = self._get_session() try: q = session.query(Project).filter((Project.owner_user_id == None) | (Project.owner_user_id == "")) # noqa: E711 rows = q.all() for p in rows: p.owner_user_id = owner_user_id p.updated_at = time.time() session.commit() return len(rows) except Exception as e: session.rollback() logger.error(f"migrate_projects_owner_to error: {e}") return 0 finally: session.close() # --- Asset/Task Operations --- def save_asset(self, project_id: str, scene_id: int, asset_type: str, status: str, local_path: str = None, remote_url: str = None, task_id: str = None, metadata: Dict = None): """保存或更新资产记录 (UPSERT 逻辑)""" session = self._get_session() try: asset = session.query(SceneAsset).filter_by( project_id=project_id, scene_id=scene_id, asset_type=asset_type ).first() meta_json = json.dumps(metadata, ensure_ascii=False) if metadata else "{}" if asset: asset.status = status asset.local_path = local_path asset.remote_url = remote_url asset.task_id = task_id asset.metadata_json = meta_json asset.updated_at = time.time() else: new_asset = SceneAsset( project_id=project_id, scene_id=scene_id, asset_type=asset_type, status=status, local_path=local_path, remote_url=remote_url, task_id=task_id, metadata_json=meta_json, created_at=time.time(), updated_at=time.time() ) session.add(new_asset) session.commit() except Exception as e: session.rollback() logger.error(f"Error saving asset: {e}") finally: session.close() def get_assets(self, project_id: str, asset_type: str = None) -> List[Dict[str, Any]]: session = self._get_session() try: query = session.query(SceneAsset).filter_by(project_id=project_id) if asset_type: query = query.filter_by(asset_type=asset_type) assets = query.all() results = [] for a in assets: data = { "id": a.id, "project_id": a.project_id, "scene_id": a.scene_id, "asset_type": a.asset_type, "status": a.status, "local_path": a.local_path, "remote_url": a.remote_url, "task_id": a.task_id, "metadata": json.loads(a.metadata_json) if a.metadata_json else {}, "updated_at": a.updated_at } results.append(data) return results finally: session.close() # --- Render Job Operations --- def create_render_job( self, job_id: str, project_id: str, *, status: str = "queued", progress: float = 0.0, message: str = "", request: Optional[Dict[str, Any]] = None, parent_id: Optional[str] = None, ) -> None: session = self._get_session() try: existing = session.query(RenderJob).filter_by(id=job_id).first() if existing: return now = time.time() rj = RenderJob( id=job_id, project_id=project_id, status=status, progress=float(progress or 0.0), message=message or "", request_json=json.dumps(request, ensure_ascii=False) if request is not None else None, parent_id=parent_id, created_at=now, updated_at=now, ) session.add(rj) session.commit() except Exception as e: session.rollback() logger.error(f"create_render_job error: {e}") raise finally: session.close() def update_render_job(self, job_id: str, patch: Dict[str, Any]) -> None: if not job_id or not patch: return session = self._get_session() try: rj = session.query(RenderJob).filter_by(id=job_id).first() if not rj: return for k, v in patch.items(): if not hasattr(rj, k): continue setattr(rj, k, v) rj.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"update_render_job error: {e}") finally: session.close() def get_render_job(self, job_id: str) -> Optional[Dict[str, Any]]: session = self._get_session() try: rj = session.query(RenderJob).filter_by(id=job_id).first() if not rj: return None try: req = json.loads(rj.request_json) if rj.request_json else None except Exception: req = None return { "id": rj.id, "project_id": rj.project_id, "status": rj.status, "progress": float(rj.progress or 0.0), "message": rj.message or "", "output_path": rj.output_path, "output_url": rj.output_url, "error": rj.error, "request": req, "parent_id": rj.parent_id, "created_at": rj.created_at, "updated_at": rj.updated_at, } finally: session.close() def get_asset(self, project_id: str, scene_id: int, asset_type: str) -> Optional[Dict[str, Any]]: session = self._get_session() try: a = session.query(SceneAsset).filter_by( project_id=project_id, scene_id=scene_id, asset_type=asset_type ).first() if a: return { "id": a.id, "project_id": a.project_id, "scene_id": a.scene_id, "asset_type": a.asset_type, "status": a.status, "local_path": a.local_path, "remote_url": a.remote_url, "task_id": a.task_id, "metadata": json.loads(a.metadata_json) if a.metadata_json else {}, "updated_at": a.updated_at } return None finally: session.close() def get_asset_by_id(self, asset_id: int) -> Optional[Dict[str, Any]]: """通过自增 id 获取素材记录(用于 file proxy)。""" session = self._get_session() try: a = session.query(SceneAsset).filter_by(id=int(asset_id)).first() if not a: return None return { "id": a.id, "project_id": a.project_id, "scene_id": a.scene_id, "asset_type": a.asset_type, "status": a.status, "local_path": a.local_path, "remote_url": a.remote_url, "task_id": a.task_id, "metadata": json.loads(a.metadata_json) if a.metadata_json else {}, "updated_at": a.updated_at, } except Exception: return None finally: session.close() def update_asset_metadata(self, project_id: str, scene_id: int, asset_type: str, patch: Dict[str, Any]) -> None: """Merge-patch asset.metadata JSON without overwriting other fields.""" if not patch: return session = self._get_session() try: asset = session.query(SceneAsset).filter_by( project_id=project_id, scene_id=scene_id, asset_type=asset_type ).first() if not asset: return try: existing = json.loads(asset.metadata_json) if asset.metadata_json else {} except Exception: existing = {} if not isinstance(existing, dict): existing = {} existing.update(patch) asset.metadata_json = json.dumps(existing, ensure_ascii=False) asset.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"Error updating asset metadata: {e}") finally: session.close() def clear_asset(self, project_id: str, scene_id: int, asset_type: str, *, status: str = "pending") -> None: """ Clear an asset record (keep the row, but remove paths/task/metadata). Used to invalidate stale videos after images are regenerated. """ session = self._get_session() try: asset = session.query(SceneAsset).filter_by( project_id=project_id, scene_id=scene_id, asset_type=asset_type, ).first() if not asset: return asset.status = status asset.local_path = None asset.remote_url = None asset.task_id = None asset.metadata_json = "{}" asset.updated_at = time.time() session.commit() except Exception as e: session.rollback() logger.error(f"Error clearing asset: {e}") finally: session.close() def clear_assets(self, project_id: str, asset_type: str, *, status: str = "pending") -> int: """ Clear all assets of a type for a project. Returns number of rows affected. """ session = self._get_session() try: q = session.query(SceneAsset).filter_by(project_id=project_id, asset_type=asset_type) rows = q.all() if not rows: return 0 for asset in rows: asset.status = status asset.local_path = None asset.remote_url = None asset.task_id = None asset.metadata_json = "{}" asset.updated_at = time.time() session.commit() return len(rows) except Exception as e: session.rollback() logger.error(f"Error clearing assets: {e}") return 0 finally: session.close() # --- Config/Prompt Operations --- def get_config(self, key: str, default: Any = None) -> Any: session = self._get_session() try: cfg = session.query(AppConfig).filter_by(key=key).first() if cfg: try: return json.loads(cfg.value) except: return cfg.value return default finally: session.close() def set_config(self, key: str, value: Any, description: str = None): session = self._get_session() try: json_val = json.dumps(value, ensure_ascii=False) cfg = session.query(AppConfig).filter_by(key=key).first() if cfg: cfg.value = json_val if description: cfg.description = description cfg.updated_at = time.time() else: new_cfg = AppConfig( key=key, value=json_val, description=description, updated_at=time.time() ) session.add(new_cfg) session.commit() except Exception as e: session.rollback() logger.error(f"Error setting config: {e}") finally: session.close() # Singleton instance db = DBManager()