354 lines
13 KiB
Python
354 lines
13 KiB
Python
"""
|
|
数据库管理模块 (SQLAlchemy)
|
|
负责项目数据、任务状态、素材路径的持久化存储
|
|
支持 SQLite 和 PostgreSQL
|
|
"""
|
|
import json
|
|
import logging
|
|
import time
|
|
from typing import Dict, List, Any, Optional
|
|
|
|
from sqlalchemy import create_engine, Column, String, Integer, Text, Float, UniqueConstraint, func
|
|
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
|
from sqlalchemy.dialects.postgresql import JSONB
|
|
|
|
import config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
Base = declarative_base()
|
|
|
|
class Project(Base):
|
|
__tablename__ = 'projects'
|
|
|
|
id = Column(String, primary_key=True)
|
|
name = Column(String)
|
|
status = Column(String) # created, script_generated, images_generated, videos_generated, completed
|
|
product_info = Column(Text) # JSON string (SQLite) or JSONB (PG - using Text for compat)
|
|
script_data = Column(Text) # JSON string
|
|
created_at = Column(Float, default=time.time)
|
|
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
|
|
|
class SceneAsset(Base):
|
|
__tablename__ = 'scene_assets'
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
project_id = Column(String, index=True)
|
|
scene_id = Column(Integer)
|
|
asset_type = Column(String) # image, video
|
|
status = Column(String) # pending, processing, completed, failed
|
|
local_path = Column(Text, nullable=True)
|
|
remote_url = Column(Text, nullable=True)
|
|
task_id = Column(String, nullable=True) # 外部 API 的任务 ID
|
|
metadata_json = Column("metadata", Text, nullable=True) # JSON string (renamed to avoid conflict with metadata attr)
|
|
created_at = Column(Float, default=time.time)
|
|
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
|
|
|
__table_args__ = (UniqueConstraint('project_id', 'scene_id', 'asset_type', name='uix_project_scene_asset'),)
|
|
|
|
class AppConfig(Base):
|
|
__tablename__ = 'app_config'
|
|
|
|
key = Column(String, primary_key=True)
|
|
value = Column(Text) # JSON string
|
|
description = Column(Text, nullable=True)
|
|
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
|
|
|
class DBManager:
|
|
def __init__(self, connection_string: str = None):
|
|
if not connection_string:
|
|
connection_string = config.DB_CONNECTION_STRING
|
|
|
|
self.engine = create_engine(connection_string, pool_recycle=3600)
|
|
self.Session = scoped_session(sessionmaker(bind=self.engine))
|
|
self._init_db()
|
|
|
|
def _init_db(self):
|
|
"""初始化表结构"""
|
|
Base.metadata.create_all(self.engine)
|
|
|
|
def _get_session(self):
|
|
return self.Session()
|
|
|
|
# --- Project Operations ---
|
|
|
|
def create_project(self, project_id: str, name: str, product_info: Dict[str, Any]):
|
|
session = self._get_session()
|
|
try:
|
|
# Check if exists
|
|
existing = session.query(Project).filter_by(id=project_id).first()
|
|
if existing:
|
|
logger.warning(f"Project {project_id} already exists.")
|
|
return
|
|
|
|
new_project = Project(
|
|
id=project_id,
|
|
name=name,
|
|
status="created",
|
|
product_info=json.dumps(product_info, ensure_ascii=False),
|
|
created_at=time.time(),
|
|
updated_at=time.time()
|
|
)
|
|
session.add(new_project)
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Error creating project: {e}")
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
def update_project_script(self, project_id: str, script: Dict[str, Any]):
|
|
session = self._get_session()
|
|
try:
|
|
project = session.query(Project).filter_by(id=project_id).first()
|
|
if project:
|
|
project.script_data = json.dumps(script, ensure_ascii=False)
|
|
project.status = "script_generated"
|
|
project.updated_at = time.time()
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Error updating script: {e}")
|
|
finally:
|
|
session.close()
|
|
|
|
def update_project_product_info(self, project_id: str, product_info: Dict[str, Any]):
|
|
"""
|
|
Update project.product_info JSON (read-write with Postgres shared DB).
|
|
Used to persist editor state without changing schema.
|
|
"""
|
|
session = self._get_session()
|
|
try:
|
|
project = session.query(Project).filter_by(id=project_id).first()
|
|
if project:
|
|
project.product_info = json.dumps(product_info, ensure_ascii=False)
|
|
project.updated_at = time.time()
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Error updating product_info: {e}")
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
def update_project_status(self, project_id: str, status: str):
|
|
session = self._get_session()
|
|
try:
|
|
project = session.query(Project).filter_by(id=project_id).first()
|
|
if project:
|
|
project.status = status
|
|
project.updated_at = time.time()
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Error updating status: {e}")
|
|
finally:
|
|
session.close()
|
|
|
|
def get_project(self, project_id: str) -> Optional[Dict[str, Any]]:
|
|
session = self._get_session()
|
|
try:
|
|
project = session.query(Project).filter_by(id=project_id).first()
|
|
if project:
|
|
data = {
|
|
"id": project.id,
|
|
"name": project.name,
|
|
"status": project.status,
|
|
"product_info": json.loads(project.product_info) if project.product_info else {},
|
|
"script_data": json.loads(project.script_data) if project.script_data else None,
|
|
"created_at": project.created_at,
|
|
"updated_at": project.updated_at
|
|
}
|
|
return data
|
|
return None
|
|
finally:
|
|
session.close()
|
|
|
|
def list_projects(self) -> List[Dict[str, Any]]:
|
|
session = self._get_session()
|
|
try:
|
|
projects = session.query(Project).order_by(Project.updated_at.desc()).all()
|
|
results = []
|
|
for p in projects:
|
|
results.append({
|
|
"id": p.id,
|
|
"name": p.name,
|
|
"status": p.status,
|
|
"updated_at": p.updated_at
|
|
})
|
|
return results
|
|
finally:
|
|
session.close()
|
|
|
|
# --- Asset/Task Operations ---
|
|
|
|
def save_asset(self, project_id: str, scene_id: int, asset_type: str,
|
|
status: str, local_path: str = None, remote_url: str = None,
|
|
task_id: str = None, metadata: Dict = None):
|
|
"""保存或更新资产记录 (UPSERT 逻辑)"""
|
|
session = self._get_session()
|
|
try:
|
|
asset = session.query(SceneAsset).filter_by(
|
|
project_id=project_id,
|
|
scene_id=scene_id,
|
|
asset_type=asset_type
|
|
).first()
|
|
|
|
meta_json = json.dumps(metadata, ensure_ascii=False) if metadata else "{}"
|
|
|
|
if asset:
|
|
asset.status = status
|
|
asset.local_path = local_path
|
|
asset.remote_url = remote_url
|
|
asset.task_id = task_id
|
|
asset.metadata_json = meta_json
|
|
asset.updated_at = time.time()
|
|
else:
|
|
new_asset = SceneAsset(
|
|
project_id=project_id,
|
|
scene_id=scene_id,
|
|
asset_type=asset_type,
|
|
status=status,
|
|
local_path=local_path,
|
|
remote_url=remote_url,
|
|
task_id=task_id,
|
|
metadata_json=meta_json,
|
|
created_at=time.time(),
|
|
updated_at=time.time()
|
|
)
|
|
session.add(new_asset)
|
|
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Error saving asset: {e}")
|
|
finally:
|
|
session.close()
|
|
|
|
def get_assets(self, project_id: str, asset_type: str = None) -> List[Dict[str, Any]]:
|
|
session = self._get_session()
|
|
try:
|
|
query = session.query(SceneAsset).filter_by(project_id=project_id)
|
|
if asset_type:
|
|
query = query.filter_by(asset_type=asset_type)
|
|
|
|
assets = query.all()
|
|
results = []
|
|
for a in assets:
|
|
data = {
|
|
"id": a.id,
|
|
"project_id": a.project_id,
|
|
"scene_id": a.scene_id,
|
|
"asset_type": a.asset_type,
|
|
"status": a.status,
|
|
"local_path": a.local_path,
|
|
"remote_url": a.remote_url,
|
|
"task_id": a.task_id,
|
|
"metadata": json.loads(a.metadata_json) if a.metadata_json else {},
|
|
"updated_at": a.updated_at
|
|
}
|
|
results.append(data)
|
|
return results
|
|
finally:
|
|
session.close()
|
|
|
|
def get_asset(self, project_id: str, scene_id: int, asset_type: str) -> Optional[Dict[str, Any]]:
|
|
session = self._get_session()
|
|
try:
|
|
a = session.query(SceneAsset).filter_by(
|
|
project_id=project_id,
|
|
scene_id=scene_id,
|
|
asset_type=asset_type
|
|
).first()
|
|
|
|
if a:
|
|
return {
|
|
"id": a.id,
|
|
"project_id": a.project_id,
|
|
"scene_id": a.scene_id,
|
|
"asset_type": a.asset_type,
|
|
"status": a.status,
|
|
"local_path": a.local_path,
|
|
"remote_url": a.remote_url,
|
|
"task_id": a.task_id,
|
|
"metadata": json.loads(a.metadata_json) if a.metadata_json else {},
|
|
"updated_at": a.updated_at
|
|
}
|
|
return None
|
|
finally:
|
|
session.close()
|
|
|
|
def update_asset_metadata(self, project_id: str, scene_id: int, asset_type: str, patch: Dict[str, Any]) -> None:
|
|
"""Merge-patch asset.metadata JSON without overwriting other fields."""
|
|
if not patch:
|
|
return
|
|
session = self._get_session()
|
|
try:
|
|
asset = session.query(SceneAsset).filter_by(
|
|
project_id=project_id,
|
|
scene_id=scene_id,
|
|
asset_type=asset_type
|
|
).first()
|
|
if not asset:
|
|
return
|
|
try:
|
|
existing = json.loads(asset.metadata_json) if asset.metadata_json else {}
|
|
except Exception:
|
|
existing = {}
|
|
if not isinstance(existing, dict):
|
|
existing = {}
|
|
existing.update(patch)
|
|
asset.metadata_json = json.dumps(existing, ensure_ascii=False)
|
|
asset.updated_at = time.time()
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Error updating asset metadata: {e}")
|
|
finally:
|
|
session.close()
|
|
|
|
# --- Config/Prompt Operations ---
|
|
|
|
def get_config(self, key: str, default: Any = None) -> Any:
|
|
session = self._get_session()
|
|
try:
|
|
cfg = session.query(AppConfig).filter_by(key=key).first()
|
|
if cfg:
|
|
try:
|
|
return json.loads(cfg.value)
|
|
except:
|
|
return cfg.value
|
|
return default
|
|
finally:
|
|
session.close()
|
|
|
|
def set_config(self, key: str, value: Any, description: str = None):
|
|
session = self._get_session()
|
|
try:
|
|
json_val = json.dumps(value, ensure_ascii=False)
|
|
|
|
cfg = session.query(AppConfig).filter_by(key=key).first()
|
|
if cfg:
|
|
cfg.value = json_val
|
|
if description:
|
|
cfg.description = description
|
|
cfg.updated_at = time.time()
|
|
else:
|
|
new_cfg = AppConfig(
|
|
key=key,
|
|
value=json_val,
|
|
description=description,
|
|
updated_at=time.time()
|
|
)
|
|
session.add(new_cfg)
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Error setting config: {e}")
|
|
finally:
|
|
session.close()
|
|
|
|
# Singleton instance
|
|
db = DBManager()
|