chore: sync code and project files

This commit is contained in:
Tony Zhang
2026-01-09 14:09:16 +08:00
parent 3d1fb37769
commit 30d7eb4b35
94 changed files with 12706 additions and 255 deletions

48
modules/auth.py Normal file
View File

@@ -0,0 +1,48 @@
"""
Auth helpers: password hashing + cookie token hashing.
We intentionally avoid heavy dependencies. Password hashing uses PBKDF2-HMAC-SHA256.
Session tokens are random and stored server-side as SHA256(token) hashes.
"""
from __future__ import annotations
import hashlib
import secrets
from typing import Optional, Tuple
PBKDF2_ITERS = 200_000
def hash_password(password: str, salt_hex: Optional[str] = None) -> Tuple[str, str]:
salt = bytes.fromhex(salt_hex) if salt_hex else secrets.token_bytes(16)
dk = hashlib.pbkdf2_hmac("sha256", (password or "").encode("utf-8"), salt, PBKDF2_ITERS)
return dk.hex(), salt.hex()
def verify_password(password: str, password_hash: str, salt_hex: str) -> bool:
cand, _ = hash_password(password, salt_hex=salt_hex)
return cand == (password_hash or "")
def new_session_token() -> str:
return secrets.token_urlsafe(32)
def hash_token(token: str) -> str:
return hashlib.sha256((token or "").encode("utf-8")).hexdigest()

View File

@@ -502,11 +502,20 @@ class VideoComposer:
current_video = fancy_path
# 7. 添加 BGM
# 说明add_bgm 的 ducking=True 路径使用 sidechaincompress但该滤镜本身不做“混音”
# 在某些 ffmpeg 版本/参数组合下会导致 BGM 听起来像“没加上”。
# 我们在 compose() 里已禁用 ducking这里保持一致使用 amix 叠加并提高默认音量。
if bgm_path:
bgm_output = str(Path(temp_root) / f"{output_name}_bgm.mp4")
ffmpeg_utils.add_bgm(
current_video, bgm_path, bgm_output,
bgm_volume=0.15
current_video,
bgm_path,
bgm_output,
bgm_volume=0.20,
ducking=False,
duck_gain_db=-6.0,
fade_in=1.0,
fade_out=1.0,
)
self._add_temp(bgm_output)
current_video = bgm_output

View File

