923 lines
34 KiB
Python
923 lines
34 KiB
Python
"""
|
||
数据库管理模块 (SQLAlchemy)
|
||
负责项目数据、任务状态、素材路径的持久化存储
|
||
支持 SQLite 和 PostgreSQL
|
||
"""
|
||
import json
|
||
import logging
|
||
import time
|
||
import secrets
|
||
from typing import Dict, List, Any, Optional, Tuple
|
||
|
||
from sqlalchemy import create_engine, Column, String, Integer, Text, Float, UniqueConstraint, func, text, inspect
|
||
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||
from sqlalchemy.dialects.postgresql import JSONB
|
||
|
||
import config
|
||
|
||
# NOTE: Some deployments do not ship `modules/auth.py`.
|
||
# Keep a minimal, local auth helper here to avoid hard dependency.
|
||
import hashlib
|
||
import hmac
|
||
|
||
|
||
def _hash_password(password: str, salt_hex: str = None) -> tuple[str, str]:
|
||
salt = bytes.fromhex(salt_hex) if salt_hex else secrets.token_bytes(16)
|
||
dk = hashlib.pbkdf2_hmac("sha256", (password or "").encode("utf-8"), salt, 120_000)
|
||
return dk.hex(), salt.hex()
|
||
|
||
|
||
def _verify_password(password: str, pwd_hash_hex: str, salt_hex: str) -> bool:
|
||
try:
|
||
cand, _ = _hash_password(password or "", salt_hex=salt_hex)
|
||
return hmac.compare_digest(cand, pwd_hash_hex or "")
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
def _new_session_token() -> str:
|
||
return secrets.token_urlsafe(32)
|
||
|
||
|
||
def _hash_token(token: str) -> str:
|
||
# store only hash for sessions
|
||
return hashlib.sha256((token or "").encode("utf-8")).hexdigest()
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
Base = declarative_base()
|
||
|
||
class Project(Base):
|
||
__tablename__ = 'projects'
|
||
|
||
id = Column(String, primary_key=True)
|
||
name = Column(String)
|
||
status = Column(String) # created, script_generated, images_generated, videos_generated, completed
|
||
product_info = Column(Text) # JSON string (SQLite) or JSONB (PG - using Text for compat)
|
||
script_data = Column(Text) # JSON string
|
||
owner_user_id = Column(String, index=True, nullable=True)
|
||
created_at = Column(Float, default=time.time)
|
||
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||
|
||
class SceneAsset(Base):
|
||
__tablename__ = 'scene_assets'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
project_id = Column(String, index=True)
|
||
scene_id = Column(Integer)
|
||
asset_type = Column(String) # image, video
|
||
status = Column(String) # pending, processing, completed, failed
|
||
local_path = Column(Text, nullable=True)
|
||
remote_url = Column(Text, nullable=True)
|
||
task_id = Column(String, nullable=True) # 外部 API 的任务 ID
|
||
metadata_json = Column("metadata", Text, nullable=True) # JSON string (renamed to avoid conflict with metadata attr)
|
||
created_at = Column(Float, default=time.time)
|
||
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||
|
||
__table_args__ = (UniqueConstraint('project_id', 'scene_id', 'asset_type', name='uix_project_scene_asset'),)
|
||
|
||
class AppConfig(Base):
|
||
__tablename__ = 'app_config'
|
||
|
||
key = Column(String, primary_key=True)
|
||
value = Column(Text) # JSON string
|
||
description = Column(Text, nullable=True)
|
||
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||
|
||
|
||
class User(Base):
|
||
__tablename__ = "users"
|
||
id = Column(String, primary_key=True) # uuid hex
|
||
username = Column(String, unique=True, index=True)
|
||
password_hash = Column(String)
|
||
password_salt = Column(String)
|
||
role = Column(String, default="user") # admin/user
|
||
is_active = Column(Integer, default=1) # 1/0 for portability
|
||
created_at = Column(Float, default=time.time)
|
||
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||
last_login_at = Column(Float, nullable=True)
|
||
|
||
|
||
class UserSession(Base):
|
||
__tablename__ = "user_sessions"
|
||
id = Column(String, primary_key=True) # uuid hex
|
||
user_id = Column(String, index=True)
|
||
token_hash = Column(String, index=True)
|
||
expires_at = Column(Float)
|
||
created_at = Column(Float, default=time.time)
|
||
last_seen_at = Column(Float, default=time.time)
|
||
ip = Column(String, nullable=True)
|
||
user_agent = Column(String, nullable=True)
|
||
|
||
|
||
class UserPrompt(Base):
|
||
__tablename__ = "user_prompts"
|
||
id = Column(String, primary_key=True) # uuid hex
|
||
user_id = Column(String, index=True)
|
||
key = Column(String, index=True)
|
||
value = Column(Text)
|
||
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||
__table_args__ = (UniqueConstraint("user_id", "key", name="uix_user_prompt"),)
|
||
|
||
|
||
class RenderJob(Base):
|
||
"""
|
||
Render job status table (async compose pipeline).
|
||
This is intentionally minimal and portable across SQLite/Postgres.
|
||
"""
|
||
__tablename__ = "render_jobs"
|
||
|
||
id = Column(String, primary_key=True) # task_id
|
||
project_id = Column(String, index=True)
|
||
status = Column(String, default="queued") # queued/running/success/failed/cancelled
|
||
progress = Column(Float, default=0.0)
|
||
message = Column(Text, default="")
|
||
output_path = Column(Text, nullable=True)
|
||
output_url = Column(Text, nullable=True)
|
||
error = Column(Text, nullable=True)
|
||
request_json = Column(Text, nullable=True) # JSON string of ComposeRequest
|
||
parent_id = Column(String, nullable=True, index=True)
|
||
created_at = Column(Float, default=time.time)
|
||
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||
|
||
class DBManager:
|
||
def __init__(self, connection_string: str = None):
|
||
if not connection_string:
|
||
connection_string = config.DB_CONNECTION_STRING
|
||
|
||
self.engine = create_engine(connection_string, pool_recycle=3600)
|
||
self.Session = scoped_session(sessionmaker(bind=self.engine))
|
||
self._init_db()
|
||
# bootstrap admin (safe to call repeatedly)
|
||
try:
|
||
self._bootstrap_admin_id = self.ensure_admin_user("admin", "admin1234")
|
||
except Exception:
|
||
self._bootstrap_admin_id = None
|
||
|
||
def _init_db(self):
|
||
"""初始化表结构 + 轻量自迁移(不依赖 Alembic)"""
|
||
Base.metadata.create_all(self.engine)
|
||
self._ensure_schema()
|
||
|
||
def _ensure_schema(self) -> None:
|
||
"""
|
||
Ensure newer columns/tables exist in both Postgres and SQLite.
|
||
create_all will not add columns to existing tables, so we do a minimal ALTER here.
|
||
"""
|
||
try:
|
||
insp = inspect(self.engine)
|
||
# projects.owner_user_id
|
||
try:
|
||
cols = [c["name"] for c in insp.get_columns("projects")]
|
||
except Exception:
|
||
cols = []
|
||
if "owner_user_id" not in cols:
|
||
logger.info("Migrating: add projects.owner_user_id")
|
||
with self.engine.begin() as conn:
|
||
conn.execute(text("ALTER TABLE projects ADD COLUMN owner_user_id VARCHAR"))
|
||
try:
|
||
conn.execute(text("CREATE INDEX IF NOT EXISTS ix_projects_owner_user_id ON projects (owner_user_id)"))
|
||
except Exception:
|
||
# some engines (older sqlite) may not support IF NOT EXISTS
|
||
try:
|
||
conn.execute(text("CREATE INDEX ix_projects_owner_user_id ON projects (owner_user_id)"))
|
||
except Exception:
|
||
pass
|
||
except Exception as e:
|
||
logger.warning(f"Schema ensure skipped/failed (non-fatal): {e}")
|
||
|
||
# Ensure render_jobs table exists (create_all should handle this; keep for safety)
|
||
try:
|
||
insp = inspect(self.engine)
|
||
if "render_jobs" not in insp.get_table_names():
|
||
logger.info("Migrating: create render_jobs table")
|
||
Base.metadata.create_all(self.engine)
|
||
except Exception as e:
|
||
logger.warning(f"render_jobs ensure skipped/failed (non-fatal): {e}")
|
||
|
||
def _get_session(self):
|
||
return self.Session()
|
||
|
||
# --- Project Operations ---
|
||
|
||
def create_project(self, project_id: str, name: str, product_info: Dict[str, Any], owner_user_id: Optional[str] = None):
|
||
session = self._get_session()
|
||
try:
|
||
# Check if exists
|
||
existing = session.query(Project).filter_by(id=project_id).first()
|
||
if existing:
|
||
logger.warning(f"Project {project_id} already exists.")
|
||
return
|
||
|
||
new_project = Project(
|
||
id=project_id,
|
||
name=name,
|
||
status="created",
|
||
product_info=json.dumps(product_info, ensure_ascii=False),
|
||
owner_user_id=owner_user_id,
|
||
created_at=time.time(),
|
||
updated_at=time.time()
|
||
)
|
||
session.add(new_project)
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"Error creating project: {e}")
|
||
raise
|
||
finally:
|
||
session.close()
|
||
|
||
def update_project_script(self, project_id: str, script: Dict[str, Any]):
|
||
session = self._get_session()
|
||
try:
|
||
project = session.query(Project).filter_by(id=project_id).first()
|
||
if project:
|
||
project.script_data = json.dumps(script, ensure_ascii=False)
|
||
project.status = "script_generated"
|
||
project.updated_at = time.time()
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"Error updating script: {e}")
|
||
finally:
|
||
session.close()
|
||
|
||
def update_project_product_info(self, project_id: str, product_info: Dict[str, Any]):
|
||
"""
|
||
Update project.product_info JSON (read-write with Postgres shared DB).
|
||
Used to persist editor state without changing schema.
|
||
"""
|
||
session = self._get_session()
|
||
try:
|
||
project = session.query(Project).filter_by(id=project_id).first()
|
||
if project:
|
||
project.product_info = json.dumps(product_info, ensure_ascii=False)
|
||
project.updated_at = time.time()
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"Error updating product_info: {e}")
|
||
raise
|
||
finally:
|
||
session.close()
|
||
|
||
def update_project_status(self, project_id: str, status: str):
|
||
session = self._get_session()
|
||
try:
|
||
project = session.query(Project).filter_by(id=project_id).first()
|
||
if project:
|
||
project.status = status
|
||
project.updated_at = time.time()
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"Error updating status: {e}")
|
||
finally:
|
||
session.close()
|
||
|
||
def get_project(self, project_id: str) -> Optional[Dict[str, Any]]:
|
||
session = self._get_session()
|
||
try:
|
||
project = session.query(Project).filter_by(id=project_id).first()
|
||
if project:
|
||
data = {
|
||
"id": project.id,
|
||
"name": project.name,
|
||
"status": project.status,
|
||
"product_info": json.loads(project.product_info) if project.product_info else {},
|
||
"script_data": json.loads(project.script_data) if project.script_data else None,
|
||
"owner_user_id": getattr(project, "owner_user_id", None),
|
||
"created_at": project.created_at,
|
||
"updated_at": project.updated_at
|
||
}
|
||
return data
|
||
return None
|
||
finally:
|
||
session.close()
|
||
|
||
def list_projects(self) -> List[Dict[str, Any]]:
|
||
session = self._get_session()
|
||
try:
|
||
projects = session.query(Project).order_by(Project.updated_at.desc()).all()
|
||
results = []
|
||
for p in projects:
|
||
results.append({
|
||
"id": p.id,
|
||
"name": p.name,
|
||
"status": p.status,
|
||
"owner_user_id": getattr(p, "owner_user_id", None),
|
||
"updated_at": p.updated_at
|
||
})
|
||
return results
|
||
finally:
|
||
session.close()
|
||
|
||
# --- User/Auth Operations ---
|
||
|
||
def ensure_admin_user(self, username: str = "admin", password: str = "admin1234") -> str:
|
||
"""Create bootstrap admin user if missing. Returns admin user_id."""
|
||
session = self._get_session()
|
||
try:
|
||
u = session.query(User).filter_by(username=username).first()
|
||
if u:
|
||
return u.id
|
||
pwd_hash, salt_hex = _hash_password(password)
|
||
uid = secrets.token_hex(16)
|
||
u = User(
|
||
id=uid,
|
||
username=username,
|
||
password_hash=pwd_hash,
|
||
password_salt=salt_hex,
|
||
role="admin",
|
||
is_active=1,
|
||
created_at=time.time(),
|
||
updated_at=time.time(),
|
||
)
|
||
session.add(u)
|
||
session.commit()
|
||
logger.warning("Bootstrap admin created (please change password asap).")
|
||
return uid
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"ensure_admin_user failed: {e}")
|
||
raise
|
||
finally:
|
||
session.close()
|
||
|
||
def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
|
||
session = self._get_session()
|
||
try:
|
||
u = session.query(User).filter_by(username=username).first()
|
||
if not u or int(getattr(u, "is_active", 1) or 0) != 1:
|
||
return None
|
||
if not _verify_password(password or "", getattr(u, "password_hash", ""), getattr(u, "password_salt", "")):
|
||
return None
|
||
u.last_login_at = time.time()
|
||
session.commit()
|
||
return {"id": u.id, "username": u.username, "role": u.role, "is_active": int(u.is_active or 0)}
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"authenticate_user error: {e}")
|
||
return None
|
||
finally:
|
||
session.close()
|
||
|
||
def create_session(self, user_id: str, *, ttl_seconds: int = 7 * 24 * 3600, ip: str = None, user_agent: str = None) -> str:
|
||
"""Returns raw session token (store hash in DB)."""
|
||
token = _new_session_token()
|
||
token_hash = _hash_token(token)
|
||
sid = secrets.token_hex(16)
|
||
session = self._get_session()
|
||
try:
|
||
now = time.time()
|
||
s = UserSession(
|
||
id=sid,
|
||
user_id=user_id,
|
||
token_hash=token_hash,
|
||
expires_at=now + int(ttl_seconds),
|
||
created_at=now,
|
||
last_seen_at=now,
|
||
ip=ip,
|
||
user_agent=user_agent,
|
||
)
|
||
session.add(s)
|
||
session.commit()
|
||
return token
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"create_session error: {e}")
|
||
raise
|
||
finally:
|
||
session.close()
|
||
|
||
def validate_session(self, token: str) -> Optional[Dict[str, Any]]:
|
||
if not token:
|
||
return None
|
||
token_hash = _hash_token(token)
|
||
session = self._get_session()
|
||
try:
|
||
now = time.time()
|
||
s = session.query(UserSession).filter_by(token_hash=token_hash).first()
|
||
if not s or (s.expires_at and float(s.expires_at) < now):
|
||
return None
|
||
u = session.query(User).filter_by(id=s.user_id).first()
|
||
if not u or int(getattr(u, "is_active", 1) or 0) != 1:
|
||
return None
|
||
s.last_seen_at = now
|
||
session.commit()
|
||
return {"id": u.id, "username": u.username, "role": u.role, "is_active": int(u.is_active or 0)}
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"validate_session error: {e}")
|
||
return None
|
||
finally:
|
||
session.close()
|
||
|
||
def revoke_session(self, token: str) -> None:
|
||
if not token:
|
||
return
|
||
token_hash = _hash_token(token)
|
||
session = self._get_session()
|
||
try:
|
||
s = session.query(UserSession).filter_by(token_hash=token_hash).first()
|
||
if s:
|
||
session.delete(s)
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"revoke_session error: {e}")
|
||
finally:
|
||
session.close()
|
||
|
||
def list_users(self) -> List[Dict[str, Any]]:
|
||
session = self._get_session()
|
||
try:
|
||
users = session.query(User).order_by(User.created_at.asc()).all()
|
||
return [
|
||
{
|
||
"id": u.id,
|
||
"username": u.username,
|
||
"role": u.role,
|
||
"is_active": int(u.is_active or 0),
|
||
"created_at": u.created_at,
|
||
"last_login_at": u.last_login_at,
|
||
}
|
||
for u in users
|
||
]
|
||
finally:
|
||
session.close()
|
||
|
||
def upsert_user(self, username: str, password: str, *, role: str = "user", is_active: int = 1) -> str:
|
||
session = self._get_session()
|
||
try:
|
||
u = session.query(User).filter_by(username=username).first()
|
||
pwd_hash, salt_hex = _hash_password(password)
|
||
now = time.time()
|
||
if u:
|
||
u.password_hash = pwd_hash
|
||
u.password_salt = salt_hex
|
||
u.role = role
|
||
u.is_active = int(is_active)
|
||
u.updated_at = now
|
||
session.commit()
|
||
return u.id
|
||
uid = secrets.token_hex(16)
|
||
u = User(
|
||
id=uid,
|
||
username=username,
|
||
password_hash=pwd_hash,
|
||
password_salt=salt_hex,
|
||
role=role,
|
||
is_active=int(is_active),
|
||
created_at=now,
|
||
updated_at=now,
|
||
)
|
||
session.add(u)
|
||
session.commit()
|
||
return uid
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"upsert_user error: {e}")
|
||
raise
|
||
finally:
|
||
session.close()
|
||
|
||
def set_user_active(self, user_id: str, is_active: int) -> None:
|
||
session = self._get_session()
|
||
try:
|
||
u = session.query(User).filter_by(id=user_id).first()
|
||
if not u:
|
||
return
|
||
u.is_active = int(is_active)
|
||
u.updated_at = time.time()
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"set_user_active error: {e}")
|
||
finally:
|
||
session.close()
|
||
|
||
def reset_user_password(self, user_id: str, new_password: str) -> None:
|
||
session = self._get_session()
|
||
try:
|
||
u = session.query(User).filter_by(id=user_id).first()
|
||
if not u:
|
||
return
|
||
pwd_hash, salt_hex = _hash_password(new_password)
|
||
u.password_hash = pwd_hash
|
||
u.password_salt = salt_hex
|
||
u.updated_at = time.time()
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"reset_user_password error: {e}")
|
||
finally:
|
||
session.close()
|
||
|
||
def get_user_prompt(self, user_id: str, key: str) -> Optional[str]:
|
||
if not user_id or not key:
|
||
return None
|
||
session = self._get_session()
|
||
try:
|
||
p = session.query(UserPrompt).filter_by(user_id=user_id, key=key).first()
|
||
return p.value if p else None
|
||
finally:
|
||
session.close()
|
||
|
||
def set_user_prompt(self, user_id: str, key: str, value: Any) -> None:
|
||
if not user_id or not key:
|
||
return
|
||
session = self._get_session()
|
||
try:
|
||
now = time.time()
|
||
v = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value
|
||
p = session.query(UserPrompt).filter_by(user_id=user_id, key=key).first()
|
||
if p:
|
||
p.value = v
|
||
p.updated_at = now
|
||
else:
|
||
p = UserPrompt(id=secrets.token_hex(16), user_id=user_id, key=key, value=v, updated_at=now)
|
||
session.add(p)
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"set_user_prompt error: {e}")
|
||
finally:
|
||
session.close()
|
||
|
||
# --- RBAC helpers ---
|
||
def list_projects_for_user(self, user: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||
if not user:
|
||
return []
|
||
if user.get("role") == "admin":
|
||
return self.list_projects()
|
||
uid = user.get("id")
|
||
session = self._get_session()
|
||
try:
|
||
projects = session.query(Project).filter_by(owner_user_id=uid).order_by(Project.updated_at.desc()).all()
|
||
return [{"id": p.id, "name": p.name, "status": p.status, "owner_user_id": p.owner_user_id, "updated_at": p.updated_at} for p in projects]
|
||
finally:
|
||
session.close()
|
||
|
||
def get_project_for_user(self, project_id: str, user: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||
data = self.get_project(project_id)
|
||
if not data or not user:
|
||
return None
|
||
if user.get("role") == "admin":
|
||
return data
|
||
if data.get("owner_user_id") != user.get("id"):
|
||
return None
|
||
return data
|
||
|
||
def migrate_projects_owner_to(self, owner_user_id: str) -> int:
|
||
"""Assign owner_user_id for legacy projects where it is NULL/empty."""
|
||
session = self._get_session()
|
||
try:
|
||
q = session.query(Project).filter((Project.owner_user_id == None) | (Project.owner_user_id == "")) # noqa: E711
|
||
rows = q.all()
|
||
for p in rows:
|
||
p.owner_user_id = owner_user_id
|
||
p.updated_at = time.time()
|
||
session.commit()
|
||
return len(rows)
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"migrate_projects_owner_to error: {e}")
|
||
return 0
|
||
finally:
|
||
session.close()
|
||
|
||
# --- Asset/Task Operations ---
|
||
|
||
def save_asset(self, project_id: str, scene_id: int, asset_type: str,
|
||
status: str, local_path: str = None, remote_url: str = None,
|
||
task_id: str = None, metadata: Dict = None):
|
||
"""保存或更新资产记录 (UPSERT 逻辑)"""
|
||
session = self._get_session()
|
||
try:
|
||
asset = session.query(SceneAsset).filter_by(
|
||
project_id=project_id,
|
||
scene_id=scene_id,
|
||
asset_type=asset_type
|
||
).first()
|
||
|
||
meta_json = json.dumps(metadata, ensure_ascii=False) if metadata else "{}"
|
||
|
||
if asset:
|
||
asset.status = status
|
||
asset.local_path = local_path
|
||
asset.remote_url = remote_url
|
||
asset.task_id = task_id
|
||
asset.metadata_json = meta_json
|
||
asset.updated_at = time.time()
|
||
else:
|
||
new_asset = SceneAsset(
|
||
project_id=project_id,
|
||
scene_id=scene_id,
|
||
asset_type=asset_type,
|
||
status=status,
|
||
local_path=local_path,
|
||
remote_url=remote_url,
|
||
task_id=task_id,
|
||
metadata_json=meta_json,
|
||
created_at=time.time(),
|
||
updated_at=time.time()
|
||
)
|
||
session.add(new_asset)
|
||
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"Error saving asset: {e}")
|
||
finally:
|
||
session.close()
|
||
|
||
def get_assets(self, project_id: str, asset_type: str = None) -> List[Dict[str, Any]]:
|
||
session = self._get_session()
|
||
try:
|
||
query = session.query(SceneAsset).filter_by(project_id=project_id)
|
||
if asset_type:
|
||
query = query.filter_by(asset_type=asset_type)
|
||
|
||
assets = query.all()
|
||
results = []
|
||
for a in assets:
|
||
data = {
|
||
"id": a.id,
|
||
"project_id": a.project_id,
|
||
"scene_id": a.scene_id,
|
||
"asset_type": a.asset_type,
|
||
"status": a.status,
|
||
"local_path": a.local_path,
|
||
"remote_url": a.remote_url,
|
||
"task_id": a.task_id,
|
||
"metadata": json.loads(a.metadata_json) if a.metadata_json else {},
|
||
"updated_at": a.updated_at
|
||
}
|
||
results.append(data)
|
||
return results
|
||
finally:
|
||
session.close()
|
||
|
||
# --- Render Job Operations ---
|
||
|
||
def create_render_job(
|
||
self,
|
||
job_id: str,
|
||
project_id: str,
|
||
*,
|
||
status: str = "queued",
|
||
progress: float = 0.0,
|
||
message: str = "",
|
||
request: Optional[Dict[str, Any]] = None,
|
||
parent_id: Optional[str] = None,
|
||
) -> None:
|
||
session = self._get_session()
|
||
try:
|
||
existing = session.query(RenderJob).filter_by(id=job_id).first()
|
||
if existing:
|
||
return
|
||
now = time.time()
|
||
rj = RenderJob(
|
||
id=job_id,
|
||
project_id=project_id,
|
||
status=status,
|
||
progress=float(progress or 0.0),
|
||
message=message or "",
|
||
request_json=json.dumps(request, ensure_ascii=False) if request is not None else None,
|
||
parent_id=parent_id,
|
||
created_at=now,
|
||
updated_at=now,
|
||
)
|
||
session.add(rj)
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"create_render_job error: {e}")
|
||
raise
|
||
finally:
|
||
session.close()
|
||
|
||
def update_render_job(self, job_id: str, patch: Dict[str, Any]) -> None:
|
||
if not job_id or not patch:
|
||
return
|
||
session = self._get_session()
|
||
try:
|
||
rj = session.query(RenderJob).filter_by(id=job_id).first()
|
||
if not rj:
|
||
return
|
||
for k, v in patch.items():
|
||
if not hasattr(rj, k):
|
||
continue
|
||
setattr(rj, k, v)
|
||
rj.updated_at = time.time()
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"update_render_job error: {e}")
|
||
finally:
|
||
session.close()
|
||
|
||
def get_render_job(self, job_id: str) -> Optional[Dict[str, Any]]:
|
||
session = self._get_session()
|
||
try:
|
||
rj = session.query(RenderJob).filter_by(id=job_id).first()
|
||
if not rj:
|
||
return None
|
||
try:
|
||
req = json.loads(rj.request_json) if rj.request_json else None
|
||
except Exception:
|
||
req = None
|
||
return {
|
||
"id": rj.id,
|
||
"project_id": rj.project_id,
|
||
"status": rj.status,
|
||
"progress": float(rj.progress or 0.0),
|
||
"message": rj.message or "",
|
||
"output_path": rj.output_path,
|
||
"output_url": rj.output_url,
|
||
"error": rj.error,
|
||
"request": req,
|
||
"parent_id": rj.parent_id,
|
||
"created_at": rj.created_at,
|
||
"updated_at": rj.updated_at,
|
||
}
|
||
finally:
|
||
session.close()
|
||
|
||
def get_asset(self, project_id: str, scene_id: int, asset_type: str) -> Optional[Dict[str, Any]]:
|
||
session = self._get_session()
|
||
try:
|
||
a = session.query(SceneAsset).filter_by(
|
||
project_id=project_id,
|
||
scene_id=scene_id,
|
||
asset_type=asset_type
|
||
).first()
|
||
|
||
if a:
|
||
return {
|
||
"id": a.id,
|
||
"project_id": a.project_id,
|
||
"scene_id": a.scene_id,
|
||
"asset_type": a.asset_type,
|
||
"status": a.status,
|
||
"local_path": a.local_path,
|
||
"remote_url": a.remote_url,
|
||
"task_id": a.task_id,
|
||
"metadata": json.loads(a.metadata_json) if a.metadata_json else {},
|
||
"updated_at": a.updated_at
|
||
}
|
||
return None
|
||
finally:
|
||
session.close()
|
||
|
||
def get_asset_by_id(self, asset_id: int) -> Optional[Dict[str, Any]]:
|
||
"""通过自增 id 获取素材记录(用于 file proxy)。"""
|
||
session = self._get_session()
|
||
try:
|
||
a = session.query(SceneAsset).filter_by(id=int(asset_id)).first()
|
||
if not a:
|
||
return None
|
||
return {
|
||
"id": a.id,
|
||
"project_id": a.project_id,
|
||
"scene_id": a.scene_id,
|
||
"asset_type": a.asset_type,
|
||
"status": a.status,
|
||
"local_path": a.local_path,
|
||
"remote_url": a.remote_url,
|
||
"task_id": a.task_id,
|
||
"metadata": json.loads(a.metadata_json) if a.metadata_json else {},
|
||
"updated_at": a.updated_at,
|
||
}
|
||
except Exception:
|
||
return None
|
||
finally:
|
||
session.close()
|
||
|
||
def update_asset_metadata(self, project_id: str, scene_id: int, asset_type: str, patch: Dict[str, Any]) -> None:
|
||
"""Merge-patch asset.metadata JSON without overwriting other fields."""
|
||
if not patch:
|
||
return
|
||
session = self._get_session()
|
||
try:
|
||
asset = session.query(SceneAsset).filter_by(
|
||
project_id=project_id,
|
||
scene_id=scene_id,
|
||
asset_type=asset_type
|
||
).first()
|
||
if not asset:
|
||
return
|
||
try:
|
||
existing = json.loads(asset.metadata_json) if asset.metadata_json else {}
|
||
except Exception:
|
||
existing = {}
|
||
if not isinstance(existing, dict):
|
||
existing = {}
|
||
existing.update(patch)
|
||
asset.metadata_json = json.dumps(existing, ensure_ascii=False)
|
||
asset.updated_at = time.time()
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"Error updating asset metadata: {e}")
|
||
finally:
|
||
session.close()
|
||
|
||
def clear_asset(self, project_id: str, scene_id: int, asset_type: str, *, status: str = "pending") -> None:
|
||
"""
|
||
Clear an asset record (keep the row, but remove paths/task/metadata).
|
||
Used to invalidate stale videos after images are regenerated.
|
||
"""
|
||
session = self._get_session()
|
||
try:
|
||
asset = session.query(SceneAsset).filter_by(
|
||
project_id=project_id,
|
||
scene_id=scene_id,
|
||
asset_type=asset_type,
|
||
).first()
|
||
if not asset:
|
||
return
|
||
asset.status = status
|
||
asset.local_path = None
|
||
asset.remote_url = None
|
||
asset.task_id = None
|
||
asset.metadata_json = "{}"
|
||
asset.updated_at = time.time()
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"Error clearing asset: {e}")
|
||
finally:
|
||
session.close()
|
||
|
||
def clear_assets(self, project_id: str, asset_type: str, *, status: str = "pending") -> int:
|
||
"""
|
||
Clear all assets of a type for a project. Returns number of rows affected.
|
||
"""
|
||
session = self._get_session()
|
||
try:
|
||
q = session.query(SceneAsset).filter_by(project_id=project_id, asset_type=asset_type)
|
||
rows = q.all()
|
||
if not rows:
|
||
return 0
|
||
for asset in rows:
|
||
asset.status = status
|
||
asset.local_path = None
|
||
asset.remote_url = None
|
||
asset.task_id = None
|
||
asset.metadata_json = "{}"
|
||
asset.updated_at = time.time()
|
||
session.commit()
|
||
return len(rows)
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"Error clearing assets: {e}")
|
||
return 0
|
||
finally:
|
||
session.close()
|
||
|
||
# --- Config/Prompt Operations ---
|
||
|
||
def get_config(self, key: str, default: Any = None) -> Any:
|
||
session = self._get_session()
|
||
try:
|
||
cfg = session.query(AppConfig).filter_by(key=key).first()
|
||
if cfg:
|
||
try:
|
||
return json.loads(cfg.value)
|
||
except:
|
||
return cfg.value
|
||
return default
|
||
finally:
|
||
session.close()
|
||
|
||
def set_config(self, key: str, value: Any, description: str = None):
|
||
session = self._get_session()
|
||
try:
|
||
json_val = json.dumps(value, ensure_ascii=False)
|
||
|
||
cfg = session.query(AppConfig).filter_by(key=key).first()
|
||
if cfg:
|
||
cfg.value = json_val
|
||
if description:
|
||
cfg.description = description
|
||
cfg.updated_at = time.time()
|
||
else:
|
||
new_cfg = AppConfig(
|
||
key=key,
|
||
value=json_val,
|
||
description=description,
|
||
updated_at=time.time()
|
||
)
|
||
session.add(new_cfg)
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"Error setting config: {e}")
|
||
finally:
|
||
session.close()
|
||
|
||
# Singleton instance
|
||
db = DBManager()
|