chore: sync code and project files
This commit is contained in:
48
modules/auth.py
Normal file
48
modules/auth.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Auth helpers: password hashing + cookie token hashing.
|
||||
|
||||
We intentionally avoid heavy dependencies. Password hashing uses PBKDF2-HMAC-SHA256.
|
||||
Session tokens are random and stored server-side as SHA256(token) hashes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
PBKDF2_ITERS = 200_000
|
||||
|
||||
|
||||
def hash_password(password: str, salt_hex: Optional[str] = None) -> Tuple[str, str]:
|
||||
salt = bytes.fromhex(salt_hex) if salt_hex else secrets.token_bytes(16)
|
||||
dk = hashlib.pbkdf2_hmac("sha256", (password or "").encode("utf-8"), salt, PBKDF2_ITERS)
|
||||
return dk.hex(), salt.hex()
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str, salt_hex: str) -> bool:
|
||||
cand, _ = hash_password(password, salt_hex=salt_hex)
|
||||
return cand == (password_hash or "")
|
||||
|
||||
|
||||
def new_session_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
return hashlib.sha256((token or "").encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -502,11 +502,20 @@ class VideoComposer:
|
||||
current_video = fancy_path
|
||||
|
||||
# 7. 添加 BGM
|
||||
# 说明:add_bgm 的 ducking=True 路径使用 sidechaincompress,但该滤镜本身不做“混音”,
|
||||
# 在某些 ffmpeg 版本/参数组合下会导致 BGM 听起来像“没加上”。
|
||||
# 我们在 compose() 里已禁用 ducking,这里保持一致,使用 amix 叠加并提高默认音量。
|
||||
if bgm_path:
|
||||
bgm_output = str(Path(temp_root) / f"{output_name}_bgm.mp4")
|
||||
ffmpeg_utils.add_bgm(
|
||||
current_video, bgm_path, bgm_output,
|
||||
bgm_volume=0.15
|
||||
current_video,
|
||||
bgm_path,
|
||||
bgm_output,
|
||||
bgm_volume=0.20,
|
||||
ducking=False,
|
||||
duck_gain_db=-6.0,
|
||||
fade_in=1.0,
|
||||
fade_out=1.0,
|
||||
)
|
||||
self._add_temp(bgm_output)
|
||||
current_video = bgm_output
|
||||
|
||||
@@ -6,14 +6,43 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional
|
||||
import secrets
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
|
||||
from sqlalchemy import create_engine, Column, String, Integer, Text, Float, UniqueConstraint, func
|
||||
from sqlalchemy import create_engine, Column, String, Integer, Text, Float, UniqueConstraint, func, text, inspect
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
import config
|
||||
|
||||
# NOTE: Some deployments do not ship `modules/auth.py`.
|
||||
# Keep a minimal, local auth helper here to avoid hard dependency.
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
|
||||
def _hash_password(password: str, salt_hex: str = None) -> tuple[str, str]:
|
||||
salt = bytes.fromhex(salt_hex) if salt_hex else secrets.token_bytes(16)
|
||||
dk = hashlib.pbkdf2_hmac("sha256", (password or "").encode("utf-8"), salt, 120_000)
|
||||
return dk.hex(), salt.hex()
|
||||
|
||||
|
||||
def _verify_password(password: str, pwd_hash_hex: str, salt_hex: str) -> bool:
|
||||
try:
|
||||
cand, _ = _hash_password(password or "", salt_hex=salt_hex)
|
||||
return hmac.compare_digest(cand, pwd_hash_hex or "")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _new_session_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def _hash_token(token: str) -> str:
|
||||
# store only hash for sessions
|
||||
return hashlib.sha256((token or "").encode("utf-8")).hexdigest()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Base = declarative_base()
|
||||
@@ -26,6 +55,7 @@ class Project(Base):
|
||||
status = Column(String) # created, script_generated, images_generated, videos_generated, completed
|
||||
product_info = Column(Text) # JSON string (SQLite) or JSONB (PG - using Text for compat)
|
||||
script_data = Column(Text) # JSON string
|
||||
owner_user_id = Column(String, index=True, nullable=True)
|
||||
created_at = Column(Float, default=time.time)
|
||||
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||||
|
||||
@@ -54,6 +84,62 @@ class AppConfig(Base):
|
||||
description = Column(Text, nullable=True)
|
||||
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
id = Column(String, primary_key=True) # uuid hex
|
||||
username = Column(String, unique=True, index=True)
|
||||
password_hash = Column(String)
|
||||
password_salt = Column(String)
|
||||
role = Column(String, default="user") # admin/user
|
||||
is_active = Column(Integer, default=1) # 1/0 for portability
|
||||
created_at = Column(Float, default=time.time)
|
||||
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||||
last_login_at = Column(Float, nullable=True)
|
||||
|
||||
|
||||
class UserSession(Base):
|
||||
__tablename__ = "user_sessions"
|
||||
id = Column(String, primary_key=True) # uuid hex
|
||||
user_id = Column(String, index=True)
|
||||
token_hash = Column(String, index=True)
|
||||
expires_at = Column(Float)
|
||||
created_at = Column(Float, default=time.time)
|
||||
last_seen_at = Column(Float, default=time.time)
|
||||
ip = Column(String, nullable=True)
|
||||
user_agent = Column(String, nullable=True)
|
||||
|
||||
|
||||
class UserPrompt(Base):
|
||||
__tablename__ = "user_prompts"
|
||||
id = Column(String, primary_key=True) # uuid hex
|
||||
user_id = Column(String, index=True)
|
||||
key = Column(String, index=True)
|
||||
value = Column(Text)
|
||||
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||||
__table_args__ = (UniqueConstraint("user_id", "key", name="uix_user_prompt"),)
|
||||
|
||||
|
||||
class RenderJob(Base):
|
||||
"""
|
||||
Render job status table (async compose pipeline).
|
||||
This is intentionally minimal and portable across SQLite/Postgres.
|
||||
"""
|
||||
__tablename__ = "render_jobs"
|
||||
|
||||
id = Column(String, primary_key=True) # task_id
|
||||
project_id = Column(String, index=True)
|
||||
status = Column(String, default="queued") # queued/running/success/failed/cancelled
|
||||
progress = Column(Float, default=0.0)
|
||||
message = Column(Text, default="")
|
||||
output_path = Column(Text, nullable=True)
|
||||
output_url = Column(Text, nullable=True)
|
||||
error = Column(Text, nullable=True)
|
||||
request_json = Column(Text, nullable=True) # JSON string of ComposeRequest
|
||||
parent_id = Column(String, nullable=True, index=True)
|
||||
created_at = Column(Float, default=time.time)
|
||||
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||||
|
||||
class DBManager:
|
||||
def __init__(self, connection_string: str = None):
|
||||
if not connection_string:
|
||||
@@ -62,17 +148,59 @@ class DBManager:
|
||||
self.engine = create_engine(connection_string, pool_recycle=3600)
|
||||
self.Session = scoped_session(sessionmaker(bind=self.engine))
|
||||
self._init_db()
|
||||
# bootstrap admin (safe to call repeatedly)
|
||||
try:
|
||||
self._bootstrap_admin_id = self.ensure_admin_user("admin", "admin1234")
|
||||
except Exception:
|
||||
self._bootstrap_admin_id = None
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化表结构"""
|
||||
"""初始化表结构 + 轻量自迁移(不依赖 Alembic)"""
|
||||
Base.metadata.create_all(self.engine)
|
||||
self._ensure_schema()
|
||||
|
||||
def _ensure_schema(self) -> None:
|
||||
"""
|
||||
Ensure newer columns/tables exist in both Postgres and SQLite.
|
||||
create_all will not add columns to existing tables, so we do a minimal ALTER here.
|
||||
"""
|
||||
try:
|
||||
insp = inspect(self.engine)
|
||||
# projects.owner_user_id
|
||||
try:
|
||||
cols = [c["name"] for c in insp.get_columns("projects")]
|
||||
except Exception:
|
||||
cols = []
|
||||
if "owner_user_id" not in cols:
|
||||
logger.info("Migrating: add projects.owner_user_id")
|
||||
with self.engine.begin() as conn:
|
||||
conn.execute(text("ALTER TABLE projects ADD COLUMN owner_user_id VARCHAR"))
|
||||
try:
|
||||
conn.execute(text("CREATE INDEX IF NOT EXISTS ix_projects_owner_user_id ON projects (owner_user_id)"))
|
||||
except Exception:
|
||||
# some engines (older sqlite) may not support IF NOT EXISTS
|
||||
try:
|
||||
conn.execute(text("CREATE INDEX ix_projects_owner_user_id ON projects (owner_user_id)"))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Schema ensure skipped/failed (non-fatal): {e}")
|
||||
|
||||
# Ensure render_jobs table exists (create_all should handle this; keep for safety)
|
||||
try:
|
||||
insp = inspect(self.engine)
|
||||
if "render_jobs" not in insp.get_table_names():
|
||||
logger.info("Migrating: create render_jobs table")
|
||||
Base.metadata.create_all(self.engine)
|
||||
except Exception as e:
|
||||
logger.warning(f"render_jobs ensure skipped/failed (non-fatal): {e}")
|
||||
|
||||
def _get_session(self):
|
||||
return self.Session()
|
||||
|
||||
# --- Project Operations ---
|
||||
|
||||
def create_project(self, project_id: str, name: str, product_info: Dict[str, Any]):
|
||||
def create_project(self, project_id: str, name: str, product_info: Dict[str, Any], owner_user_id: Optional[str] = None):
|
||||
session = self._get_session()
|
||||
try:
|
||||
# Check if exists
|
||||
@@ -86,6 +214,7 @@ class DBManager:
|
||||
name=name,
|
||||
status="created",
|
||||
product_info=json.dumps(product_info, ensure_ascii=False),
|
||||
owner_user_id=owner_user_id,
|
||||
created_at=time.time(),
|
||||
updated_at=time.time()
|
||||
)
|
||||
@@ -157,6 +286,7 @@ class DBManager:
|
||||
"status": project.status,
|
||||
"product_info": json.loads(project.product_info) if project.product_info else {},
|
||||
"script_data": json.loads(project.script_data) if project.script_data else None,
|
||||
"owner_user_id": getattr(project, "owner_user_id", None),
|
||||
"created_at": project.created_at,
|
||||
"updated_at": project.updated_at
|
||||
}
|
||||
@@ -175,12 +305,288 @@ class DBManager:
|
||||
"id": p.id,
|
||||
"name": p.name,
|
||||
"status": p.status,
|
||||
"owner_user_id": getattr(p, "owner_user_id", None),
|
||||
"updated_at": p.updated_at
|
||||
})
|
||||
return results
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# --- User/Auth Operations ---
|
||||
|
||||
def ensure_admin_user(self, username: str = "admin", password: str = "admin1234") -> str:
|
||||
"""Create bootstrap admin user if missing. Returns admin user_id."""
|
||||
session = self._get_session()
|
||||
try:
|
||||
u = session.query(User).filter_by(username=username).first()
|
||||
if u:
|
||||
return u.id
|
||||
pwd_hash, salt_hex = _hash_password(password)
|
||||
uid = secrets.token_hex(16)
|
||||
u = User(
|
||||
id=uid,
|
||||
username=username,
|
||||
password_hash=pwd_hash,
|
||||
password_salt=salt_hex,
|
||||
role="admin",
|
||||
is_active=1,
|
||||
created_at=time.time(),
|
||||
updated_at=time.time(),
|
||||
)
|
||||
session.add(u)
|
||||
session.commit()
|
||||
logger.warning("Bootstrap admin created (please change password asap).")
|
||||
return uid
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"ensure_admin_user failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
|
||||
session = self._get_session()
|
||||
try:
|
||||
u = session.query(User).filter_by(username=username).first()
|
||||
if not u or int(getattr(u, "is_active", 1) or 0) != 1:
|
||||
return None
|
||||
if not _verify_password(password or "", getattr(u, "password_hash", ""), getattr(u, "password_salt", "")):
|
||||
return None
|
||||
u.last_login_at = time.time()
|
||||
session.commit()
|
||||
return {"id": u.id, "username": u.username, "role": u.role, "is_active": int(u.is_active or 0)}
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"authenticate_user error: {e}")
|
||||
return None
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def create_session(self, user_id: str, *, ttl_seconds: int = 7 * 24 * 3600, ip: str = None, user_agent: str = None) -> str:
|
||||
"""Returns raw session token (store hash in DB)."""
|
||||
token = _new_session_token()
|
||||
token_hash = _hash_token(token)
|
||||
sid = secrets.token_hex(16)
|
||||
session = self._get_session()
|
||||
try:
|
||||
now = time.time()
|
||||
s = UserSession(
|
||||
id=sid,
|
||||
user_id=user_id,
|
||||
token_hash=token_hash,
|
||||
expires_at=now + int(ttl_seconds),
|
||||
created_at=now,
|
||||
last_seen_at=now,
|
||||
ip=ip,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
session.add(s)
|
||||
session.commit()
|
||||
return token
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"create_session error: {e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def validate_session(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
if not token:
|
||||
return None
|
||||
token_hash = _hash_token(token)
|
||||
session = self._get_session()
|
||||
try:
|
||||
now = time.time()
|
||||
s = session.query(UserSession).filter_by(token_hash=token_hash).first()
|
||||
if not s or (s.expires_at and float(s.expires_at) < now):
|
||||
return None
|
||||
u = session.query(User).filter_by(id=s.user_id).first()
|
||||
if not u or int(getattr(u, "is_active", 1) or 0) != 1:
|
||||
return None
|
||||
s.last_seen_at = now
|
||||
session.commit()
|
||||
return {"id": u.id, "username": u.username, "role": u.role, "is_active": int(u.is_active or 0)}
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"validate_session error: {e}")
|
||||
return None
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def revoke_session(self, token: str) -> None:
|
||||
if not token:
|
||||
return
|
||||
token_hash = _hash_token(token)
|
||||
session = self._get_session()
|
||||
try:
|
||||
s = session.query(UserSession).filter_by(token_hash=token_hash).first()
|
||||
if s:
|
||||
session.delete(s)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"revoke_session error: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def list_users(self) -> List[Dict[str, Any]]:
|
||||
session = self._get_session()
|
||||
try:
|
||||
users = session.query(User).order_by(User.created_at.asc()).all()
|
||||
return [
|
||||
{
|
||||
"id": u.id,
|
||||
"username": u.username,
|
||||
"role": u.role,
|
||||
"is_active": int(u.is_active or 0),
|
||||
"created_at": u.created_at,
|
||||
"last_login_at": u.last_login_at,
|
||||
}
|
||||
for u in users
|
||||
]
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def upsert_user(self, username: str, password: str, *, role: str = "user", is_active: int = 1) -> str:
|
||||
session = self._get_session()
|
||||
try:
|
||||
u = session.query(User).filter_by(username=username).first()
|
||||
pwd_hash, salt_hex = _hash_password(password)
|
||||
now = time.time()
|
||||
if u:
|
||||
u.password_hash = pwd_hash
|
||||
u.password_salt = salt_hex
|
||||
u.role = role
|
||||
u.is_active = int(is_active)
|
||||
u.updated_at = now
|
||||
session.commit()
|
||||
return u.id
|
||||
uid = secrets.token_hex(16)
|
||||
u = User(
|
||||
id=uid,
|
||||
username=username,
|
||||
password_hash=pwd_hash,
|
||||
password_salt=salt_hex,
|
||||
role=role,
|
||||
is_active=int(is_active),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
session.add(u)
|
||||
session.commit()
|
||||
return uid
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"upsert_user error: {e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def set_user_active(self, user_id: str, is_active: int) -> None:
|
||||
session = self._get_session()
|
||||
try:
|
||||
u = session.query(User).filter_by(id=user_id).first()
|
||||
if not u:
|
||||
return
|
||||
u.is_active = int(is_active)
|
||||
u.updated_at = time.time()
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"set_user_active error: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def reset_user_password(self, user_id: str, new_password: str) -> None:
|
||||
session = self._get_session()
|
||||
try:
|
||||
u = session.query(User).filter_by(id=user_id).first()
|
||||
if not u:
|
||||
return
|
||||
pwd_hash, salt_hex = _hash_password(new_password)
|
||||
u.password_hash = pwd_hash
|
||||
u.password_salt = salt_hex
|
||||
u.updated_at = time.time()
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"reset_user_password error: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_user_prompt(self, user_id: str, key: str) -> Optional[str]:
|
||||
if not user_id or not key:
|
||||
return None
|
||||
session = self._get_session()
|
||||
try:
|
||||
p = session.query(UserPrompt).filter_by(user_id=user_id, key=key).first()
|
||||
return p.value if p else None
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def set_user_prompt(self, user_id: str, key: str, value: Any) -> None:
|
||||
if not user_id or not key:
|
||||
return
|
||||
session = self._get_session()
|
||||
try:
|
||||
now = time.time()
|
||||
v = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value
|
||||
p = session.query(UserPrompt).filter_by(user_id=user_id, key=key).first()
|
||||
if p:
|
||||
p.value = v
|
||||
p.updated_at = now
|
||||
else:
|
||||
p = UserPrompt(id=secrets.token_hex(16), user_id=user_id, key=key, value=v, updated_at=now)
|
||||
session.add(p)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"set_user_prompt error: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# --- RBAC helpers ---
|
||||
def list_projects_for_user(self, user: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
if not user:
|
||||
return []
|
||||
if user.get("role") == "admin":
|
||||
return self.list_projects()
|
||||
uid = user.get("id")
|
||||
session = self._get_session()
|
||||
try:
|
||||
projects = session.query(Project).filter_by(owner_user_id=uid).order_by(Project.updated_at.desc()).all()
|
||||
return [{"id": p.id, "name": p.name, "status": p.status, "owner_user_id": p.owner_user_id, "updated_at": p.updated_at} for p in projects]
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_project_for_user(self, project_id: str, user: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
data = self.get_project(project_id)
|
||||
if not data or not user:
|
||||
return None
|
||||
if user.get("role") == "admin":
|
||||
return data
|
||||
if data.get("owner_user_id") != user.get("id"):
|
||||
return None
|
||||
return data
|
||||
|
||||
def migrate_projects_owner_to(self, owner_user_id: str) -> int:
|
||||
"""Assign owner_user_id for legacy projects where it is NULL/empty."""
|
||||
session = self._get_session()
|
||||
try:
|
||||
q = session.query(Project).filter((Project.owner_user_id == None) | (Project.owner_user_id == "")) # noqa: E711
|
||||
rows = q.all()
|
||||
for p in rows:
|
||||
p.owner_user_id = owner_user_id
|
||||
p.updated_at = time.time()
|
||||
session.commit()
|
||||
return len(rows)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"migrate_projects_owner_to error: {e}")
|
||||
return 0
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# --- Asset/Task Operations ---
|
||||
|
||||
def save_asset(self, project_id: str, scene_id: int, asset_type: str,
|
||||
@@ -253,6 +659,92 @@ class DBManager:
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
# --- Render Job Operations ---
|
||||
|
||||
def create_render_job(
|
||||
self,
|
||||
job_id: str,
|
||||
project_id: str,
|
||||
*,
|
||||
status: str = "queued",
|
||||
progress: float = 0.0,
|
||||
message: str = "",
|
||||
request: Optional[Dict[str, Any]] = None,
|
||||
parent_id: Optional[str] = None,
|
||||
) -> None:
|
||||
session = self._get_session()
|
||||
try:
|
||||
existing = session.query(RenderJob).filter_by(id=job_id).first()
|
||||
if existing:
|
||||
return
|
||||
now = time.time()
|
||||
rj = RenderJob(
|
||||
id=job_id,
|
||||
project_id=project_id,
|
||||
status=status,
|
||||
progress=float(progress or 0.0),
|
||||
message=message or "",
|
||||
request_json=json.dumps(request, ensure_ascii=False) if request is not None else None,
|
||||
parent_id=parent_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
session.add(rj)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"create_render_job error: {e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def update_render_job(self, job_id: str, patch: Dict[str, Any]) -> None:
|
||||
if not job_id or not patch:
|
||||
return
|
||||
session = self._get_session()
|
||||
try:
|
||||
rj = session.query(RenderJob).filter_by(id=job_id).first()
|
||||
if not rj:
|
||||
return
|
||||
for k, v in patch.items():
|
||||
if not hasattr(rj, k):
|
||||
continue
|
||||
setattr(rj, k, v)
|
||||
rj.updated_at = time.time()
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"update_render_job error: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_render_job(self, job_id: str) -> Optional[Dict[str, Any]]:
|
||||
session = self._get_session()
|
||||
try:
|
||||
rj = session.query(RenderJob).filter_by(id=job_id).first()
|
||||
if not rj:
|
||||
return None
|
||||
try:
|
||||
req = json.loads(rj.request_json) if rj.request_json else None
|
||||
except Exception:
|
||||
req = None
|
||||
return {
|
||||
"id": rj.id,
|
||||
"project_id": rj.project_id,
|
||||
"status": rj.status,
|
||||
"progress": float(rj.progress or 0.0),
|
||||
"message": rj.message or "",
|
||||
"output_path": rj.output_path,
|
||||
"output_url": rj.output_url,
|
||||
"error": rj.error,
|
||||
"request": req,
|
||||
"parent_id": rj.parent_id,
|
||||
"created_at": rj.created_at,
|
||||
"updated_at": rj.updated_at,
|
||||
}
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_asset(self, project_id: str, scene_id: int, asset_type: str) -> Optional[Dict[str, Any]]:
|
||||
session = self._get_session()
|
||||
try:
|
||||
@@ -279,6 +771,30 @@ class DBManager:
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_asset_by_id(self, asset_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""通过自增 id 获取素材记录(用于 file proxy)。"""
|
||||
session = self._get_session()
|
||||
try:
|
||||
a = session.query(SceneAsset).filter_by(id=int(asset_id)).first()
|
||||
if not a:
|
||||
return None
|
||||
return {
|
||||
"id": a.id,
|
||||
"project_id": a.project_id,
|
||||
"scene_id": a.scene_id,
|
||||
"asset_type": a.asset_type,
|
||||
"status": a.status,
|
||||
"local_path": a.local_path,
|
||||
"remote_url": a.remote_url,
|
||||
"task_id": a.task_id,
|
||||
"metadata": json.loads(a.metadata_json) if a.metadata_json else {},
|
||||
"updated_at": a.updated_at,
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def update_asset_metadata(self, project_id: str, scene_id: int, asset_type: str, patch: Dict[str, Any]) -> None:
|
||||
"""Merge-patch asset.metadata JSON without overwriting other fields."""
|
||||
if not patch:
|
||||
|
||||
@@ -172,7 +172,8 @@ def get_video_info(video_path: str) -> Dict[str, Any]:
|
||||
def concat_videos(
|
||||
video_paths: List[str],
|
||||
output_path: str,
|
||||
target_size: Tuple[int, int] = (1080, 1920)
|
||||
target_size: Tuple[int, int] = (1080, 1920),
|
||||
fades: Optional[List[Dict[str, float]]] = None
|
||||
) -> str:
|
||||
"""
|
||||
使用 FFmpeg concat demuxer 拼接多段视频
|
||||
@@ -197,10 +198,72 @@ def concat_videos(
|
||||
filter_parts = []
|
||||
for i in range(len(video_paths)):
|
||||
# scale 保持宽高比,pad 填充黑边居中
|
||||
filter_parts.append(
|
||||
chain = (
|
||||
f"[{i}:v]scale={width}:{height}:force_original_aspect_ratio=decrease,"
|
||||
f"pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black,setsar=1[v{i}]"
|
||||
f"pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black,setsar=1"
|
||||
)
|
||||
# 可选:片段末尾“火山式转场”(不改时长、不重叠)
|
||||
if fades and i < len(fades):
|
||||
fx = fades[i] or {}
|
||||
fi = float(fx.get("in", 0) or 0.0)
|
||||
fo = float(fx.get("out", 0) or 0.0)
|
||||
t_type = str(fx.get("type") or "")
|
||||
t_dur = float(fx.get("dur") or 0.0)
|
||||
|
||||
try:
|
||||
dur = float(get_video_info(video_paths[i]).get("duration") or 0.0)
|
||||
except Exception:
|
||||
dur = 0.0
|
||||
|
||||
# 基础淡入/淡出
|
||||
if fi > 0:
|
||||
chain += f",fade=t=in:st=0:d={fi}"
|
||||
if fo > 0 and dur > 0:
|
||||
st = max(dur - fo, 0.0)
|
||||
chain += f",fade=t=out:st={st}:d={fo}"
|
||||
|
||||
# 末尾动效(WYSIWYG:前端预览必须与此一致)
|
||||
if t_type and t_dur > 0 and dur > 0:
|
||||
st = max(dur - t_dur, 0.0)
|
||||
td = max(t_dur, 0.001)
|
||||
p = f"if(between(t\\,{st}\\,{dur})\\,(t-{st})/{td}\\,0)"
|
||||
if t_type == "fade":
|
||||
chain += f",fade=t=out:st={st}:d={t_dur}"
|
||||
elif t_type == "fadeWhite":
|
||||
chain += f",fade=t=out:st={st}:d={t_dur}:color=white"
|
||||
elif t_type == "blurOut":
|
||||
chain += f",gblur=sigma='10*{p}':steps=1"
|
||||
elif t_type == "blurFade":
|
||||
chain += f",gblur=sigma='8*{p}':steps=1,fade=t=out:st={st}:d={t_dur}"
|
||||
elif t_type == "flash":
|
||||
chain += f",eq=brightness='0.7*(1-abs(0.5-{p})*2)'"
|
||||
elif t_type == "desaturate":
|
||||
chain += f",hue=s='1-0.9*{p}'"
|
||||
elif t_type == "colorPop":
|
||||
chain += f",hue=s='1+0.8*{p}',eq=contrast='1+0.3*{p}'"
|
||||
elif t_type == "hueShift":
|
||||
chain += f",hue=h='60*{p}'"
|
||||
elif t_type == "darken":
|
||||
chain += f",eq=brightness='-0.4*{p}'"
|
||||
elif t_type in ("slideLeft", "slideRight", "slideUp", "slideDown"):
|
||||
off = 80
|
||||
if t_type == "slideLeft":
|
||||
chain += f",pad={width+off}:{height}:{off/2}-{off}*{p}:0:black,crop={width}:{height}:{off/2}:0"
|
||||
if t_type == "slideRight":
|
||||
chain += f",pad={width+off}:{height}:{off/2}+{off}*{p}:0:black,crop={width}:{height}:{off/2}:0"
|
||||
if t_type == "slideUp":
|
||||
chain += f",pad={width}:{height+off}:0:{off/2}-{off}*{p}:black,crop={width}:{height}:0:{off/2}"
|
||||
if t_type == "slideDown":
|
||||
chain += f",pad={width}:{height+off}:0:{off/2}+{off}*{p}:black,crop={width}:{height}:0:{off/2}"
|
||||
elif t_type in ("zoomOut", "zoomIn"):
|
||||
if t_type == "zoomOut":
|
||||
chain += f",scale=w='{width}*(1-0.10*{p})':h='{height}*(1-0.10*{p})':eval=frame,pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black"
|
||||
else:
|
||||
chain += f",scale=w='{width}*(1+0.10*{p})':h='{height}*(1+0.10*{p})':eval=frame,crop={width}:{height}"
|
||||
elif t_type == "rotateOut":
|
||||
chain += f",rotate=a='0.12*{p}':c=black@1:ow={width}:oh={height}"
|
||||
chain += f"[v{i}]"
|
||||
filter_parts.append(chain)
|
||||
|
||||
# 拼接所有视频流
|
||||
concat_inputs = "".join([f"[v{i}]" for i in range(len(video_paths))])
|
||||
@@ -232,7 +295,8 @@ def concat_videos(
|
||||
def concat_videos_with_audio(
|
||||
video_paths: List[str],
|
||||
output_path: str,
|
||||
target_size: Tuple[int, int] = (1080, 1920)
|
||||
target_size: Tuple[int, int] = (1080, 1920),
|
||||
fades: Optional[List[Dict[str, float]]] = None
|
||||
) -> str:
|
||||
"""
|
||||
拼接视频并保留音频轨道
|
||||
@@ -250,10 +314,70 @@ def concat_videos_with_audio(
|
||||
|
||||
# 视频处理
|
||||
for i in range(n):
|
||||
filter_parts.append(
|
||||
chain = (
|
||||
f"[{i}:v]scale={width}:{height}:force_original_aspect_ratio=decrease,"
|
||||
f"pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black,setsar=1[v{i}]"
|
||||
f"pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black,setsar=1"
|
||||
)
|
||||
# 可选:片段末尾“火山式转场”(不改时长、不重叠)
|
||||
if fades and i < len(fades):
|
||||
fx = fades[i] or {}
|
||||
fi = float(fx.get("in", 0) or 0.0)
|
||||
fo = float(fx.get("out", 0) or 0.0)
|
||||
t_type = str(fx.get("type") or "")
|
||||
t_dur = float(fx.get("dur") or 0.0)
|
||||
|
||||
try:
|
||||
dur = float(get_video_info(video_paths[i]).get("duration") or 0.0)
|
||||
except Exception:
|
||||
dur = 0.0
|
||||
|
||||
if fi > 0:
|
||||
chain += f",fade=t=in:st=0:d={fi}"
|
||||
if fo > 0 and dur > 0:
|
||||
st = max(dur - fo, 0.0)
|
||||
chain += f",fade=t=out:st={st}:d={fo}"
|
||||
|
||||
if t_type and t_dur > 0 and dur > 0:
|
||||
st = max(dur - t_dur, 0.0)
|
||||
td = max(t_dur, 0.001)
|
||||
p = f"if(between(t\\,{st}\\,{dur})\\,(t-{st})/{td}\\,0)"
|
||||
if t_type == "fade":
|
||||
chain += f",fade=t=out:st={st}:d={t_dur}"
|
||||
elif t_type == "fadeWhite":
|
||||
chain += f",fade=t=out:st={st}:d={t_dur}:color=white"
|
||||
elif t_type == "blurOut":
|
||||
chain += f",gblur=sigma='10*{p}':steps=1"
|
||||
elif t_type == "blurFade":
|
||||
chain += f",gblur=sigma='8*{p}':steps=1,fade=t=out:st={st}:d={t_dur}"
|
||||
elif t_type == "flash":
|
||||
chain += f",eq=brightness='0.7*(1-abs(0.5-{p})*2)'"
|
||||
elif t_type == "desaturate":
|
||||
chain += f",hue=s='1-0.9*{p}'"
|
||||
elif t_type == "colorPop":
|
||||
chain += f",hue=s='1+0.8*{p}',eq=contrast='1+0.3*{p}'"
|
||||
elif t_type == "hueShift":
|
||||
chain += f",hue=h='60*{p}'"
|
||||
elif t_type == "darken":
|
||||
chain += f",eq=brightness='-0.4*{p}'"
|
||||
elif t_type in ("slideLeft", "slideRight", "slideUp", "slideDown"):
|
||||
off = 80
|
||||
if t_type == "slideLeft":
|
||||
chain += f",pad={width+off}:{height}:{off/2}-{off}*{p}:0:black,crop={width}:{height}:{off/2}:0"
|
||||
if t_type == "slideRight":
|
||||
chain += f",pad={width+off}:{height}:{off/2}+{off}*{p}:0:black,crop={width}:{height}:{off/2}:0"
|
||||
if t_type == "slideUp":
|
||||
chain += f",pad={width}:{height+off}:0:{off/2}-{off}*{p}:black,crop={width}:{height}:0:{off/2}"
|
||||
if t_type == "slideDown":
|
||||
chain += f",pad={width}:{height+off}:0:{off/2}+{off}*{p}:black,crop={width}:{height}:0:{off/2}"
|
||||
elif t_type in ("zoomOut", "zoomIn"):
|
||||
if t_type == "zoomOut":
|
||||
chain += f",scale=w='{width}*(1-0.10*{p})':h='{height}*(1-0.10*{p})':eval=frame,pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black"
|
||||
else:
|
||||
chain += f",scale=w='{width}*(1+0.10*{p})':h='{height}*(1+0.10*{p})':eval=frame,crop={width}:{height}"
|
||||
elif t_type == "rotateOut":
|
||||
chain += f",rotate=a='0.12*{p}':c=black@1:ow={width}:oh={height}"
|
||||
chain += f"[v{i}]"
|
||||
filter_parts.append(chain)
|
||||
|
||||
# 音频处理(静音填充如果没有音频)
|
||||
for i in range(n):
|
||||
@@ -469,6 +593,127 @@ def adjust_audio_duration(
|
||||
return output_path
|
||||
|
||||
|
||||
def _atempo_chain(speed: float) -> str:
|
||||
"""
|
||||
构造 atempo 链,支持 <0.5 或 >2.0 的倍速(通过链式 atempo)。
|
||||
"""
|
||||
try:
|
||||
s = float(speed)
|
||||
except Exception:
|
||||
s = 1.0
|
||||
if s <= 0:
|
||||
s = 1.0
|
||||
parts = []
|
||||
# atempo 支持 0.5~2.0
|
||||
while s > 2.0:
|
||||
parts.append("atempo=2.0")
|
||||
s /= 2.0
|
||||
while s < 0.5:
|
||||
parts.append("atempo=0.5")
|
||||
s /= 0.5
|
||||
parts.append(f"atempo={s}")
|
||||
return ",".join(parts)
|
||||
|
||||
|
||||
def change_audio_speed(input_path: str, speed: float, output_path: str) -> str:
|
||||
"""改变音频播放倍速(纯播放倍速)。"""
|
||||
if not os.path.exists(input_path):
|
||||
return None
|
||||
af = _atempo_chain(speed)
|
||||
cmd = [FFMPEG_PATH, "-y", "-i", input_path, "-filter:a", af, output_path]
|
||||
_run_ffmpeg(cmd)
|
||||
return output_path
|
||||
|
||||
|
||||
def fit_audio_to_duration_by_speed(input_path: str, target_duration: float, output_path: str) -> str:
|
||||
"""
|
||||
通过“改变播放倍速”来贴合目标时长(可快可慢),并裁剪/补齐到严格时长。
|
||||
适用于旁白:用户拉伸片段期望语速变化,而不是静音补齐。
|
||||
"""
|
||||
if not os.path.exists(input_path):
|
||||
return None
|
||||
try:
|
||||
td = float(target_duration or 0)
|
||||
except Exception:
|
||||
td = 0.0
|
||||
if td <= 0:
|
||||
import shutil
|
||||
shutil.copy(input_path, output_path)
|
||||
return output_path
|
||||
|
||||
cur = float(get_audio_info(input_path).get("duration") or 0.0)
|
||||
if cur <= 0:
|
||||
import shutil
|
||||
shutil.copy(input_path, output_path)
|
||||
return output_path
|
||||
|
||||
speed = cur / td
|
||||
af_speed = _atempo_chain(speed)
|
||||
# 贴合后仍做一次 atrim+apad 保证严格时长(避免累计误差)
|
||||
af = f"{af_speed},atrim=0:{td},apad=pad_dur=0,atrim=0:{td}"
|
||||
cmd = [FFMPEG_PATH, "-y", "-i", input_path, "-filter:a", af, output_path]
|
||||
_run_ffmpeg(cmd)
|
||||
return output_path
|
||||
|
||||
|
||||
def force_audio_duration(input_path: str, target_duration: float, output_path: str) -> str:
|
||||
"""不改变倍速,仅裁剪/补齐到严格时长(用于倍速已在上游完成的场景)。"""
|
||||
if not os.path.exists(input_path):
|
||||
return None
|
||||
try:
|
||||
td = float(target_duration or 0)
|
||||
except Exception:
|
||||
td = 0.0
|
||||
if td <= 0:
|
||||
import shutil
|
||||
shutil.copy(input_path, output_path)
|
||||
return output_path
|
||||
af = f"atrim=0:{td},apad=pad_dur=0,atrim=0:{td}"
|
||||
cmd = [FFMPEG_PATH, "-y", "-i", input_path, "-filter:a", af, output_path]
|
||||
_run_ffmpeg(cmd)
|
||||
return output_path
|
||||
|
||||
|
||||
def _which(cmd: str) -> Optional[str]:
|
||||
try:
|
||||
import shutil
|
||||
return shutil.which(cmd)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def normalize_sticker_to_png(input_path: str, output_path: str) -> str:
|
||||
"""
|
||||
将贴纸规范化为 PNG(用于 ffmpeg overlay)。
|
||||
- PNG/WEBP:直接返回原图或拷贝
|
||||
- SVG:优先用 rsvg-convert 转 PNG;否则尝试 ffmpeg 直接解码
|
||||
"""
|
||||
if not input_path or not os.path.exists(input_path):
|
||||
return None
|
||||
ext = Path(input_path).suffix.lower()
|
||||
if ext in [".png"]:
|
||||
return input_path
|
||||
if ext in [".webp"]:
|
||||
# 转 PNG,避免某些 ffmpeg build 对 webp 支持不一致
|
||||
cmd = [FFMPEG_PATH, "-y", "-i", input_path, output_path]
|
||||
_run_ffmpeg(cmd)
|
||||
return output_path
|
||||
if ext == ".svg":
|
||||
rsvg = _which("rsvg-convert")
|
||||
if rsvg:
|
||||
import subprocess
|
||||
subprocess.check_call([rsvg, "-o", output_path, input_path])
|
||||
return output_path
|
||||
# fallback: ffmpeg decode svg(依赖 build)
|
||||
cmd = [FFMPEG_PATH, "-y", "-i", input_path, output_path]
|
||||
_run_ffmpeg(cmd)
|
||||
return output_path
|
||||
# 其他格式:尽量用 ffmpeg 转
|
||||
cmd = [FFMPEG_PATH, "-y", "-i", input_path, output_path]
|
||||
_run_ffmpeg(cmd)
|
||||
return output_path
|
||||
|
||||
|
||||
def get_audio_info(file_path: str) -> Dict[str, Any]:
|
||||
"""获取音频信息"""
|
||||
return get_video_info(file_path)
|
||||
@@ -830,6 +1075,12 @@ def add_bgm(
|
||||
loop: bool = True,
|
||||
ducking: bool = True,
|
||||
duck_gain_db: float = -6.0,
|
||||
# 新增:按时间段闪避(更可控,和旁白时间轴严格对齐)
|
||||
duck_volume: float = 0.25,
|
||||
duck_ranges: Optional[List[Tuple[float, float]]] = None,
|
||||
# 新增:BGM 片段可有起点/时长(不强制从 0 覆盖整段视频)
|
||||
start_time: float = 0.0,
|
||||
clip_duration: Optional[float] = None,
|
||||
fade_in: float = 1.0,
|
||||
fade_out: float = 1.0
|
||||
) -> str:
|
||||
@@ -856,30 +1107,46 @@ def add_bgm(
|
||||
info = get_video_info(video_path)
|
||||
video_duration = info["duration"]
|
||||
|
||||
if loop:
|
||||
bgm_chain = (
|
||||
f"[1:a]aloop=-1:size=2e+09,asetpts=N/SR/TB,"
|
||||
f"atrim=0:{video_duration},"
|
||||
f"afade=t=in:st=0:d={fade_in},"
|
||||
f"afade=t=out:st={max(video_duration - fade_out, 0)}:d={fade_out},"
|
||||
f"volume={bgm_volume}[bgm]"
|
||||
)
|
||||
else:
|
||||
bgm_chain = (
|
||||
f"[1:a]"
|
||||
f"afade=t=in:st=0:d={fade_in},"
|
||||
f"afade=t=out:st={max(video_duration - fade_out, 0)}:d={fade_out},"
|
||||
f"volume={bgm_volume}[bgm]"
|
||||
)
|
||||
# 片段时长:默认覆盖整段视频
|
||||
dur = float(clip_duration) if (clip_duration is not None and float(clip_duration) > 0) else float(video_duration)
|
||||
st = max(0.0, float(start_time or 0.0))
|
||||
end_for_fade = max(dur - float(fade_out or 0.0), 0.0)
|
||||
|
||||
if ducking:
|
||||
# 使用安全参数的 sidechaincompress,避免 unsupported 参数
|
||||
# 基础链:loop/trim -> fades -> base volume
|
||||
if loop:
|
||||
bgm_chain = f"[1:a]aloop=-1:size=2e+09,asetpts=N/SR/TB,atrim=0:{dur}"
|
||||
else:
|
||||
bgm_chain = f"[1:a]atrim=0:{dur}"
|
||||
|
||||
bgm_chain += f",afade=t=in:st=0:d={float(fade_in or 0.0)},afade=t=out:st={end_for_fade}:d={float(fade_out or 0.0)},volume={bgm_volume}"
|
||||
|
||||
# 延迟到 start_time
|
||||
if st > 1e-6:
|
||||
ms = int(st * 1000)
|
||||
bgm_chain += f",adelay={ms}|{ms}"
|
||||
|
||||
# 闪避(按时间段)
|
||||
# 注意:使用 enable 让 filter 只在区间内生效(外部直接 passthrough)
|
||||
if ducking and duck_ranges:
|
||||
dv = max(0.05, min(1.0, float(duck_volume or 0.25)))
|
||||
for (rs, re) in duck_ranges:
|
||||
rsf = max(0.0, float(rs))
|
||||
ref = max(rsf, float(re))
|
||||
bgm_chain += f",volume={dv}:enable='between(t,{rsf},{ref})'"
|
||||
|
||||
bgm_chain += "[bgm]"
|
||||
|
||||
# 如果提供了 duck_ranges,就用确定性的 amix(ducking 已在 bgm_chain 内完成)
|
||||
if ducking and duck_ranges:
|
||||
filter_complex = f"{bgm_chain};[0:a][bgm]amix=inputs=2:duration=first:dropout_transition=0:normalize=0[outa]"
|
||||
elif ducking:
|
||||
# 否则退回 sidechaincompress(对原视频音频进行侧链压缩)
|
||||
filter_complex = (
|
||||
f"{bgm_chain};"
|
||||
f"[0:a][bgm]sidechaincompress=threshold=0.1:ratio=4:attack=5:release=250:makeup=1:mix=1:level_in=1:level_sc=1[outa]"
|
||||
)
|
||||
else:
|
||||
filter_complex = f"{bgm_chain};[0:a][bgm]amix=inputs=2:duration=first[outa]"
|
||||
filter_complex = f"{bgm_chain};[0:a][bgm]amix=inputs=2:duration=first:dropout_transition=0:normalize=0[outa]"
|
||||
|
||||
cmd = [
|
||||
FFMPEG_PATH, "-y",
|
||||
|
||||
@@ -246,3 +246,9 @@ def normalize_legacy_project(doc: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Optional, Tuple
|
||||
|
||||
LEGACY_HOST_TEMP_PREFIX = "/root/video-flow/temp/"
|
||||
LEGACY_HOST_OUTPUT_PREFIX = "/root/video-flow/output/"
|
||||
LEGACY_HOST_PREFIX = "/root/video-flow/"
|
||||
|
||||
# Container mount points (see docker-compose.yml)
|
||||
LEGACY_CONTAINER_TEMP_DIR = "/legacy/temp"
|
||||
@@ -42,18 +43,27 @@ def map_legacy_local_path(local_path: Optional[str]) -> Tuple[Optional[str], Opt
|
||||
if os.path.exists(local_path):
|
||||
return local_path, None
|
||||
|
||||
# Legacy host -> container mapping by basename
|
||||
# Legacy host path -> current container workspace path (same repo but different prefix)
|
||||
# Example:
|
||||
# /root/video-flow/temp/projects/... -> /app/temp/projects/...
|
||||
# This covers cases where we don't mount /legacy/* but the files were copied into current stack.
|
||||
if local_path.startswith(LEGACY_HOST_PREFIX):
|
||||
rest = local_path[len(LEGACY_HOST_PREFIX):].lstrip("/")
|
||||
candidate = str(Path("/app") / rest)
|
||||
if os.path.exists(candidate):
|
||||
return candidate, None
|
||||
|
||||
# Legacy host -> container mapping (preserve relative path)
|
||||
if local_path.startswith(LEGACY_HOST_TEMP_PREFIX):
|
||||
name = Path(local_path).name
|
||||
container_path = str(Path(LEGACY_CONTAINER_TEMP_DIR) / name)
|
||||
url = f"{LEGACY_STATIC_TEMP_PREFIX}{name}"
|
||||
return container_path, url
|
||||
rel = local_path[len(LEGACY_HOST_TEMP_PREFIX):].lstrip("/")
|
||||
container_path = str(Path(LEGACY_CONTAINER_TEMP_DIR) / rel)
|
||||
# 静态路由通常只覆盖目录根(不包含子目录);这里交给 /api/assets/file 做 FileResponse 更稳
|
||||
return container_path, None
|
||||
|
||||
if local_path.startswith(LEGACY_HOST_OUTPUT_PREFIX):
|
||||
name = Path(local_path).name
|
||||
container_path = str(Path(LEGACY_CONTAINER_OUTPUT_DIR) / name)
|
||||
url = f"{LEGACY_STATIC_OUTPUT_PREFIX}{name}"
|
||||
return container_path, url
|
||||
rel = local_path[len(LEGACY_HOST_OUTPUT_PREFIX):].lstrip("/")
|
||||
container_path = str(Path(LEGACY_CONTAINER_OUTPUT_DIR) / rel)
|
||||
return container_path, None
|
||||
|
||||
# Unknown path: keep as-is
|
||||
return local_path, None
|
||||
@@ -64,3 +74,9 @@ def map_legacy_local_path(local_path: Optional[str]) -> Tuple[Optional[str], Opt
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
102
modules/preview_proxy.py
Normal file
102
modules/preview_proxy.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Generate and cache low-bitrate preview proxies for browser playback.
|
||||
|
||||
Goal:
|
||||
- Improve Remotion Player preview smoothness by serving smaller/faster-to-decode videos.
|
||||
- Keep original `source_path` for accurate export; only swap `source_url` for preview.
|
||||
|
||||
Design:
|
||||
- Deterministic cache key based on (path, mtime, size).
|
||||
- Proxy lives under config.TEMP_DIR / "proxy".
|
||||
- Generated with ffmpeg: scale/pad + fps downsample + faststart + no audio.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import config
|
||||
from modules import ffmpeg_utils
|
||||
|
||||
|
||||
def _key_for_file(path: str) -> str:
|
||||
st = os.stat(path)
|
||||
raw = f"{path}|{st.st_mtime_ns}|{st.st_size}".encode("utf-8")
|
||||
return hashlib.sha1(raw).hexdigest() # short, deterministic
|
||||
|
||||
|
||||
def ensure_video_proxy(
|
||||
source_path: str,
|
||||
*,
|
||||
target_w: int = 540,
|
||||
target_h: int = 960,
|
||||
target_fps: int = 24,
|
||||
crf: int = 28,
|
||||
preset: str = "veryfast",
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Ensure a preview proxy exists for source_path.
|
||||
|
||||
Returns: (proxy_path, proxy_url) or (None, None) if source_path invalid.
|
||||
"""
|
||||
if not source_path or not os.path.exists(source_path):
|
||||
return None, None
|
||||
|
||||
proxy_dir = Path(config.TEMP_DIR) / "proxy"
|
||||
proxy_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
key = _key_for_file(source_path)
|
||||
out_name = f"proxy_{key}.mp4"
|
||||
out_path = proxy_dir / out_name
|
||||
|
||||
if out_path.exists() and out_path.stat().st_size > 1024:
|
||||
return str(out_path), f"/static/temp/proxy/{out_name}"
|
||||
|
||||
vf = (
|
||||
f"scale={target_w}:{target_h}:force_original_aspect_ratio=decrease,"
|
||||
f"pad={target_w}:{target_h}:(ow-iw)/2:(oh-ih)/2:black,"
|
||||
f"fps={target_fps}"
|
||||
)
|
||||
|
||||
cmd = [
|
||||
ffmpeg_utils.FFMPEG_PATH,
|
||||
"-y",
|
||||
"-i",
|
||||
source_path,
|
||||
"-an", # preview 不要音轨,减少解码负担(旁白/BGM 走单独轨道)
|
||||
"-vf",
|
||||
vf,
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
preset,
|
||||
"-crf",
|
||||
str(crf),
|
||||
"-tune",
|
||||
"fastdecode",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
"-movflags",
|
||||
"+faststart",
|
||||
str(out_path),
|
||||
]
|
||||
|
||||
ffmpeg_utils._run_ffmpeg(cmd)
|
||||
if out_path.exists() and out_path.stat().st_size > 1024:
|
||||
return str(out_path), f"/static/temp/proxy/{out_name}"
|
||||
return None, None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pathlib import Path
|
||||
@@ -27,7 +28,10 @@ class ScriptGenerator:
|
||||
# OpenAI-compatible client for ShuBiaoBiao (supports multiple models incl. GPT)
|
||||
self.shubiaobiao_client = OpenAI(
|
||||
api_key=config.SHUBIAOBIAO_KEY,
|
||||
base_url=config.SHUBIAOBIAO_BASE_URL
|
||||
base_url=config.SHUBIAOBIAO_BASE_URL,
|
||||
# IMPORTANT: OpenAI SDK default timeout is 10 minutes; cap it to keep UX responsive.
|
||||
timeout=config.SHUBIAOBIAO_CHAT_TIMEOUT_S,
|
||||
max_retries=config.SHUBIAOBIAO_CHAT_MAX_RETRIES,
|
||||
)
|
||||
|
||||
# Default System Prompt
|
||||
@@ -139,15 +143,23 @@ class ScriptGenerator:
|
||||
product_name: str,
|
||||
product_info: Dict[str, Any],
|
||||
image_paths: List[str] = None,
|
||||
model_provider: str = "shubiaobiao" # "shubiaobiao" or "doubao"
|
||||
model_provider: str = "shubiaobiao", # "shubiaobiao" or "doubao"
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成分镜脚本
|
||||
"""
|
||||
logger.info(f"Generating script for: {product_name} (Provider: {model_provider})")
|
||||
|
||||
# 1. 构造 Prompt (优先从数据库读取配置)
|
||||
system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt)
|
||||
# 1. 构造 Prompt (优先按 user_id 读取;否则回退到全局配置,再回退默认)
|
||||
system_prompt = None
|
||||
if user_id:
|
||||
try:
|
||||
system_prompt = db.get_user_prompt(user_id, "prompt_script_gen")
|
||||
except Exception:
|
||||
system_prompt = None
|
||||
if not system_prompt:
|
||||
system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt)
|
||||
user_prompt = self._build_user_prompt(product_name, product_info)
|
||||
|
||||
# Branch for Doubao (Volcengine)
|
||||
@@ -293,21 +305,40 @@ class ScriptGenerator:
|
||||
ShuBiaoBiao OpenAI-compatible multimodal chat.
|
||||
IMPORTANT: For ShuBiaoBiao models, we pass image URLs (R2 public URLs), not base64.
|
||||
"""
|
||||
t0 = time.time()
|
||||
# Use WARNING level so it shows up even if Streamlit/root logger is not configured to INFO.
|
||||
logger.warning(
|
||||
f"[script_gen] start shubiaobiao chat model={model_name} images={len(image_paths or [])} "
|
||||
f"timeout_s={getattr(config, 'SHUBIAOBIAO_CHAT_TIMEOUT_S', 'n/a')} "
|
||||
f"max_retries={getattr(config, 'SHUBIAOBIAO_CHAT_MAX_RETRIES', 'n/a')}"
|
||||
)
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
user_content: List[Dict[str, Any]] = []
|
||||
# Images first (URL), then text
|
||||
t_upload0 = time.time()
|
||||
urls = self._upload_images_to_r2(image_paths or [], limit=10)
|
||||
logger.warning(
|
||||
f"[script_gen] r2_upload done urls={len(urls)} elapsed_s={time.time() - t_upload0:.2f}"
|
||||
)
|
||||
for url in urls:
|
||||
user_content.append({"type": "image_url", "image_url": {"url": url}})
|
||||
user_content.append({"type": "text", "text": user_prompt})
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
try:
|
||||
resp = self.shubiaobiao_client.chat.completions.create(
|
||||
client = self.shubiaobiao_client.with_options(
|
||||
timeout=config.SHUBIAOBIAO_CHAT_TIMEOUT_S,
|
||||
max_retries=config.SHUBIAOBIAO_CHAT_MAX_RETRIES,
|
||||
)
|
||||
t_call0 = time.time()
|
||||
resp = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
)
|
||||
logger.warning(
|
||||
f"[script_gen] shubiaobiao chat done elapsed_s={time.time() - t_call0:.2f} total_s={time.time() - t0:.2f}"
|
||||
)
|
||||
content_text = (resp.choices[0].message.content or "").strip()
|
||||
script_json = self._extract_json_from_response(content_text)
|
||||
if script_json is None:
|
||||
@@ -323,7 +354,9 @@ class ScriptGenerator:
|
||||
}
|
||||
return final_script
|
||||
except Exception as e:
|
||||
logger.error(f"shubiaobiao script generation failed ({model_name}): {e}")
|
||||
logger.error(
|
||||
f"shubiaobiao script generation failed ({model_name}) after {time.time() - t0:.2f}s: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _postprocess_selling_points(self, product_info: Dict[str, Any], selling_points: Any) -> List[str]:
|
||||
@@ -582,7 +615,33 @@ class ScriptGenerator:
|
||||
|
||||
def _validate_and_fix_script(self, script: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""校验并修复脚本结构"""
|
||||
# 简单校验,确保必要字段存在
|
||||
if "scenes" not in script:
|
||||
if not isinstance(script, dict):
|
||||
return {"scenes": []}
|
||||
|
||||
# Ensure fields exist
|
||||
if "scenes" not in script or not isinstance(script.get("scenes"), list):
|
||||
script["scenes"] = []
|
||||
|
||||
# Normalize: keep visual_anchor at top-level, but avoid repeating the full anchor in every scene.visual_prompt.
|
||||
# Reason: repeating a long anchor 4-5 times explodes tokens and makes UI look like "only three sections",
|
||||
# while image generation already supports passing visual_anchor separately and prepending it at runtime.
|
||||
visual_anchor = script.get("visual_anchor") or ""
|
||||
if isinstance(visual_anchor, str) and visual_anchor.strip() and script["scenes"]:
|
||||
anchor = visual_anchor.strip()
|
||||
prefix = f"[{anchor}]"
|
||||
for scene in script["scenes"]:
|
||||
if not isinstance(scene, dict):
|
||||
continue
|
||||
vp = scene.get("visual_prompt")
|
||||
if not isinstance(vp, str) or not vp.strip():
|
||||
continue
|
||||
s = vp.strip()
|
||||
# Strip exact "[anchor]" prefix if present
|
||||
if s.startswith(prefix):
|
||||
s = s[len(prefix):].lstrip()
|
||||
# If the model output copied the raw anchor without brackets, strip it too
|
||||
elif s.startswith(anchor):
|
||||
s = s[len(anchor):].lstrip()
|
||||
scene["visual_prompt"] = s
|
||||
|
||||
return script
|
||||
|
||||
@@ -90,6 +90,79 @@ class TextRenderer:
|
||||
return color
|
||||
return (0, 0, 0, 255)
|
||||
|
||||
def _wrap_text_to_width(self, text: str, font: ImageFont.FreeTypeFont, max_width: int) -> str:
|
||||
"""
|
||||
将文本按最大宽度自动换行(支持中英文混排)。
|
||||
- 保留原始换行符为段落边界
|
||||
- 英文优先按空格断词;中文按字符贪心换行
|
||||
"""
|
||||
try:
|
||||
mw = int(max_width or 0)
|
||||
except Exception:
|
||||
mw = 0
|
||||
if mw <= 0:
|
||||
return text
|
||||
|
||||
# 兼容:去掉末尾多余空行
|
||||
raw_paras = (text or "").split("\n")
|
||||
out_lines: List[str] = []
|
||||
|
||||
# 1x1 dummy draw 用于测量
|
||||
dummy_draw = ImageDraw.Draw(Image.new("RGBA", (1, 1)))
|
||||
|
||||
def text_w(s: str) -> float:
|
||||
try:
|
||||
return float(dummy_draw.textlength(s, font=font))
|
||||
except Exception:
|
||||
bbox = dummy_draw.textbbox((0, 0), s, font=font)
|
||||
return float((bbox[2] - bbox[0]) if bbox else 0)
|
||||
|
||||
for para in raw_paras:
|
||||
p = (para or "").rstrip()
|
||||
if not p:
|
||||
out_lines.append("")
|
||||
continue
|
||||
|
||||
# 英文/混排:尝试按空格分词,否则按字符
|
||||
use_words = (" " in p)
|
||||
tokens = p.split(" ") if use_words else list(p)
|
||||
|
||||
cur = ""
|
||||
for tok in tokens:
|
||||
cand = (cur + (" " if (use_words and cur) else "") + tok) if cur else tok
|
||||
if text_w(cand) <= mw:
|
||||
cur = cand
|
||||
continue
|
||||
|
||||
# 当前行放不下:先落一行
|
||||
if cur:
|
||||
out_lines.append(cur)
|
||||
cur = tok
|
||||
else:
|
||||
# 单 token 超宽:强制按字符拆
|
||||
if use_words:
|
||||
chars = list(tok)
|
||||
else:
|
||||
chars = [tok]
|
||||
buf = ""
|
||||
for ch in chars:
|
||||
cand2 = buf + ch
|
||||
if text_w(cand2) <= mw or not buf:
|
||||
buf = cand2
|
||||
else:
|
||||
out_lines.append(buf)
|
||||
buf = ch
|
||||
cur = buf
|
||||
|
||||
if cur:
|
||||
out_lines.append(cur)
|
||||
|
||||
# 去掉尾部空行(保持中间空行)
|
||||
while out_lines and out_lines[-1] == "":
|
||||
out_lines.pop()
|
||||
|
||||
return "\n".join(out_lines)
|
||||
|
||||
def render(self, text: str, style: Union[Dict[str, Any], str], cache: bool = True) -> str:
|
||||
"""
|
||||
渲染文本并返回图片路径
|
||||
@@ -122,14 +195,29 @@ class TextRenderer:
|
||||
font_size = style.get("font_size", 60)
|
||||
font = self._get_font(font_path, font_size)
|
||||
font_color = self._parse_color(style.get("font_color", "#FFFFFF"))
|
||||
bold = bool(style.get("bold", False))
|
||||
italic = bool(style.get("italic", False))
|
||||
underline = bool(style.get("underline", False))
|
||||
|
||||
# 3. 测量文本尺寸
|
||||
# 3. 自动换行(可选)
|
||||
max_width = style.get("max_width") or style.get("maxWidth") or style.get("text_box_width")
|
||||
try:
|
||||
max_width = int(max_width) if max_width is not None else 0
|
||||
except Exception:
|
||||
max_width = 0
|
||||
if max_width > 0:
|
||||
text = self._wrap_text_to_width(text, font, max_width)
|
||||
|
||||
# 4. 测量文本尺寸(支持多行)
|
||||
dummy_draw = ImageDraw.Draw(Image.new("RGBA", (1, 1)))
|
||||
bbox = dummy_draw.textbbox((0, 0), text, font=font)
|
||||
text_w = bbox[2] - bbox[0]
|
||||
text_h = bbox[3] - bbox[1]
|
||||
try:
|
||||
bbox = dummy_draw.multiline_textbbox((0, 0), text, font=font, spacing=int(font_size * 0.25), align="center")
|
||||
except Exception:
|
||||
bbox = dummy_draw.textbbox((0, 0), text, font=font)
|
||||
text_w = (bbox[2] - bbox[0]) if bbox else 0
|
||||
text_h = (bbox[3] - bbox[1]) if bbox else 0
|
||||
|
||||
# 4. 计算总尺寸 (包含 padding, stroke, shadow)
|
||||
# 5. 计算总尺寸 (包含 padding, stroke, shadow)
|
||||
strokes = style.get("stroke", [])
|
||||
if isinstance(strokes, dict): strokes = [strokes] # 兼容旧格式
|
||||
|
||||
@@ -156,7 +244,7 @@ class TextRenderer:
|
||||
canvas_w = content_w + extra_margin * 2
|
||||
canvas_h = content_h + extra_margin * 2
|
||||
|
||||
# 5. 创建画布
|
||||
# 6. 创建画布
|
||||
img = Image.new("RGBA", (int(canvas_w), int(canvas_h)), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
@@ -164,7 +252,7 @@ class TextRenderer:
|
||||
center_x = canvas_w // 2
|
||||
center_y = canvas_h // 2
|
||||
|
||||
# 6. 绘制顺序: 阴影 -> 背景 -> 描边 -> 文本
|
||||
# 7. 绘制顺序: 阴影 -> 背景 -> 描边 -> 文本
|
||||
|
||||
# --- 绘制阴影 (针对整个块) ---
|
||||
if shadow:
|
||||
@@ -183,7 +271,11 @@ class TextRenderer:
|
||||
# 文字阴影
|
||||
txt_x = center_x - text_w / 2
|
||||
txt_y = center_y - text_h / 2
|
||||
shadow_draw.text((txt_x, txt_y), text, font=font, fill=shadow_color)
|
||||
# 多行阴影
|
||||
try:
|
||||
shadow_draw.multiline_text((txt_x, txt_y), text, font=font, fill=shadow_color, spacing=int(font_size * 0.25), align="center")
|
||||
except Exception:
|
||||
shadow_draw.text((txt_x, txt_y), text, font=font, fill=shadow_color)
|
||||
# 描边阴影
|
||||
for s in strokes:
|
||||
width = s.get("width", 0)
|
||||
@@ -217,10 +309,55 @@ class TextRenderer:
|
||||
width = s.get("width", 0)
|
||||
if width > 0:
|
||||
# 通过偏移模拟描边 (Pillow stroke_width 效果一般,但这里先用原生参数)
|
||||
draw.text((txt_x, txt_y), text, font=font, fill=color, stroke_width=width, stroke_fill=color)
|
||||
try:
|
||||
draw.multiline_text((txt_x, txt_y), text, font=font, fill=color, spacing=int(font_size * 0.25), align="center", stroke_width=width, stroke_fill=color)
|
||||
except Exception:
|
||||
draw.text((txt_x, txt_y), text, font=font, fill=color, stroke_width=width, stroke_fill=color)
|
||||
|
||||
# --- 绘制文字 ---
|
||||
draw.text((txt_x, txt_y), text, font=font, fill=font_color)
|
||||
# italic:通过仿射变换做简单斜体(先绘制到单独图层,再 shear)
|
||||
# bold:通过多次微小偏移叠加模拟加粗(比改 stroke 更接近“字重”)
|
||||
if italic:
|
||||
text_layer = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
||||
text_draw = ImageDraw.Draw(text_layer)
|
||||
if bold:
|
||||
for dx in (0, 1):
|
||||
try:
|
||||
text_draw.multiline_text((txt_x + dx, txt_y), text, font=font, fill=font_color, spacing=int(font_size * 0.25), align="center")
|
||||
except Exception:
|
||||
text_draw.text((txt_x + dx, txt_y), text, font=font, fill=font_color)
|
||||
else:
|
||||
try:
|
||||
text_draw.multiline_text((txt_x, txt_y), text, font=font, fill=font_color, spacing=int(font_size * 0.25), align="center")
|
||||
except Exception:
|
||||
text_draw.text((txt_x, txt_y), text, font=font, fill=font_color)
|
||||
shear = 0.22 # 经验值:适中倾斜
|
||||
text_layer = text_layer.transform(
|
||||
text_layer.size,
|
||||
Image.AFFINE,
|
||||
(1, shear, 0, 0, 1, 0),
|
||||
resample=Image.BICUBIC
|
||||
)
|
||||
img = Image.alpha_composite(img, text_layer)
|
||||
draw = ImageDraw.Draw(img)
|
||||
else:
|
||||
if bold:
|
||||
for dx in (0, 1):
|
||||
try:
|
||||
draw.multiline_text((txt_x + dx, txt_y), text, font=font, fill=font_color, spacing=int(font_size * 0.25), align="center")
|
||||
except Exception:
|
||||
draw.text((txt_x + dx, txt_y), text, font=font, fill=font_color)
|
||||
else:
|
||||
try:
|
||||
draw.multiline_text((txt_x, txt_y), text, font=font, fill=font_color, spacing=int(font_size * 0.25), align="center")
|
||||
except Exception:
|
||||
draw.text((txt_x, txt_y), text, font=font, fill=font_color)
|
||||
|
||||
# underline:在文本底部画线(与字号相关)
|
||||
if underline:
|
||||
line_y = txt_y + text_h + max(2, int(font_size * 0.08))
|
||||
line_th = max(2, int(font_size * 0.06))
|
||||
draw.rectangle([txt_x, line_y, txt_x + text_w, line_y + line_th], fill=font_color)
|
||||
|
||||
# 7. 裁剪多余透明区域
|
||||
bbox = img.getbbox()
|
||||
|
||||
Reference in New Issue
Block a user