@@ -6,14 +6,43 @@
import json
import logging
import time
from typing import Dict, List, Any, Optional
import secrets
from typing import Dict, List, Any, Optional, Tuple
from sqlalchemy import create_engine, Column, String, Integer, Text, Float, UniqueConstraint, func
from sqlalchemy import create_engine, Column, String, Integer, Text, Float, UniqueConstraint, func, text, inspect
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
from sqlalchemy.dialects.postgresql import JSONB
import config
# NOTE: Some deployments do not ship `modules/auth.py`.
# Keep a minimal, local auth helper here to avoid hard dependency.
import hashlib
import hmac
def _hash_password(password: str, salt_hex: str = None) -> tuple[str, str]:
salt = bytes.fromhex(salt_hex) if salt_hex else secrets.token_bytes(16)
dk = hashlib.pbkdf2_hmac("sha256", (password or "").encode("utf-8"), salt, 120_000)
return dk.hex(), salt.hex()
def _verify_password(password: str, pwd_hash_hex: str, salt_hex: str) -> bool:
try:
cand, _ = _hash_password(password or "", salt_hex=salt_hex)
return hmac.compare_digest(cand, pwd_hash_hex or "")
except Exception:
return False
def _new_session_token() -> str:
return secrets.token_urlsafe(32)
def _hash_token(token: str) -> str:
# store only hash for sessions
return hashlib.sha256((token or "").encode("utf-8")).hexdigest()
logger = logging.getLogger(__name__)
Base = declarative_base()
@@ -26,6 +55,7 @@ class Project(Base):
status = Column(String) # created, script_generated, images_generated, videos_generated, completed
product_info = Column(Text) # JSON string (SQLite) or JSONB (PG - using Text for compat)
script_data = Column(Text) # JSON string
owner_user_id = Column(String, index=True, nullable=True)
created_at = Column(Float, default=time.time)
updated_at = Column(Float, default=time.time, onupdate=time.time)
@@ -54,6 +84,62 @@ class AppConfig(Base):
description = Column(Text, nullable=True)
updated_at = Column(Float, default=time.time, onupdate=time.time)
class User(Base):
__tablename__ = "users"
id = Column(String, primary_key=True) # uuid hex
username = Column(String, unique=True, index=True)
password_hash = Column(String)
password_salt = Column(String)
role = Column(String, default="user") # admin/user
is_active = Column(Integer, default=1) # 1/0 for portability
created_at = Column(Float, default=time.time)
updated_at = Column(Float, default=time.time, onupdate=time.time)
last_login_at = Column(Float, nullable=True)
class UserSession(Base):
__tablename__ = "user_sessions"
id = Column(String, primary_key=True) # uuid hex
user_id = Column(String, index=True)
token_hash = Column(String, index=True)
expires_at = Column(Float)
created_at = Column(Float, default=time.time)
last_seen_at = Column(Float, default=time.time)
ip = Column(String, nullable=True)
user_agent = Column(String, nullable=True)
class UserPrompt(Base):
__tablename__ = "user_prompts"
id = Column(String, primary_key=True) # uuid hex
user_id = Column(String, index=True)
key = Column(String, index=True)
value = Column(Text)
updated_at = Column(Float, default=time.time, onupdate=time.time)
__table_args__ = (UniqueConstraint("user_id", "key", name="uix_user_prompt"),)
class RenderJob(Base):
"""
Render job status table (async compose pipeline).
This is intentionally minimal and portable across SQLite/Postgres.
"""
__tablename__ = "render_jobs"
id = Column(String, primary_key=True) # task_id
project_id = Column(String, index=True)
status = Column(String, default="queued") # queued/running/success/failed/cancelled
progress = Column(Float, default=0.0)
message = Column(Text, default="")
output_path = Column(Text, nullable=True)
output_url = Column(Text, nullable=True)
error = Column(Text, nullable=True)
request_json = Column(Text, nullable=True) # JSON string of ComposeRequest
parent_id = Column(String, nullable=True, index=True)
created_at = Column(Float, default=time.time)
updated_at = Column(Float, default=time.time, onupdate=time.time)
class DBManager:
def __init__(self, connection_string: str = None):
if not connection_string:
@@ -62,17 +148,59 @@ class DBManager:
self.engine = create_engine(connection_string, pool_recycle=3600)
self.Session = scoped_session(sessionmaker(bind=self.engine))
self._init_db()
# bootstrap admin (safe to call repeatedly)
try:
self._bootstrap_admin_id = self.ensure_admin_user("admin", "admin1234")
except Exception:
self._bootstrap_admin_id = None
def _init_db(self):
"""初始化表结构"""
"""初始化表结构 + 轻量自迁移(不依赖 Alembic"""
Base.metadata.create_all(self.engine)
self._ensure_schema()
def _ensure_schema(self) -> None:
"""
Ensure newer columns/tables exist in both Postgres and SQLite.
create_all will not add columns to existing tables, so we do a minimal ALTER here.
"""
try:
insp = inspect(self.engine)
# projects.owner_user_id
try:
cols = [c["name"] for c in insp.get_columns("projects")]
except Exception:
cols = []
if "owner_user_id" not in cols:
logger.info("Migrating: add projects.owner_user_id")
with self.engine.begin() as conn:
conn.execute(text("ALTER TABLE projects ADD COLUMN owner_user_id VARCHAR"))
try:
conn.execute(text("CREATE INDEX IF NOT EXISTS ix_projects_owner_user_id ON projects (owner_user_id)"))
except Exception:
# some engines (older sqlite) may not support IF NOT EXISTS
try:
conn.execute(text("CREATE INDEX ix_projects_owner_user_id ON projects (owner_user_id)"))
except Exception:
pass
except Exception as e:
logger.warning(f"Schema ensure skipped/failed (non-fatal): {e}")
# Ensure render_jobs table exists (create_all should handle this; keep for safety)
try:
insp = inspect(self.engine)
if "render_jobs" not in insp.get_table_names():
logger.info("Migrating: create render_jobs table")
Base.metadata.create_all(self.engine)
except Exception as e:
logger.warning(f"render_jobs ensure skipped/failed (non-fatal): {e}")
def _get_session(self):
return self.Session()
# --- Project Operations ---
def create_project(self, project_id: str, name: str, product_info: Dict[str, Any]):
def create_project(self, project_id: str, name: str, product_info: Dict[str, Any], owner_user_id: Optional[str] = None):
session = self._get_session()
try:
# Check if exists
@@ -86,6 +214,7 @@ class DBManager:
name=name,
status="created",
product_info=json.dumps(product_info, ensure_ascii=False),
owner_user_id=owner_user_id,
created_at=time.time(),
updated_at=time.time()
)
@@ -157,6 +286,7 @@ class DBManager:
"status": project.status,
"product_info": json.loads(project.product_info) if project.product_info else {},
"script_data": json.loads(project.script_data) if project.script_data else None,
"owner_user_id": getattr(project, "owner_user_id", None),
"created_at": project.created_at,
"updated_at": project.updated_at
}
@@ -175,12 +305,288 @@ class DBManager:
"id": p.id,
"name": p.name,
"status": p.status,
"owner_user_id": getattr(p, "owner_user_id", None),
"updated_at": p.updated_at
})
return results
finally:
session.close()
# --- User/Auth Operations ---
def ensure_admin_user(self, username: str = "admin", password: str = "admin1234") -> str:
"""Create bootstrap admin user if missing. Returns admin user_id."""
session = self._get_session()
try:
u = session.query(User).filter_by(username=username).first()
if u:
return u.id
pwd_hash, salt_hex = _hash_password(password)
uid = secrets.token_hex(16)
u = User(
id=uid,
username=username,
password_hash=pwd_hash,
password_salt=salt_hex,
role="admin",
is_active=1,
created_at=time.time(),
updated_at=time.time(),
)
session.add(u)
session.commit()
logger.warning("Bootstrap admin created (please change password asap).")
return uid
except Exception as e:
session.rollback()
logger.error(f"ensure_admin_user failed: {e}")
raise
finally:
session.close()
def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
session = self._get_session()
try:
u = session.query(User).filter_by(username=username).first()
if not u or int(getattr(u, "is_active", 1) or 0) != 1:
return None
if not _verify_password(password or "", getattr(u, "password_hash", ""), getattr(u, "password_salt", "")):
return None
u.last_login_at = time.time()
session.commit()
return {"id": u.id, "username": u.username, "role": u.role, "is_active": int(u.is_active or 0)}
except Exception as e:
session.rollback()
logger.error(f"authenticate_user error: {e}")
return None
finally:
session.close()
def create_session(self, user_id: str, *, ttl_seconds: int = 7 * 24 * 3600, ip: str = None, user_agent: str = None) -> str:
"""Returns raw session token (store hash in DB)."""
token = _new_session_token()
token_hash = _hash_token(token)
sid = secrets.token_hex(16)
session = self._get_session()
try:
now = time.time()
s = UserSession(
id=sid,
user_id=user_id,
token_hash=token_hash,
expires_at=now + int(ttl_seconds),
created_at=now,
last_seen_at=now,
ip=ip,
user_agent=user_agent,
)
session.add(s)
session.commit()
return token
except Exception as e:
session.rollback()
logger.error(f"create_session error: {e}")
raise
finally:
session.close()
def validate_session(self, token: str) -> Optional[Dict[str, Any]]:
if not token:
return None
token_hash = _hash_token(token)
session = self._get_session()
try:
now = time.time()
s = session.query(UserSession).filter_by(token_hash=token_hash).first()
if not s or (s.expires_at and float(s.expires_at) < now):
return None
u = session.query(User).filter_by(id=s.user_id).first()
if not u or int(getattr(u, "is_active", 1) or 0) != 1:
return None
s.last_seen_at = now
session.commit()
return {"id": u.id, "username": u.username, "role": u.role, "is_active": int(u.is_active or 0)}
except Exception as e:
session.rollback()
logger.error(f"validate_session error: {e}")
return None
finally:
session.close()
def revoke_session(self, token: str) -> None:
if not token:
return
token_hash = _hash_token(token)
session = self._get_session()
try:
s = session.query(UserSession).filter_by(token_hash=token_hash).first()
if s:
session.delete(s)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"revoke_session error: {e}")
finally:
session.close()
def list_users(self) -> List[Dict[str, Any]]:
session = self._get_session()
try:
users = session.query(User).order_by(User.created_at.asc()).all()
return [
{
"id": u.id,
"username": u.username,
"role": u.role,
"is_active": int(u.is_active or 0),
"created_at": u.created_at,
"last_login_at": u.last_login_at,
}
for u in users
]
finally:
session.close()
def upsert_user(self, username: str, password: str, *, role: str = "user", is_active: int = 1) -> str:
session = self._get_session()
try:
u = session.query(User).filter_by(username=username).first()
pwd_hash, salt_hex = _hash_password(password)
now = time.time()
if u:
u.password_hash = pwd_hash
u.password_salt = salt_hex
u.role = role
u.is_active = int(is_active)
u.updated_at = now
session.commit()
return u.id
uid = secrets.token_hex(16)
u = User(
id=uid,
username=username,
password_hash=pwd_hash,
password_salt=salt_hex,
role=role,
is_active=int(is_active),
created_at=now,
updated_at=now,
)
session.add(u)
session.commit()
return uid
except Exception as e:
session.rollback()
logger.error(f"upsert_user error: {e}")
raise
finally:
session.close()
def set_user_active(self, user_id: str, is_active: int) -> None:
session = self._get_session()
try:
u = session.query(User).filter_by(id=user_id).first()
if not u:
return
u.is_active = int(is_active)
u.updated_at = time.time()
session.commit()
except Exception as e:
session.rollback()
logger.error(f"set_user_active error: {e}")
finally:
session.close()
def reset_user_password(self, user_id: str, new_password: str) -> None:
session = self._get_session()
try:
u = session.query(User).filter_by(id=user_id).first()
if not u:
return
pwd_hash, salt_hex = _hash_password(new_password)
u.password_hash = pwd_hash
u.password_salt = salt_hex
u.updated_at = time.time()
session.commit()
except Exception as e:
session.rollback()
logger.error(f"reset_user_password error: {e}")
finally:
session.close()
def get_user_prompt(self, user_id: str, key: str) -> Optional[str]:
if not user_id or not key:
return None
session = self._get_session()
try:
p = session.query(UserPrompt).filter_by(user_id=user_id, key=key).first()
return p.value if p else None
finally:
session.close()
def set_user_prompt(self, user_id: str, key: str, value: Any) -> None:
if not user_id or not key:
return
session = self._get_session()
try:
now = time.time()
v = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value
p = session.query(UserPrompt).filter_by(user_id=user_id, key=key).first()
if p:
p.value = v
p.updated_at = now
else:
p = UserPrompt(id=secrets.token_hex(16), user_id=user_id, key=key, value=v, updated_at=now)
session.add(p)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"set_user_prompt error: {e}")
finally:
session.close()
# --- RBAC helpers ---
def list_projects_for_user(self, user: Dict[str, Any]) -> List[Dict[str, Any]]:
if not user:
return []
if user.get("role") == "admin":
return self.list_projects()
uid = user.get("id")
session = self._get_session()
try:
projects = session.query(Project).filter_by(owner_user_id=uid).order_by(Project.updated_at.desc()).all()
return [{"id": p.id, "name": p.name, "status": p.status, "owner_user_id": p.owner_user_id, "updated_at": p.updated_at} for p in projects]
finally:
session.close()
def get_project_for_user(self, project_id: str, user: Dict[str, Any]) -> Optional[Dict[str, Any]]:
data = self.get_project(project_id)
if not data or not user:
return None
if user.get("role") == "admin":
return data
if data.get("owner_user_id") != user.get("id"):
return None
return data
def migrate_projects_owner_to(self, owner_user_id: str) -> int:
"""Assign owner_user_id for legacy projects where it is NULL/empty."""
session = self._get_session()
try:
q = session.query(Project).filter((Project.owner_user_id == None) | (Project.owner_user_id == "")) # noqa: E711
rows = q.all()
for p in rows:
p.owner_user_id = owner_user_id
p.updated_at = time.time()
session.commit()
return len(rows)
except Exception as e:
session.rollback()
logger.error(f"migrate_projects_owner_to error: {e}")
return 0
finally:
session.close()
# --- Asset/Task Operations ---
def save_asset(self, project_id: str, scene_id: int, asset_type: str,
@@ -253,6 +659,92 @@ class DBManager:
finally:
session.close()
# --- Render Job Operations ---
def create_render_job(
self,
job_id: str,
project_id: str,
*,
status: str = "queued",
progress: float = 0.0,
message: str = "",
request: Optional[Dict[str, Any]] = None,
parent_id: Optional[str] = None,
) -> None:
session = self._get_session()
try:
existing = session.query(RenderJob).filter_by(id=job_id).first()
if existing:
return
now = time.time()
rj = RenderJob(
id=job_id,
project_id=project_id,
status=status,
progress=float(progress or 0.0),
message=message or "",
request_json=json.dumps(request, ensure_ascii=False) if request is not None else None,
parent_id=parent_id,
created_at=now,
updated_at=now,
)
session.add(rj)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"create_render_job error: {e}")
raise
finally:
session.close()
def update_render_job(self, job_id: str, patch: Dict[str, Any]) -> None:
if not job_id or not patch:
return
session = self._get_session()
try:
rj = session.query(RenderJob).filter_by(id=job_id).first()
if not rj:
return
for k, v in patch.items():
if not hasattr(rj, k):
continue
setattr(rj, k, v)
rj.updated_at = time.time()
session.commit()
except Exception as e:
session.rollback()
logger.error(f"update_render_job error: {e}")
finally:
session.close()
def get_render_job(self, job_id: str) -> Optional[Dict[str, Any]]:
session = self._get_session()
try:
rj = session.query(RenderJob).filter_by(id=job_id).first()
if not rj:
return None
try:
req = json.loads(rj.request_json) if rj.request_json else None
except Exception:
req = None
return {
"id": rj.id,
"project_id": rj.project_id,
"status": rj.status,
"progress": float(rj.progress or 0.0),
"message": rj.message or "",
"output_path": rj.output_path,
"output_url": rj.output_url,
"error": rj.error,
"request": req,
"parent_id": rj.parent_id,
"created_at": rj.created_at,
"updated_at": rj.updated_at,
}
finally:
session.close()
def get_asset(self, project_id: str, scene_id: int, asset_type: str) -> Optional[Dict[str, Any]]:
session = self._get_session()
try:
@@ -279,6 +771,30 @@ class DBManager:
finally:
session.close()
def get_asset_by_id(self, asset_id: int) -> Optional[Dict[str, Any]]:
"""通过自增 id 获取素材记录(用于 file proxy"""
session = self._get_session()
try:
a = session.query(SceneAsset).filter_by(id=int(asset_id)).first()
if not a:
return None
return {
"id": a.id,
"project_id": a.project_id,
"scene_id": a.scene_id,
"asset_type": a.asset_type,
"status": a.status,
"local_path": a.local_path,
"remote_url": a.remote_url,
"task_id": a.task_id,
"metadata": json.loads(a.metadata_json) if a.metadata_json else {},
"updated_at": a.updated_at,
}
except Exception:
return None
finally:
session.close()
def update_asset_metadata(self, project_id: str, scene_id: int, asset_type: str, patch: Dict[str, Any]) -> None:
"""Merge-patch asset.metadata JSON without overwriting other fields."""
if not patch:

View File

@@ -172,7 +172,8 @@ def get_video_info(video_path: str) -> Dict[str, Any]:
def concat_videos(
video_paths: List[str],
output_path: str,
target_size: Tuple[int, int] = (1080, 1920)
target_size: Tuple[int, int] = (1080, 1920),
fades: Optional[List[Dict[str, float]]] = None
) -> str:
"""
使用 FFmpeg concat demuxer 拼接多段视频
@@ -197,10 +198,72 @@ def concat_videos(
filter_parts = []
for i in range(len(video_paths)):
# scale 保持宽高比pad 填充黑边居中
filter_parts.append(
chain = (
f"[{i}:v]scale={width}:{height}:force_original_aspect_ratio=decrease,"
f"pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black,setsar=1[v{i}]"
f"pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black,setsar=1"
)
# 可选:片段末尾“火山式转场”(不改时长、不重叠)
if fades and i < len(fades):
fx = fades[i] or {}
fi = float(fx.get("in", 0) or 0.0)
fo = float(fx.get("out", 0) or 0.0)
t_type = str(fx.get("type") or "")
t_dur = float(fx.get("dur") or 0.0)
try:
dur = float(get_video_info(video_paths[i]).get("duration") or 0.0)
except Exception:
dur = 0.0
# 基础淡入/淡出
if fi > 0:
chain += f",fade=t=in:st=0:d={fi}"
if fo > 0 and dur > 0:
st = max(dur - fo, 0.0)
chain += f",fade=t=out:st={st}:d={fo}"
# 末尾动效WYSIWYG前端预览必须与此一致
if t_type and t_dur > 0 and dur > 0:
st = max(dur - t_dur, 0.0)
td = max(t_dur, 0.001)
p = f"if(between(t\\,{st}\\,{dur})\\,(t-{st})/{td}\\,0)"
if t_type == "fade":
chain += f",fade=t=out:st={st}:d={t_dur}"
elif t_type == "fadeWhite":
chain += f",fade=t=out:st={st}:d={t_dur}:color=white"
elif t_type == "blurOut":
chain += f",gblur=sigma='10*{p}':steps=1"
elif t_type == "blurFade":
chain += f",gblur=sigma='8*{p}':steps=1,fade=t=out:st={st}:d={t_dur}"
elif t_type == "flash":
chain += f",eq=brightness='0.7*(1-abs(0.5-{p})*2)'"
elif t_type == "desaturate":
chain += f",hue=s='1-0.9*{p}'"
elif t_type == "colorPop":
chain += f",hue=s='1+0.8*{p}',eq=contrast='1+0.3*{p}'"
elif t_type == "hueShift":
chain += f",hue=h='60*{p}'"
elif t_type == "darken":
chain += f",eq=brightness='-0.4*{p}'"
elif t_type in ("slideLeft", "slideRight", "slideUp", "slideDown"):
off = 80
if t_type == "slideLeft":
chain += f",pad={width+off}:{height}:{off/2}-{off}*{p}:0:black,crop={width}:{height}:{off/2}:0"
if t_type == "slideRight":
chain += f",pad={width+off}:{height}:{off/2}+{off}*{p}:0:black,crop={width}:{height}:{off/2}:0"
if t_type == "slideUp":
chain += f",pad={width}:{height+off}:0:{off/2}-{off}*{p}:black,crop={width}:{height}:0:{off/2}"
if t_type == "slideDown":
chain += f",pad={width}:{height+off}:0:{off/2}+{off}*{p}:black,crop={width}:{height}:0:{off/2}"
elif t_type in ("zoomOut", "zoomIn"):
if t_type == "zoomOut":
chain += f",scale=w='{width}*(1-0.10*{p})':h='{height}*(1-0.10*{p})':eval=frame,pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black"
else:
chain += f",scale=w='{width}*(1+0.10*{p})':h='{height}*(1+0.10*{p})':eval=frame,crop={width}:{height}"
elif t_type == "rotateOut":
chain += f",rotate=a='0.12*{p}':c=black@1:ow={width}:oh={height}"
chain += f"[v{i}]"
filter_parts.append(chain)
# 拼接所有视频流
concat_inputs = "".join([f"[v{i}]" for i in range(len(video_paths))])
@@ -232,7 +295,8 @@ def concat_videos(
def concat_videos_with_audio(
video_paths: List[str],
output_path: str,
target_size: Tuple[int, int] = (1080, 1920)
target_size: Tuple[int, int] = (1080, 1920),
fades: Optional[List[Dict[str, float]]] = None
) -> str:
"""
拼接视频并保留音频轨道
@@ -250,10 +314,70 @@ def concat_videos_with_audio(
# 视频处理
for i in range(n):
filter_parts.append(
chain = (
f"[{i}:v]scale={width}:{height}:force_original_aspect_ratio=decrease,"
f"pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black,setsar=1[v{i}]"
f"pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black,setsar=1"
)
# 可选:片段末尾“火山式转场”(不改时长、不重叠)
if fades and i < len(fades):
fx = fades[i] or {}
fi = float(fx.get("in", 0) or 0.0)
fo = float(fx.get("out", 0) or 0.0)
t_type = str(fx.get("type") or "")
t_dur = float(fx.get("dur") or 0.0)
try:
dur = float(get_video_info(video_paths[i]).get("duration") or 0.0)
except Exception:
dur = 0.0
if fi > 0:
chain += f",fade=t=in:st=0:d={fi}"
if fo > 0 and dur > 0:
st = max(dur - fo, 0.0)
chain += f",fade=t=out:st={st}:d={fo}"
if t_type and t_dur > 0 and dur > 0:
st = max(dur - t_dur, 0.0)
td = max(t_dur, 0.001)
p = f"if(between(t\\,{st}\\,{dur})\\,(t-{st})/{td}\\,0)"
if t_type == "fade":
chain += f",fade=t=out:st={st}:d={t_dur}"
elif t_type == "fadeWhite":
chain += f",fade=t=out:st={st}:d={t_dur}:color=white"
elif t_type == "blurOut":
chain += f",gblur=sigma='10*{p}':steps=1"
elif t_type == "blurFade":
chain += f",gblur=sigma='8*{p}':steps=1,fade=t=out:st={st}:d={t_dur}"
elif t_type == "flash":
chain += f",eq=brightness='0.7*(1-abs(0.5-{p})*2)'"
elif t_type == "desaturate":
chain += f",hue=s='1-0.9*{p}'"
elif t_type == "colorPop":
chain += f",hue=s='1+0.8*{p}',eq=contrast='1+0.3*{p}'"
elif t_type == "hueShift":
chain += f",hue=h='60*{p}'"
elif t_type == "darken":
chain += f",eq=brightness='-0.4*{p}'"
elif t_type in ("slideLeft", "slideRight", "slideUp", "slideDown"):
off = 80
if t_type == "slideLeft":
chain += f",pad={width+off}:{height}:{off/2}-{off}*{p}:0:black,crop={width}:{height}:{off/2}:0"
if t_type == "slideRight":
chain += f",pad={width+off}:{height}:{off/2}+{off}*{p}:0:black,crop={width}:{height}:{off/2}:0"
if t_type == "slideUp":
chain += f",pad={width}:{height+off}:0:{off/2}-{off}*{p}:black,crop={width}:{height}:0:{off/2}"
if t_type == "slideDown":
chain += f",pad={width}:{height+off}:0:{off/2}+{off}*{p}:black,crop={width}:{height}:0:{off/2}"
elif t_type in ("zoomOut", "zoomIn"):
if t_type == "zoomOut":
chain += f",scale=w='{width}*(1-0.10*{p})':h='{height}*(1-0.10*{p})':eval=frame,pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black"
else:
chain += f",scale=w='{width}*(1+0.10*{p})':h='{height}*(1+0.10*{p})':eval=frame,crop={width}:{height}"
elif t_type == "rotateOut":
chain += f",rotate=a='0.12*{p}':c=black@1:ow={width}:oh={height}"
chain += f"[v{i}]"
filter_parts.append(chain)
# 音频处理(静音填充如果没有音频)
for i in range(n):
@@ -469,6 +593,127 @@ def adjust_audio_duration(
return output_path
def _atempo_chain(speed: float) -> str:
"""
构造 atempo 链,支持 <0.5 或 >2.0 的倍速(通过链式 atempo
"""
try:
s = float(speed)
except Exception:
s = 1.0
if s <= 0:
s = 1.0
parts = []
# atempo 支持 0.5~2.0
while s > 2.0:
parts.append("atempo=2.0")
s /= 2.0
while s < 0.5:
parts.append("atempo=0.5")
s /= 0.5
parts.append(f"atempo={s}")
return ",".join(parts)
def change_audio_speed(input_path: str, speed: float, output_path: str) -> str:
"""改变音频播放倍速(纯播放倍速)。"""
if not os.path.exists(input_path):
return None
af = _atempo_chain(speed)
cmd = [FFMPEG_PATH, "-y", "-i", input_path, "-filter:a", af, output_path]
_run_ffmpeg(cmd)
return output_path
def fit_audio_to_duration_by_speed(input_path: str, target_duration: float, output_path: str) -> str:
"""
通过“改变播放倍速”来贴合目标时长(可快可慢),并裁剪/补齐到严格时长。
适用于旁白:用户拉伸片段期望语速变化,而不是静音补齐。
"""
if not os.path.exists(input_path):
return None
try:
td = float(target_duration or 0)
except Exception:
td = 0.0
if td <= 0:
import shutil
shutil.copy(input_path, output_path)
return output_path
cur = float(get_audio_info(input_path).get("duration") or 0.0)
if cur <= 0:
import shutil
shutil.copy(input_path, output_path)
return output_path
speed = cur / td
af_speed = _atempo_chain(speed)
# 贴合后仍做一次 atrim+apad 保证严格时长(避免累计误差)
af = f"{af_speed},atrim=0:{td},apad=pad_dur=0,atrim=0:{td}"
cmd = [FFMPEG_PATH, "-y", "-i", input_path, "-filter:a", af, output_path]
_run_ffmpeg(cmd)
return output_path
def force_audio_duration(input_path: str, target_duration: float, output_path: str) -> str:
"""不改变倍速,仅裁剪/补齐到严格时长(用于倍速已在上游完成的场景)。"""
if not os.path.exists(input_path):
return None
try:
td = float(target_duration or 0)
except Exception:
td = 0.0
if td <= 0:
import shutil
shutil.copy(input_path, output_path)
return output_path
af = f"atrim=0:{td},apad=pad_dur=0,atrim=0:{td}"
cmd = [FFMPEG_PATH, "-y", "-i", input_path, "-filter:a", af, output_path]
_run_ffmpeg(cmd)
return output_path
def _which(cmd: str) -> Optional[str]:
try:
import shutil
return shutil.which(cmd)
except Exception:
return None
def normalize_sticker_to_png(input_path: str, output_path: str) -> str:
"""
将贴纸规范化为 PNG用于 ffmpeg overlay
- PNG/WEBP直接返回原图或拷贝
- SVG优先用 rsvg-convert 转 PNG否则尝试 ffmpeg 直接解码
"""
if not input_path or not os.path.exists(input_path):
return None
ext = Path(input_path).suffix.lower()
if ext in [".png"]:
return input_path
if ext in [".webp"]:
# 转 PNG避免某些 ffmpeg build 对 webp 支持不一致
cmd = [FFMPEG_PATH, "-y", "-i", input_path, output_path]
_run_ffmpeg(cmd)
return output_path
if ext == ".svg":
rsvg = _which("rsvg-convert")
if rsvg:
import subprocess
subprocess.check_call([rsvg, "-o", output_path, input_path])
return output_path
# fallback: ffmpeg decode svg依赖 build
cmd = [FFMPEG_PATH, "-y", "-i", input_path, output_path]
_run_ffmpeg(cmd)
return output_path
# 其他格式:尽量用 ffmpeg 转
cmd = [FFMPEG_PATH, "-y", "-i", input_path, output_path]
_run_ffmpeg(cmd)
return output_path
def get_audio_info(file_path: str) -> Dict[str, Any]:
"""获取音频信息"""
return get_video_info(file_path)
@@ -830,6 +1075,12 @@ def add_bgm(
loop: bool = True,
ducking: bool = True,
duck_gain_db: float = -6.0,
# 新增:按时间段闪避(更可控,和旁白时间轴严格对齐)
duck_volume: float = 0.25,
duck_ranges: Optional[List[Tuple[float, float]]] = None,
# 新增BGM 片段可有起点/时长(不强制从 0 覆盖整段视频)
start_time: float = 0.0,
clip_duration: Optional[float] = None,
fade_in: float = 1.0,
fade_out: float = 1.0
) -> str:
@@ -856,30 +1107,46 @@ def add_bgm(
info = get_video_info(video_path)
video_duration = info["duration"]
if loop:
bgm_chain = (
f"[1:a]aloop=-1:size=2e+09,asetpts=N/SR/TB,"
f"atrim=0:{video_duration},"
f"afade=t=in:st=0:d={fade_in},"
f"afade=t=out:st={max(video_duration - fade_out, 0)}:d={fade_out},"
f"volume={bgm_volume}[bgm]"
)
else:
bgm_chain = (
f"[1:a]"
f"afade=t=in:st=0:d={fade_in},"
f"afade=t=out:st={max(video_duration - fade_out, 0)}:d={fade_out},"
f"volume={bgm_volume}[bgm]"
)
# 片段时长:默认覆盖整段视频
dur = float(clip_duration) if (clip_duration is not None and float(clip_duration) > 0) else float(video_duration)
st = max(0.0, float(start_time or 0.0))
end_for_fade = max(dur - float(fade_out or 0.0), 0.0)
if ducking:
# 使用安全参数的 sidechaincompress避免 unsupported 参数
# 基础链loop/trim -> fades -> base volume
if loop:
bgm_chain = f"[1:a]aloop=-1:size=2e+09,asetpts=N/SR/TB,atrim=0:{dur}"
else:
bgm_chain = f"[1:a]atrim=0:{dur}"
bgm_chain += f",afade=t=in:st=0:d={float(fade_in or 0.0)},afade=t=out:st={end_for_fade}:d={float(fade_out or 0.0)},volume={bgm_volume}"
# 延迟到 start_time
if st > 1e-6:
ms = int(st * 1000)
bgm_chain += f",adelay={ms}|{ms}"
# 闪避(按时间段)
# 注意:使用 enable 让 filter 只在区间内生效(外部直接 passthrough
if ducking and duck_ranges:
dv = max(0.05, min(1.0, float(duck_volume or 0.25)))
for (rs, re) in duck_ranges:
rsf = max(0.0, float(rs))
ref = max(rsf, float(re))
bgm_chain += f",volume={dv}:enable='between(t,{rsf},{ref})'"
bgm_chain += "[bgm]"
# 如果提供了 duck_ranges就用确定性的 amixducking 已在 bgm_chain 内完成)
if ducking and duck_ranges:
filter_complex = f"{bgm_chain};[0:a][bgm]amix=inputs=2:duration=first:dropout_transition=0:normalize=0[outa]"
elif ducking:
# 否则退回 sidechaincompress对原视频音频进行侧链压缩
filter_complex = (
f"{bgm_chain};"
f"[0:a][bgm]sidechaincompress=threshold=0.1:ratio=4:attack=5:release=250:makeup=1:mix=1:level_in=1:level_sc=1[outa]"
)
else:
filter_complex = f"{bgm_chain};[0:a][bgm]amix=inputs=2:duration=first[outa]"
filter_complex = f"{bgm_chain};[0:a][bgm]amix=inputs=2:duration=first:dropout_transition=0:normalize=0[outa]"
cmd = [
FFMPEG_PATH, "-y",

View File

@@ -246,3 +246,9 @@ def normalize_legacy_project(doc: Dict[str, Any]) -> Dict[str, Any]:

View File

@@ -19,6 +19,7 @@ from typing import Optional, Tuple
LEGACY_HOST_TEMP_PREFIX = "/root/video-flow/temp/"
LEGACY_HOST_OUTPUT_PREFIX = "/root/video-flow/output/"
LEGACY_HOST_PREFIX = "/root/video-flow/"
# Container mount points (see docker-compose.yml)
LEGACY_CONTAINER_TEMP_DIR = "/legacy/temp"
@@ -42,18 +43,27 @@ def map_legacy_local_path(local_path: Optional[str]) -> Tuple[Optional[str], Opt
if os.path.exists(local_path):
return local_path, None
# Legacy host -> container mapping by basename
# Legacy host path -> current container workspace path (same repo but different prefix)
# Example:
# /root/video-flow/temp/projects/... -> /app/temp/projects/...
# This covers cases where we don't mount /legacy/* but the files were copied into current stack.
if local_path.startswith(LEGACY_HOST_PREFIX):
rest = local_path[len(LEGACY_HOST_PREFIX):].lstrip("/")
candidate = str(Path("/app") / rest)
if os.path.exists(candidate):
return candidate, None
# Legacy host -> container mapping (preserve relative path)
if local_path.startswith(LEGACY_HOST_TEMP_PREFIX):
name = Path(local_path).name
container_path = str(Path(LEGACY_CONTAINER_TEMP_DIR) / name)
url = f"{LEGACY_STATIC_TEMP_PREFIX}{name}"
return container_path, url
rel = local_path[len(LEGACY_HOST_TEMP_PREFIX):].lstrip("/")
container_path = str(Path(LEGACY_CONTAINER_TEMP_DIR) / rel)
# 静态路由通常只覆盖目录根(不包含子目录);这里交给 /api/assets/file 做 FileResponse 更稳
return container_path, None
if local_path.startswith(LEGACY_HOST_OUTPUT_PREFIX):
name = Path(local_path).name
container_path = str(Path(LEGACY_CONTAINER_OUTPUT_DIR) / name)
url = f"{LEGACY_STATIC_OUTPUT_PREFIX}{name}"
return container_path, url
rel = local_path[len(LEGACY_HOST_OUTPUT_PREFIX):].lstrip("/")
container_path = str(Path(LEGACY_CONTAINER_OUTPUT_DIR) / rel)
return container_path, None
# Unknown path: keep as-is
return local_path, None
@@ -64,3 +74,9 @@ def map_legacy_local_path(local_path: Optional[str]) -> Tuple[Optional[str], Opt

102
modules/preview_proxy.py Normal file
View File

@@ -0,0 +1,102 @@
"""
Generate and cache low-bitrate preview proxies for browser playback.
Goal:
- Improve Remotion Player preview smoothness by serving smaller/faster-to-decode videos.
- Keep original `source_path` for accurate export; only swap `source_url` for preview.
Design:
- Deterministic cache key based on (path, mtime, size).
- Proxy lives under config.TEMP_DIR / "proxy".
- Generated with ffmpeg: scale/pad + fps downsample + faststart + no audio.
"""
from __future__ import annotations
import hashlib
import os
from pathlib import Path
from typing import Optional, Tuple
import config
from modules import ffmpeg_utils
def _key_for_file(path: str) -> str:
st = os.stat(path)
raw = f"{path}|{st.st_mtime_ns}|{st.st_size}".encode("utf-8")
return hashlib.sha1(raw).hexdigest() # short, deterministic
def ensure_video_proxy(
source_path: str,
*,
target_w: int = 540,
target_h: int = 960,
target_fps: int = 24,
crf: int = 28,
preset: str = "veryfast",
) -> Tuple[Optional[str], Optional[str]]:
"""
Ensure a preview proxy exists for source_path.
Returns: (proxy_path, proxy_url) or (None, None) if source_path invalid.
"""
if not source_path or not os.path.exists(source_path):
return None, None
proxy_dir = Path(config.TEMP_DIR) / "proxy"
proxy_dir.mkdir(parents=True, exist_ok=True)
key = _key_for_file(source_path)
out_name = f"proxy_{key}.mp4"
out_path = proxy_dir / out_name
if out_path.exists() and out_path.stat().st_size > 1024:
return str(out_path), f"/static/temp/proxy/{out_name}"
vf = (
f"scale={target_w}:{target_h}:force_original_aspect_ratio=decrease,"
f"pad={target_w}:{target_h}:(ow-iw)/2:(oh-ih)/2:black,"
f"fps={target_fps}"
)
cmd = [
ffmpeg_utils.FFMPEG_PATH,
"-y",
"-i",
source_path,
"-an", # preview 不要音轨,减少解码负担(旁白/BGM 走单独轨道)
"-vf",
vf,
"-c:v",
"libx264",
"-preset",
preset,
"-crf",
str(crf),
"-tune",
"fastdecode",
"-pix_fmt",
"yuv420p",
"-movflags",
"+faststart",
str(out_path),
]
ffmpeg_utils._run_ffmpeg(cmd)
if out_path.exists() and out_path.stat().st_size > 1024:
return str(out_path), f"/static/temp/proxy/{out_name}"
return None, None

View File

@@ -6,6 +6,7 @@ import base64
import json
import logging
import os
import time
import requests
from typing import Dict, Any, List, Optional
from pathlib import Path
@@ -27,7 +28,10 @@ class ScriptGenerator:
# OpenAI-compatible client for ShuBiaoBiao (supports multiple models incl. GPT)
self.shubiaobiao_client = OpenAI(
api_key=config.SHUBIAOBIAO_KEY,
base_url=config.SHUBIAOBIAO_BASE_URL
base_url=config.SHUBIAOBIAO_BASE_URL,
# IMPORTANT: OpenAI SDK default timeout is 10 minutes; cap it to keep UX responsive.
timeout=config.SHUBIAOBIAO_CHAT_TIMEOUT_S,
max_retries=config.SHUBIAOBIAO_CHAT_MAX_RETRIES,
)
# Default System Prompt
@@ -139,15 +143,23 @@ class ScriptGenerator:
product_name: str,
product_info: Dict[str, Any],
image_paths: List[str] = None,
model_provider: str = "shubiaobiao" # "shubiaobiao" or "doubao"
model_provider: str = "shubiaobiao", # "shubiaobiao" or "doubao"
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
生成分镜脚本
"""
logger.info(f"Generating script for: {product_name} (Provider: {model_provider})")
# 1. 构造 Prompt (优先从数据库读取配置)
system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt)
# 1. 构造 Prompt (优先按 user_id 读取;否则回退到全局配置,再回退默认)
system_prompt = None
if user_id:
try:
system_prompt = db.get_user_prompt(user_id, "prompt_script_gen")
except Exception:
system_prompt = None
if not system_prompt:
system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt)
user_prompt = self._build_user_prompt(product_name, product_info)
# Branch for Doubao (Volcengine)
@@ -293,21 +305,40 @@ class ScriptGenerator:
ShuBiaoBiao OpenAI-compatible multimodal chat.
IMPORTANT: For ShuBiaoBiao models, we pass image URLs (R2 public URLs), not base64.
"""
t0 = time.time()
# Use WARNING level so it shows up even if Streamlit/root logger is not configured to INFO.
logger.warning(
f"[script_gen] start shubiaobiao chat model={model_name} images={len(image_paths or [])} "
f"timeout_s={getattr(config, 'SHUBIAOBIAO_CHAT_TIMEOUT_S', 'n/a')} "
f"max_retries={getattr(config, 'SHUBIAOBIAO_CHAT_MAX_RETRIES', 'n/a')}"
)
messages = [{"role": "system", "content": system_prompt}]
user_content: List[Dict[str, Any]] = []
# Images first (URL), then text
t_upload0 = time.time()
urls = self._upload_images_to_r2(image_paths or [], limit=10)
logger.warning(
f"[script_gen] r2_upload done urls={len(urls)} elapsed_s={time.time() - t_upload0:.2f}"
)
for url in urls:
user_content.append({"type": "image_url", "image_url": {"url": url}})
user_content.append({"type": "text", "text": user_prompt})
messages.append({"role": "user", "content": user_content})
try:
resp = self.shubiaobiao_client.chat.completions.create(
client = self.shubiaobiao_client.with_options(
timeout=config.SHUBIAOBIAO_CHAT_TIMEOUT_S,
max_retries=config.SHUBIAOBIAO_CHAT_MAX_RETRIES,
)
t_call0 = time.time()
resp = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.7,
)
logger.warning(
f"[script_gen] shubiaobiao chat done elapsed_s={time.time() - t_call0:.2f} total_s={time.time() - t0:.2f}"
)
content_text = (resp.choices[0].message.content or "").strip()
script_json = self._extract_json_from_response(content_text)
if script_json is None:
@@ -323,7 +354,9 @@ class ScriptGenerator:
}
return final_script
except Exception as e:
logger.error(f"shubiaobiao script generation failed ({model_name}): {e}")
logger.error(
f"shubiaobiao script generation failed ({model_name}) after {time.time() - t0:.2f}s: {e}"
)
return None
def _postprocess_selling_points(self, product_info: Dict[str, Any], selling_points: Any) -> List[str]:
@@ -582,7 +615,33 @@ class ScriptGenerator:
def _validate_and_fix_script(self, script: Dict[str, Any]) -> Dict[str, Any]:
"""校验并修复脚本结构"""
# 简单校验,确保必要字段存在
if "scenes" not in script:
if not isinstance(script, dict):
return {"scenes": []}
# Ensure fields exist
if "scenes" not in script or not isinstance(script.get("scenes"), list):
script["scenes"] = []
# Normalize: keep visual_anchor at top-level, but avoid repeating the full anchor in every scene.visual_prompt.
# Reason: repeating a long anchor 4-5 times explodes tokens and makes UI look like "only three sections",
# while image generation already supports passing visual_anchor separately and prepending it at runtime.
visual_anchor = script.get("visual_anchor") or ""
if isinstance(visual_anchor, str) and visual_anchor.strip() and script["scenes"]:
anchor = visual_anchor.strip()
prefix = f"[{anchor}]"
for scene in script["scenes"]:
if not isinstance(scene, dict):
continue
vp = scene.get("visual_prompt")
if not isinstance(vp, str) or not vp.strip():
continue
s = vp.strip()
# Strip exact "[anchor]" prefix if present
if s.startswith(prefix):
s = s[len(prefix):].lstrip()
# If the model output copied the raw anchor without brackets, strip it too
elif s.startswith(anchor):
s = s[len(anchor):].lstrip()
scene["visual_prompt"] = s
return script

View File

@@ -90,6 +90,79 @@ class TextRenderer:
return color
return (0, 0, 0, 255)
def _wrap_text_to_width(self, text: str, font: ImageFont.FreeTypeFont, max_width: int) -> str:
"""
将文本按最大宽度自动换行(支持中英文混排)。
- 保留原始换行符为段落边界
- 英文优先按空格断词;中文按字符贪心换行
"""
try:
mw = int(max_width or 0)
except Exception:
mw = 0
if mw <= 0:
return text
# 兼容:去掉末尾多余空行
raw_paras = (text or "").split("\n")
out_lines: List[str] = []
# 1x1 dummy draw 用于测量
dummy_draw = ImageDraw.Draw(Image.new("RGBA", (1, 1)))
def text_w(s: str) -> float:
try:
return float(dummy_draw.textlength(s, font=font))
except Exception:
bbox = dummy_draw.textbbox((0, 0), s, font=font)
return float((bbox[2] - bbox[0]) if bbox else 0)
for para in raw_paras:
p = (para or "").rstrip()
if not p:
out_lines.append("")
continue
# 英文/混排:尝试按空格分词,否则按字符
use_words = (" " in p)
tokens = p.split(" ") if use_words else list(p)
cur = ""
for tok in tokens:
cand = (cur + (" " if (use_words and cur) else "") + tok) if cur else tok
if text_w(cand) <= mw:
cur = cand
continue
# 当前行放不下:先落一行
if cur:
out_lines.append(cur)
cur = tok
else:
# 单 token 超宽:强制按字符拆
if use_words:
chars = list(tok)
else:
chars = [tok]
buf = ""
for ch in chars:
cand2 = buf + ch
if text_w(cand2) <= mw or not buf:
buf = cand2
else:
out_lines.append(buf)
buf = ch
cur = buf
if cur:
out_lines.append(cur)
# 去掉尾部空行(保持中间空行)
while out_lines and out_lines[-1] == "":
out_lines.pop()
return "\n".join(out_lines)
def render(self, text: str, style: Union[Dict[str, Any], str], cache: bool = True) -> str:
"""
渲染文本并返回图片路径
@@ -122,14 +195,29 @@ class TextRenderer:
font_size = style.get("font_size", 60)
font = self._get_font(font_path, font_size)
font_color = self._parse_color(style.get("font_color", "#FFFFFF"))
bold = bool(style.get("bold", False))
italic = bool(style.get("italic", False))
underline = bool(style.get("underline", False))
# 3. 测量文本尺寸
# 3. 自动换行(可选)
max_width = style.get("max_width") or style.get("maxWidth") or style.get("text_box_width")
try:
max_width = int(max_width) if max_width is not None else 0
except Exception:
max_width = 0
if max_width > 0:
text = self._wrap_text_to_width(text, font, max_width)
# 4. 测量文本尺寸(支持多行)
dummy_draw = ImageDraw.Draw(Image.new("RGBA", (1, 1)))
bbox = dummy_draw.textbbox((0, 0), text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
try:
bbox = dummy_draw.multiline_textbbox((0, 0), text, font=font, spacing=int(font_size * 0.25), align="center")
except Exception:
bbox = dummy_draw.textbbox((0, 0), text, font=font)
text_w = (bbox[2] - bbox[0]) if bbox else 0
text_h = (bbox[3] - bbox[1]) if bbox else 0
# 4. 计算总尺寸 (包含 padding, stroke, shadow)
# 5. 计算总尺寸 (包含 padding, stroke, shadow)
strokes = style.get("stroke", [])
if isinstance(strokes, dict): strokes = [strokes] # 兼容旧格式
@@ -156,7 +244,7 @@ class TextRenderer:
canvas_w = content_w + extra_margin * 2
canvas_h = content_h + extra_margin * 2
# 5. 创建画布
# 6. 创建画布
img = Image.new("RGBA", (int(canvas_w), int(canvas_h)), (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
@@ -164,7 +252,7 @@ class TextRenderer:
center_x = canvas_w // 2
center_y = canvas_h // 2
# 6. 绘制顺序: 阴影 -> 背景 -> 描边 -> 文本
# 7. 绘制顺序: 阴影 -> 背景 -> 描边 -> 文本
# --- 绘制阴影 (针对整个块) ---
if shadow:
@@ -183,7 +271,11 @@ class TextRenderer:
# 文字阴影
txt_x = center_x - text_w / 2
txt_y = center_y - text_h / 2
shadow_draw.text((txt_x, txt_y), text, font=font, fill=shadow_color)
# 多行阴影
try:
shadow_draw.multiline_text((txt_x, txt_y), text, font=font, fill=shadow_color, spacing=int(font_size * 0.25), align="center")
except Exception:
shadow_draw.text((txt_x, txt_y), text, font=font, fill=shadow_color)
# 描边阴影
for s in strokes:
width = s.get("width", 0)
@@ -217,10 +309,55 @@ class TextRenderer:
width = s.get("width", 0)
if width > 0:
# 通过偏移模拟描边 (Pillow stroke_width 效果一般,但这里先用原生参数)
draw.text((txt_x, txt_y), text, font=font, fill=color, stroke_width=width, stroke_fill=color)
try:
draw.multiline_text((txt_x, txt_y), text, font=font, fill=color, spacing=int(font_size * 0.25), align="center", stroke_width=width, stroke_fill=color)
except Exception:
draw.text((txt_x, txt_y), text, font=font, fill=color, stroke_width=width, stroke_fill=color)
# --- 绘制文字 ---
draw.text((txt_x, txt_y), text, font=font, fill=font_color)
# italic通过仿射变换做简单斜体先绘制到单独图层再 shear
# bold通过多次微小偏移叠加模拟加粗比改 stroke 更接近“字重”)
if italic:
text_layer = Image.new("RGBA", img.size, (0, 0, 0, 0))
text_draw = ImageDraw.Draw(text_layer)
if bold:
for dx in (0, 1):
try:
text_draw.multiline_text((txt_x + dx, txt_y), text, font=font, fill=font_color, spacing=int(font_size * 0.25), align="center")
except Exception:
text_draw.text((txt_x + dx, txt_y), text, font=font, fill=font_color)
else:
try:
text_draw.multiline_text((txt_x, txt_y), text, font=font, fill=font_color, spacing=int(font_size * 0.25), align="center")
except Exception:
text_draw.text((txt_x, txt_y), text, font=font, fill=font_color)
shear = 0.22 # 经验值:适中倾斜
text_layer = text_layer.transform(
text_layer.size,
Image.AFFINE,
(1, shear, 0, 0, 1, 0),
resample=Image.BICUBIC
)
img = Image.alpha_composite(img, text_layer)
draw = ImageDraw.Draw(img)
else:
if bold:
for dx in (0, 1):
try:
draw.multiline_text((txt_x + dx, txt_y), text, font=font, fill=font_color, spacing=int(font_size * 0.25), align="center")
except Exception:
draw.text((txt_x + dx, txt_y), text, font=font, fill=font_color)
else:
try:
draw.multiline_text((txt_x, txt_y), text, font=font, fill=font_color, spacing=int(font_size * 0.25), align="center")
except Exception:
draw.text((txt_x, txt_y), text, font=font, fill=font_color)
# underline在文本底部画线与字号相关
if underline:
line_y = txt_y + text_h + max(2, int(font_size * 0.08))
line_th = max(2, int(font_size * 0.06))
draw.rectangle([txt_x, line_y, txt_x + text_w, line_y + line_th], fill=font_color)
# 7. 裁剪多余透明区域
bbox = img.getbbox()