Files
video-flow/app.py
2026-01-09 14:09:16 +08:00

1684 lines
84 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
MatchMe Studio - UI (Streamlit)
Style: Kaogujia (Clean, Data-heavy, Professional)
"""
import streamlit as st
import json
import time
import os
import random
from pathlib import Path
import pandas as pd
from time import perf_counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from os import stat
# Import Backend Modules
import config
from modules.script_gen import ScriptGenerator
from modules.image_gen import ImageGenerator
from modules.video_gen import VideoGenerator
from modules.composer import VideoComposer, VideoComposer as Composer # alias
from modules.text_renderer import renderer
from modules import export_utils
from modules.db_manager import db
from modules import path_utils
from modules import limits
from modules.legacy_path_mapper import map_legacy_local_path
from modules.legacy_normalizer import normalize_legacy_project
import extra_streamlit_components as stx
# Page Config
st.set_page_config(
page_title="Video Flow Console",
page_icon="🎬",
layout="wide",
initial_sidebar_state="expanded"
)
# ============================================================
# Auth (login + remember cookie)
# ============================================================
COOKIE_NAME = "vf_session"
cookie_manager = stx.CookieManager(key="vf_cookie_mgr")
def _current_user() -> dict:
u = st.session_state.get("current_user")
return u if isinstance(u, dict) else None
def _set_current_user(u: dict):
st.session_state.current_user = u
def _try_restore_login():
if _current_user():
return
try:
token = cookie_manager.get(COOKIE_NAME)
except Exception:
token = None
if token:
u = db.validate_session(token)
if u:
_set_current_user(u)
def _logout():
try:
token = cookie_manager.get(COOKIE_NAME)
except Exception:
token = None
if token:
db.revoke_session(token)
try:
cookie_manager.delete(COOKIE_NAME)
except Exception:
pass
st.session_state.pop("current_user", None)
st.rerun()
def _login_gate():
_try_restore_login()
u = _current_user()
if u:
return
st.title("登录")
# Login page runs before the global CSS block below; inject minimal CSS here.
st.markdown(
"""
<style>
button[aria-label="Show password text"],
button[aria-label="Hide password text"] {
display: none !important;
}
</style>
""",
unsafe_allow_html=True,
)
c1, c2 = st.columns([1, 2])
with c1:
username = st.text_input("用户名", value="", key="login_user")
password = st.text_input("密码", value="", type="password", key="login_pass")
if st.button("登录", type="primary"):
au = db.authenticate_user(username.strip(), password)
if not au:
st.error("用户名或密码错误,或账号已禁用")
st.stop()
token = db.create_session(au["id"])
try:
cookie_manager.set(COOKIE_NAME, token, max_age=7 * 24 * 3600)
except Exception:
# fallback: session-only
pass
_set_current_user(au)
st.rerun()
st.stop()
_login_gate()
# ============================================================
# BGM 智能匹配函数
# ============================================================
def match_bgm_by_style(bgm_style: str, bgm_dir: Path) -> str:
"""
根据脚本 bgm_style 智能匹配 BGM 文件
- 匹配成功:随机选一个匹配的 BGM
- 匹配失败:随机选任意一个 BGM
"""
# 获取所有 BGM 文件 (支持 .mp3 和 .mp4)
bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3"))
bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')]
if not bgm_files:
return None
# 关键词匹配
if bgm_style:
style_lower = bgm_style.lower()
# 提取关键词 (中文分词简化版:按常见词匹配)
keywords = ["活泼", "欢快", "轻松", "舒缓", "休闲", "温柔", "随性", "百搭", "bling", "节奏"]
matched_keywords = [kw for kw in keywords if kw in style_lower]
matched_files = []
for f in bgm_files:
fname = f.name
if any(kw in fname for kw in matched_keywords):
matched_files.append(f)
if matched_files:
return str(random.choice(matched_files))
# 无匹配则随机选一个
return str(random.choice(bgm_files))
# ============================================================
# CSS Styling (Kaogujia Style)
# ============================================================
st.markdown("""
<style>
/* 只保留颜色和字体样式,完全不干预布局和滚动 */
.stApp {
background-color: #F4F5F7;
font-family: "PingFang SC", "Microsoft YaHei", sans-serif;
}
section[data-testid="stSidebar"] {
background-color: #FFFFFF;
border-right: 1px solid #E5E7EB;
}
.stButton button {
border-radius: 4px;
font-weight: 500;
}
div[data-testid="stButton"] > button[kind="primary"] {
background-color: #1677FF;
color: white;
border: none;
}
div[data-testid="stButton"] > button[kind="primary"]:hover {
background-color: #4096FF;
}
h1, h2, h3 {
color: #1F2329;
font-weight: 600;
}
h2 { border-left: 4px solid #1677FF; padding-left: 10px; }
.stTextInput input, .stTextArea textarea {
border-radius: 4px;
}
/* Hide Streamlit password reveal (eye) buttons to avoid accidental exposure */
button[aria-label="Show password text"],
button[aria-label="Hide password text"] {
display: none !important;
}
</style>
""", unsafe_allow_html=True)
# ============================================================
# Session State Management
# ============================================================
if "project_id" not in st.session_state:
st.session_state.project_id = None
if "current_step" not in st.session_state:
st.session_state.current_step = 0
if "script_data" not in st.session_state:
st.session_state.script_data = None
if "scene_images" not in st.session_state:
st.session_state.scene_images = {}
if "scene_videos" not in st.session_state:
st.session_state.scene_videos = {}
if "final_video" not in st.session_state:
st.session_state.final_video = None
if "uploaded_images" not in st.session_state:
st.session_state.uploaded_images = []
if "view_mode" not in st.session_state:
st.session_state.view_mode = "workspace" # workspace, history, settings
if "selected_img_provider" not in st.session_state:
st.session_state.selected_img_provider = "shubiaobiao"
if "ui_rev" not in st.session_state:
# used to scope Streamlit widget keys per project load to avoid cross-project state pollution
st.session_state.ui_rev = int(time.time() * 1000)
def _ui_key(suffix: str) -> str:
"""
Build a project-scoped widget key.
Streamlit widget state is keyed globally; without project scoping, switching projects can
reuse old widget states and even trigger frontend DOM errors (e.g. removeChild NotFoundError).
"""
pid = st.session_state.get("project_id") or "NEW"
rev = st.session_state.get("ui_rev") or 0
return f"p:{pid}|r:{rev}|{suffix}"
def load_project(project_id):
"""Load project state from DB"""
data = db.get_project_for_user(project_id, _current_user())
if not data:
st.error("Project not found / no permission")
return
st.session_state.project_id = project_id
# bump UI revision to ensure all widget keys are isolated per project load
st.session_state.ui_rev = int(time.time() * 1000)
st.session_state.script_data = data.get("script_data")
st.session_state.view_mode = "workspace"
# Fallback: 如果 DB 中的 script_data 是旧结构/缺字段,则从 legacy JSON 重新规范化一次
try:
script_data = st.session_state.script_data
legacy_json = Path(config.TEMP_DIR) / f"project_{project_id}.json"
def _needs_normalize(sd: Any) -> bool:
if not isinstance(sd, dict):
return True
if "_legacy_schema" not in sd:
return True
scenes = sd.get("scenes") or []
if scenes and isinstance(scenes, list) and isinstance(scenes[0], dict):
if "visual_prompt" not in scenes[0] or "video_prompt" not in scenes[0]:
return True
return False
if legacy_json.exists() and _needs_normalize(script_data):
raw = json.loads(legacy_json.read_text(encoding="utf-8"))
normalized = normalize_legacy_project(raw)
st.session_state.script_data = normalized
# 写回 DB避免每次 load 都重新算
db.update_project_script(project_id, normalized)
st.info("已从 legacy JSON 重新规范化脚本字段(兼容旧版项目)。")
except Exception as e:
st.warning(f"legacy 规范化失败(将继续使用 DB 数据): {e}")
# Restore product info for Step 1 display
product_info = data.get("product_info", {})
st.session_state.loaded_product_name = data.get("name", "")
st.session_state.loaded_product_info = product_info
# Restore uploaded images from product_info
if product_info and "uploaded_images" in product_info:
# 检查文件是否仍存在
valid_paths = [p for p in product_info["uploaded_images"] if os.path.exists(p)]
st.session_state.uploaded_images = valid_paths
if len(valid_paths) < len(product_info["uploaded_images"]):
st.warning(f"部分原始图片已丢失 ({len(product_info['uploaded_images']) - len(valid_paths)} 张),建议重新上传")
else:
st.session_state.uploaded_images = []
# Restore assets
# RBAC: assets are project-scoped, so permission already checked above.
assets = db.get_assets(project_id)
images = {}
videos = {}
final_vid = None
for asset in assets:
sid = asset["scene_id"]
source_path, _mapped_url = map_legacy_local_path(asset.get("local_path"))
# 假设 scene_id 0 或 -1 用于 final video
if asset["asset_type"] == "image" and asset["status"] == "completed":
if source_path:
images[sid] = source_path
elif asset["asset_type"] == "video" and asset["status"] == "completed":
if source_path:
videos[sid] = source_path
elif asset["asset_type"] == "final_video" and asset["status"] == "completed":
if source_path:
final_vid = source_path
st.session_state.scene_images = images
st.session_state.scene_videos = videos
st.session_state.final_video = final_vid
# Determine step
if final_vid:
st.session_state.current_step = 4
elif videos:
st.session_state.current_step = 4
elif images:
st.session_state.current_step = 3
elif st.session_state.script_data:
st.session_state.current_step = 1
else:
st.session_state.current_step = 0
# ============================================================
# Helper Functions (must be defined before Sidebar uses them)
# ============================================================
def _record_metrics(project_id: str, patch: dict):
"""Persist lightweight timing/diagnostic metrics into project.product_info['_metrics']."""
if not project_id or not isinstance(patch, dict) or not patch:
return
try:
proj = db.get_project(project_id) or {}
product_info = proj.get("product_info") or {}
metrics = product_info.get("_metrics") if isinstance(product_info.get("_metrics"), dict) else {}
metrics.update(patch)
metrics["updated_at"] = time.time()
product_info["_metrics"] = metrics
db.update_project_product_info(project_id, product_info)
except Exception:
# metrics must never break UX
pass
def _get_metrics(project_id: str) -> dict:
try:
proj = db.get_project(project_id) or {}
product_info = proj.get("product_info") or {}
m = product_info.get("_metrics")
return m if isinstance(m, dict) else {}
except Exception:
return {}
def _ensure_local_videos(project_id: str, scenes: list):
"""
Step 5 合成前的保障逻辑:
检查分镜视频是否已下载到服务器本地。如果只有 URL 没有本地文件,则后台静默下载。
"""
if not project_id or not scenes:
return
vid_gen = VideoGenerator()
downloaded = 0
missing_assets = []
for scene in scenes:
scene_id = scene["id"]
local_path = st.session_state.scene_videos.get(scene_id)
# 如果本地文件不存在,尝试补全
if not local_path or not os.path.exists(local_path):
asset = db.get_asset(project_id, scene_id, "video")
if not asset:
missing_assets.append(f"Scene {scene_id}")
continue
meta = asset.get("metadata") or {}
video_url = meta.get("video_url")
task_id = asset.get("task_id")
# 如果 DB 里没有 URL但有 task_id尝试现场查询一次
if not video_url and task_id:
logger.info(f"Checking task {task_id} status for Scene {scene_id} during composition...")
status, url = vid_gen.check_task_status(task_id)
if status == "succeeded" and url:
video_url = url
db.update_asset_metadata(project_id, scene_id, "video", {"video_url": url, "volc_status": status})
if video_url:
out_name = path_utils.unique_filename(
prefix="scene_video",
ext="mp4",
project_id=project_id,
scene_id=scene_id,
extra=f"auto_{int(time.time())}"
)
logger.info(f"Auto-downloading missing video for Scene {scene_id} from {video_url}")
target_dir = path_utils.project_videos_dir(project_id)
target_path = str(target_dir / out_name)
if vid_gen._download_video_to(video_url, target_path):
st.session_state.scene_videos[scene_id] = target_path
db.save_asset(project_id, scene_id, "video", "completed", local_path=target_path)
downloaded += 1
else:
missing_assets.append(f"Scene {scene_id} (未生成/无URL)")
if downloaded > 0:
logger.info(f"Successfully auto-downloaded {downloaded} missing videos for project {project_id}")
if missing_assets:
msg = f"无法合成:以下分镜缺少视频素材,请先在第 4 步生成:{', '.join(missing_assets)}"
logger.warning(msg)
raise ValueError(msg)
# ============================================================
# Sidebar
# ============================================================
with st.sidebar:
st.title("📽️ Video Flow")
cu = _current_user()
if cu:
st.caption(f"登录用户: {cu.get('username')} ({cu.get('role')})")
if st.button("退出登录", type="secondary"):
_logout()
# Mode Selection - 正确计算 index
mode_options = ["🛠️ 工作台", "📜 历史任务", "⚙️ 设置"]
mode_map = {"workspace": 0, "history": 1, "settings": 2}
current_index = mode_map.get(st.session_state.view_mode, 0)
mode = st.radio("模式", mode_options, index=current_index)
if mode == "🛠️ 工作台":
st.session_state.view_mode = "workspace"
elif mode == "📜 历史任务":
st.session_state.view_mode = "history"
elif mode == "⚙️ 设置":
st.session_state.view_mode = "settings"
st.markdown("---")
if st.session_state.view_mode == "workspace":
# Project Selection
st.subheader("Current Project")
projects = db.list_projects_for_user(_current_user())
proj_options = {p['id']: f"{p.get('name', 'Untitled')} ({p['id']})" for p in projects}
selected_proj_id = st.selectbox(
"Select Project",
options=["New Project"] + list(proj_options.keys()),
format_func=lambda x: " New Project" if x == "New Project" else proj_options[x]
)
if selected_proj_id == "New Project":
if st.session_state.project_id and st.session_state.project_id in proj_options:
# If switching from existing to new, reset
if st.button("Start New"):
for key in list(st.session_state.keys()):
if key != "view_mode": del st.session_state[key]
st.rerun()
else:
if st.session_state.project_id != selected_proj_id:
if st.button(f"Load {selected_proj_id}"):
load_project(selected_proj_id)
st.rerun()
if st.session_state.project_id:
st.caption(f"Current ID: {st.session_state.project_id}")
with st.expander("⏱️ 性能与诊断", expanded=False):
m = _get_metrics(st.session_state.project_id)
if not m:
st.caption("暂无指标(执行一次脚本/生图/生视频/合成后会出现)。")
else:
keys = [
"script_gen_s",
"image_gen_total_s",
"video_submit_s",
"video_recover_s",
"compose_s",
"script_model",
"image_provider",
"image_generated",
"video_submitted",
"video_recovered",
"bgm_used",
]
for k in keys:
if k in m:
st.caption(f"{k}: {m.get(k)}")
st.markdown("---")
# Navigation / Progress
steps = ["1. 输入信息", "2. 生成脚本", "3. 画面生成", "4. 视频生成", "5. 合成输出"]
current = st.session_state.current_step
for i, step in enumerate(steps):
if i < current:
st.markdown(f"✅ **{step}**")
elif i == current:
st.markdown(f"🔵 **{step}**")
else:
st.markdown(f"{step}")
st.markdown("---")
if st.button("🔄 重置状态", type="secondary"):
for key in list(st.session_state.keys()):
if key != "view_mode": del st.session_state[key]
st.rerun()
def save_uploaded_file(project_id: str, uploaded_file):
"""Save uploaded file to per-project upload dir (avoid overwrites across projects)."""
if uploaded_file is None:
return None
if not project_id:
raise ValueError("project_id is required to save uploaded files safely")
upload_dir = path_utils.project_upload_dir(project_id)
original = path_utils.sanitize_filename(getattr(uploaded_file, "name", "upload"))
# keep original stem for readability, but ensure uniqueness
suffix = Path(original).suffix.lstrip(".") or "bin"
stem = Path(original).stem or "upload"
unique_name = path_utils.unique_filename(prefix=f"upload_{stem}", ext=suffix, project_id=project_id)
file_path = upload_dir / unique_name
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
return str(file_path)
# ============================================================
# Main Content: Workspace
# ============================================================
if st.session_state.view_mode == "workspace":
st.header("商品短视频自动生成控制台")
# --- Step 1: Input ---
with st.expander("📦 1. 商品信息输入", expanded=(st.session_state.current_step == 0)):
# 从 session_state 读取已加载项目的信息,否则使用默认值
loaded_info = st.session_state.get("loaded_product_info", {})
default_name = st.session_state.get("loaded_product_name", "网红气质大号发量多!高马尾香蕉夹")
default_category = loaded_info.get("category", "钟表配饰-时尚饰品-发饰")
default_price = loaded_info.get("price", "3.99元")
default_tags = loaded_info.get("tags", "回头客|款式好看|材质好|尺寸合适|颜色好看|很好用|做工好|质感不错|很牢固")
default_params = loaded_info.get("params", "金属材质:非金属; 非金属材质:树脂; 发夹分类:香蕉夹")
col1, col2 = st.columns([1, 1])
with col1:
product_name = st.text_input("商品标题", value=default_name, key=_ui_key("product_name"))
category = st.text_input("商品类目", value=default_category, key=_ui_key("category"))
price = st.text_input("价格", value=default_price, key=_ui_key("price"))
with col2:
tags = st.text_area("评价标签 (用于提炼卖点)", value=default_tags, height=100, key=_ui_key("tags"))
params = st.text_area("商品参数", value=default_params, height=100, key=_ui_key("params"))
# 商家自定义风格提示
style_hint = st.text_area(
"商品视频重点增强提示 (可选)",
value=loaded_info.get("style_hint", ""),
placeholder="例如:韩风、高级感、活力青春、简约日系...",
height=80,
key=_ui_key("style_hint"),
)
st.markdown("### 上传素材")
# 显示已有的上传图片(如果从历史项目加载)
if st.session_state.uploaded_images:
st.info(f"已有 {len(st.session_state.uploaded_images)} 张参考图片")
with st.expander("查看已有图片"):
img_cols = st.columns(min(len(st.session_state.uploaded_images), 4))
for i, img_path in enumerate(st.session_state.uploaded_images[:4]):
if os.path.exists(img_path):
img_cols[i % 4].image(img_path, width=150)
uploaded_files = st.file_uploader("上传商品主图 (建议 3-5 张)", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
# 允许在没有上传新图片但有历史图片的情况下继续
can_submit = uploaded_files or st.session_state.uploaded_images
# Model Selection (all support images; user explicitly chooses model)
model_options = ["GPT-5.2", "Gemini 3 Pro", "Doubao Pro (Vision)"]
selected_model_label = st.radio("选择脚本生成模型", model_options, horizontal=True, index=0, key=_ui_key("script_model"))
# Map label to provider key
if selected_model_label == "GPT-5.2":
model_provider = "shubiaobiao_gpt"
elif "Doubao" in selected_model_label:
model_provider = "doubao"
else:
model_provider = "shubiaobiao"
if st.button("提交任务 & 生成脚本", type="primary", disabled=(not can_submit)):
# 处理图片路径
image_paths = list(st.session_state.uploaded_images) if st.session_state.uploaded_images else []
# 如果有新上传的文件,添加它们
if uploaded_files:
# Create Project ID
if not st.session_state.project_id:
st.session_state.project_id = f"PROJ-{int(time.time())}"
for uf in uploaded_files:
path = save_uploaded_file(st.session_state.project_id, uf)
if path: image_paths.append(path)
st.session_state.uploaded_images = image_paths
# 确保有项目 ID
if not st.session_state.project_id:
st.session_state.project_id = f"PROJ-{int(time.time())}"
# DB: Create Project
# 将 uploaded_images 保存到 product_info 以便持久化
product_info = {"category": category, "price": price, "tags": tags, "params": params, "style_hint": style_hint, "uploaded_images": image_paths}
db.create_project(
st.session_state.project_id,
product_name,
product_info,
owner_user_id=(_current_user() or {}).get("id"),
)
# Call Script Generator
with st.spinner(f"正在分析商品信息并生成脚本 ({selected_model_label})..."):
gen = ScriptGenerator()
t0 = perf_counter()
script = gen.generate_script(
product_name,
product_info,
image_paths,
model_provider=model_provider,
user_id=(_current_user() or {}).get("id"),
)
_record_metrics(st.session_state.project_id, {
"script_gen_s": round(perf_counter() - t0, 3),
"script_model": model_provider,
})
if script:
st.session_state.script_data = script
# DB: Save Script
db.update_project_script(st.session_state.project_id, script)
st.session_state.current_step = 1
st.success("脚本生成成功!")
st.rerun()
else:
st.error("脚本生成失败,请检查日志。")
# --- Step 2: Script Review ---
if st.session_state.current_step >= 1:
with st.expander("📝 2. 脚本分镜确认", expanded=(st.session_state.current_step == 1)):
if st.session_state.script_data:
script = st.session_state.script_data
# Display Basic Info (兼容 legacy schema)
selling_points = script.get("selling_points", []) or []
target_audience = script.get("target_audience", "") or ""
analysis_text = script.get("analysis", "") or ""
legacy_schema = script.get("_legacy_schema", "") or ""
c1, c2 = st.columns(2)
if selling_points:
c1.write(f"**核心卖点**: {', '.join(selling_points)}")
else:
c1.write("**核心卖点**: (legacy 项目可能未生成该字段)")
if analysis_text:
with st.expander("查看 legacy analysis用于补齐信息"):
st.write(analysis_text)
if target_audience:
c2.write(f"**目标人群**: {target_audience}")
else:
c2.write("**目标人群**: (legacy 项目可能未生成该字段)")
# Hook / CTA / Schema
hook = script.get("hook", "") or ""
if hook:
st.markdown(f"**Hook**: {hook}")
cta = script.get("cta", "")
if cta:
if isinstance(cta, dict):
st.markdown("**CTAlegacy object**")
st.json(cta)
else:
st.markdown(f"**CTA**: {cta}")
if legacy_schema:
st.caption(f"Legacy Schema: {legacy_schema}")
# Prompt Visualization
if "_debug" in script:
with st.expander("🔍 查看 AI Prompt (Debug)"):
debug_info = script["_debug"]
st.markdown("**System Prompt:**")
st.code(debug_info.get("system_prompt", ""), language="markdown")
st.markdown("**User Prompt:**")
st.code(debug_info.get("user_prompt", ""), language="markdown")
# 显示原始输出
if "raw_output" in debug_info:
st.markdown("**Raw AI Output:**")
st.code(debug_info.get("raw_output", ""), language="markdown")
# Editable Scenes Table
st.markdown("### 分镜列表")
scenes = script.get("scenes", [])
# Global Voiceover Timeline (New)
st.markdown("### 🎙️ 整体旁白与字幕时间轴")
with st.expander("编辑旁白时间轴 (Voiceover Timeline)", expanded=True):
timeline = script.get("voiceover_timeline", []) or []
if not timeline:
# 对于历史项目:如果没有 scenes 也没有 timeline不要强行塞“示例旁白”避免污染数据
if not scenes and analysis_text:
st.info("该历史项目暂无旁白时间轴(可能停留在分析/提问阶段)。")
timeline = []
else:
# Init with default if empty (使用秒)
timeline = [{"text": "示例旁白", "subtitle": "示例字幕", "start_time": 0.0, "duration": 3.0}]
updated_timeline = []
for i, item in enumerate(timeline):
c1, c2, c3, c4 = st.columns([3, 3, 1, 1])
with c1:
item["text"] = st.text_input(f"旁白 #{i+1}", value=item.get("text", ""), key=_ui_key(f"tl_vo_{i}"))
with c2:
item["subtitle"] = st.text_input(f"字幕 #{i+1}", value=item.get("subtitle", item.get("text", "")), key=_ui_key(f"tl_sub_{i}"))
with c3:
# 兼容旧格式: 如果有 start_ratio 则转换为 start_time
default_start = item.get("start_time", item.get("start_ratio", 0) * 12) # 假设总时长12秒
item["start_time"] = st.number_input(f"开始(秒) #{i+1}", value=float(default_start), min_value=0.0, max_value=30.0, step=0.5, key=_ui_key(f"tl_start_{i}"))
with c4:
# 兼容旧格式: 如果有 duration_ratio 则转换为 duration
default_dur = item.get("duration", item.get("duration_ratio", 0.25) * 12) # 假设总时长12秒
item["duration"] = st.number_input(f"时长(秒) #{i+1}", value=float(default_dur), min_value=0.5, max_value=15.0, step=0.5, key=_ui_key(f"tl_dur_{i}"))
# 清理旧字段
item.pop("start_ratio", None)
item.pop("duration_ratio", None)
updated_timeline.append(item)
if st.button(" 添加旁白段落"):
updated_timeline.append({"text": "", "subtitle": "", "start_time": 0.0, "duration": 3.0})
st.rerun()
script["voiceover_timeline"] = updated_timeline
updated_scenes = []
for i, scene in enumerate(scenes):
with st.container():
st.markdown(f"#### 🎬 Scene {scene['id']}")
c_vis, c_aud = st.columns([2, 1])
with c_vis:
new_visual = st.text_area(f"Visual Prompt (Scene {scene['id']})", value=scene.get("visual_prompt", ""), height=80, key=_ui_key(f"vp_{i}"))
new_video = st.text_area(f"Video Prompt (Scene {scene['id']})", value=scene.get("video_prompt", ""), height=80, key=_ui_key(f"vidp_{i}"))
scene["visual_prompt"] = new_visual
scene["video_prompt"] = new_video
with c_aud:
# 花字编辑保留
ft = scene.get("fancy_text", {})
if isinstance(ft, dict):
new_ft_text = st.text_input(
f"Fancy Text (Scene {scene['id']})",
value=ft.get("text", ""),
key=_ui_key(f"ft_{i}"),
)
# 兼容:旧数据可能没有 fancy_text 字段
if not isinstance(scene.get("fancy_text"), dict):
scene["fancy_text"] = {}
scene["fancy_text"]["text"] = new_ft_text
# 旁白/字幕已移至上方整体时间轴,此处仅作展示或删除
st.caption("注:旁白与字幕已移至上方整体时间轴编辑")
# Legacy 信息展示(只读,用于调试/对齐)
legacy_scene = scene.get("_legacy", {}) if isinstance(scene.get("_legacy", {}), dict) else {}
if legacy_scene:
with st.expander(f"Legacy 信息 (Scene {scene['id']})", expanded=False):
img_url = legacy_scene.get("image_url")
if img_url:
st.markdown(f"- image_url: `{img_url}`")
cam = legacy_scene.get("camera_movement")
if cam:
st.markdown(f"- camera_movement: {cam}")
vo = legacy_scene.get("voiceover")
if vo:
st.markdown(f"- voiceover: {vo}")
keyframe = legacy_scene.get("keyframe")
if keyframe:
st.markdown("- keyframe:")
st.json(keyframe)
rhythm = legacy_scene.get("rhythm")
if rhythm:
st.markdown("- rhythm:")
st.json(rhythm)
updated_scenes.append(scene)
st.divider()
st.session_state.script_data["scenes"] = updated_scenes
col_act1, col_act2 = st.columns([1, 4])
with col_act1:
if st.button("确认脚本 & 开始生图", type="primary"):
# DB: Update script in case user edited it
db.update_project_script(st.session_state.project_id, st.session_state.script_data)
st.session_state.current_step = 2
st.rerun()
# --- Step 3: Image Generation ---
if st.session_state.current_step >= 2:
with st.expander("🎨 3. 画面生成", expanded=(st.session_state.current_step == 2)):
# Debug: Show reference image
if st.session_state.uploaded_images:
with st.expander("🔍 查看生图参考底图 (Base Image)"):
for i, p in enumerate(st.session_state.uploaded_images):
st.info(f"{i+1} 文件名: {os.path.basename(p)}")
if os.path.exists(p):
st.image(p, width=200)
else:
st.warning("⚠️ 未检测到参考底图!将生成随机风格图像。建议在 Step 1 重新上传。")
# Trigger Generation Button
if not st.session_state.scene_images:
# Model Selection for Image Gen
img_model_options = ["Shubiaobiao (Gemini)", "Doubao (Volcengine)", "Gemini (Direct)", "Doubao (Group Image)"]
selected_img_model = st.radio("选择生图模型", img_model_options, horizontal=True, key=_ui_key("img_model"))
# Map to provider
if "Group Image" in selected_img_model:
img_provider = "doubao-group"
elif "Doubao" in selected_img_model:
img_provider = "doubao"
elif "Direct" in selected_img_model:
img_provider = "gemini"
else:
img_provider = "shubiaobiao"
# Store selected provider for regeneration
st.session_state.selected_img_provider = img_provider
if st.button("🚀 执行 AI 生图", type="primary"):
with limits.acquire_image(blocking=False) as ok:
if not ok:
st.warning("系统正在生成其他任务(生图并发已达上限),请稍后再试。")
st.stop()
img_gen = ImageGenerator()
# Pass ALL uploaded images as reference
base_imgs = st.session_state.uploaded_images if st.session_state.uploaded_images else []
if not base_imgs:
st.error("No base image found (未找到参考底图). Please upload in Step 1.")
st.stop()
scenes = st.session_state.script_data.get("scenes", [])
# 读取 visual_anchor 用于保持生图一致性
visual_anchor = st.session_state.script_data.get("visual_anchor", "")
if img_provider == "doubao-group":
# --- Group Generation Logic ---
with st.spinner("正在进行 Doubao 组图生成 (Batch Group Generation)..."):
try:
t0 = perf_counter()
results = img_gen.generate_group_images_doubao(
scenes=scenes,
reference_images=base_imgs,
visual_anchor=visual_anchor,
project_id=st.session_state.project_id
)
_record_metrics(st.session_state.project_id, {
"image_gen_total_s": round(perf_counter() - t0, 3),
"image_provider": img_provider,
"image_generated": len(results),
})
for s_id, path in results.items():
st.session_state.scene_images[s_id] = path
db.save_asset(st.session_state.project_id, s_id, "image", "completed", local_path=path)
# Invalidate stale video for this scene (group image regen also changes image)
db.clear_asset(st.session_state.project_id, s_id, "video", status="pending")
st.session_state.scene_videos.pop(s_id, None)
if len(results) == len(scenes):
st.success("组图生成完成!")
db.update_project_status(st.session_state.project_id, "images_generated")
time.sleep(1)
st.rerun()
else:
st.warning(f"部分图片生成失败: {len(results)}/{len(scenes)}")
st.rerun()
except Exception as e:
st.error(f"组图生成失败: {e}")
else:
# --- Parallel Logic (default): only merchant uploaded images as references ---
total_scenes = len(scenes)
progress_bar = st.progress(0)
status_text = st.empty()
try:
t0 = perf_counter()
# Parallel workers within a single run; global semaphore already acquired above.
max_workers = 6
futures = {}
with ThreadPoolExecutor(max_workers=max_workers) as ex:
for idx, scene in enumerate(scenes):
scene_id = scene["id"]
futures[ex.submit(
img_gen.generate_single_scene_image,
scene=scene,
original_image_path=list(base_imgs), # ONLY merchant images
previous_image_path=None,
model_provider=img_provider,
visual_anchor=visual_anchor,
project_id=st.session_state.project_id,
)] = (idx, scene_id)
done = 0
for fut in as_completed(futures):
idx, scene_id = futures[fut]
done += 1
status_text.text(f"已完成 {done}/{total_scenes}Scene {scene_id}")
try:
img_path = fut.result()
if img_path:
st.session_state.scene_images[scene_id] = img_path
db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=img_path)
# Invalidate stale video for this scene (image changed => old video is wrong)
db.clear_asset(st.session_state.project_id, scene_id, "video", status="pending")
st.session_state.scene_videos.pop(scene_id, None)
except Exception as e:
st.warning(f"Scene {scene_id} 生成失败:{e}")
progress_bar.progress(done / total_scenes)
status_text.text("生图完成!")
st.success("生图完成!")
_record_metrics(st.session_state.project_id, {
"image_gen_total_s": round(perf_counter() - t0, 3),
"image_provider": img_provider,
"image_generated": len(st.session_state.scene_images),
})
# Update Status
db.update_project_status(st.session_state.project_id, "images_generated")
time.sleep(1)
st.rerun()
except PermissionError as e:
st.error(f"生图服务拒绝访问 (可能余额不足): {e}")
except Exception as e:
st.error(f"生图发生错误: {e}")
# Display & Regenerate Images
if st.session_state.scene_images:
cols = st.columns(4)
scenes = st.session_state.script_data.get("scenes", [])
for i, scene in enumerate(scenes):
scene_id = scene["id"]
img_path = st.session_state.scene_images.get(scene_id)
with cols[i % 4]:
st.markdown(f"**Scene {scene_id}**")
# Expand Prompt info
with st.expander("Prompt", expanded=False):
st.caption(scene.get("visual_prompt", "No prompt"))
if img_path and os.path.exists(img_path):
st.image(img_path, width=None, use_column_width=True) # Fix deprecated
else:
st.error("Image missing")
# Regenerate specific image
if st.button(f"🔄 重生 S{scene_id}", key=f"regen_img_{scene_id}"):
if not st.session_state.uploaded_images:
st.error("缺少参考底图,请在 Step 1 重新上传图片")
else:
with st.spinner(f"正在重绘 Scene {scene_id}..."):
img_gen = ImageGenerator()
# Only merchant uploaded images as references (no chaining)
current_refs_for_regen = list(st.session_state.uploaded_images)
# Fallback to single mode for regen if group was used
provider = st.session_state.get("selected_img_provider", "shubiaobiao")
if provider == "doubao-group": provider = "doubao"
# 读取 visual_anchor
regen_visual_anchor = st.session_state.script_data.get("visual_anchor", "")
new_path = img_gen.generate_single_scene_image(
scene=scene,
original_image_path=current_refs_for_regen,
previous_image_path=None,
model_provider=provider,
visual_anchor=regen_visual_anchor,
project_id=st.session_state.project_id
)
if new_path:
st.session_state.scene_images[scene_id] = new_path
db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=new_path)
# Invalidate stale video for this scene
db.clear_asset(st.session_state.project_id, scene_id, "video", status="pending")
st.session_state.scene_videos.pop(scene_id, None)
st.rerun()
if st.button("下一步:生成视频", type="primary"):
st.session_state.current_step = 3
st.rerun()
# --- Step 4: Video Generation ---
if st.session_state.current_step >= 3:
with st.expander("🎥 4. 视频生成 (Volcengine I2V)", expanded=(st.session_state.current_step == 3)):
scenes = st.session_state.script_data.get("scenes", [])
vid_gen = VideoGenerator()
st.caption("简化策略:第 4 步完成“生成→轮询→下载到服务器本地”。当至少有一个分镜视频落盘后,才允许进入第 5 步合成。")
if st.button("🎬 生成分镜视频并下载到服务器(阻塞)", type="primary"):
with limits.acquire_video(blocking=False) as ok:
if not ok:
st.warning("系统正在处理其他视频任务(并发已达上限),请稍后再试。")
st.stop()
if not st.session_state.project_id:
st.error("缺少 project_id无法生成视频。")
st.stop()
if not scenes:
st.error("脚本中没有分镜,无法生成视频。")
st.stop()
# 先把 DB 中已完成且存在的本地视频恢复到 session
for scene in scenes:
scene_id = scene["id"]
existing = st.session_state.scene_videos.get(scene_id)
if existing and os.path.exists(existing):
continue
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
local_path = (asset or {}).get("local_path")
if local_path and os.path.exists(local_path):
st.session_state.scene_videos[scene_id] = local_path
t0 = perf_counter()
total = len(scenes)
done = sum(1 for _sid, p in (st.session_state.scene_videos or {}).items() if p and os.path.exists(p))
progress = st.progress(0.0)
status_text = st.empty()
# 收集/提交任务
tasks = {} # scene_id -> task_id
for scene in scenes:
scene_id = scene["id"]
existing = st.session_state.scene_videos.get(scene_id)
if existing and os.path.exists(existing):
continue
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
task_id = (asset or {}).get("task_id")
if task_id:
tasks[scene_id] = task_id
continue
image_path = st.session_state.scene_images.get(scene_id)
prompt = scene.get("video_prompt", "High quality video")
new_task_id = vid_gen.submit_scene_video_task(
st.session_state.project_id, scene_id, image_path, prompt
)
if new_task_id:
tasks[scene_id] = new_task_id
if tasks:
db.update_project_status(st.session_state.project_id, "videos_processing")
out_dir = path_utils.project_videos_dir(st.session_state.project_id)
pending = set(tasks.keys())
deadline = time.time() + 15 * 60 # 15 min
while pending and time.time() < deadline:
status_text.text(f"视频生成/下载中:已完成 {done}/{total},队列中 {len(pending)} ...")
to_remove = []
for scene_id in list(pending):
task_id = tasks.get(scene_id)
if not task_id:
to_remove.append(scene_id)
continue
status, url = vid_gen.check_task_status(task_id)
if status == "succeeded" and url:
out_name = path_utils.unique_filename(
prefix="scene_video",
ext="mp4",
project_id=st.session_state.project_id,
scene_id=scene_id,
extra=(task_id[-8:] if isinstance(task_id, str) else None),
)
target_path = str(out_dir / out_name)
ok_dl = vid_gen._download_video_to(url, target_path)
if ok_dl and os.path.exists(target_path):
st.session_state.scene_videos[scene_id] = target_path
db.save_asset(
st.session_state.project_id,
scene_id,
"video",
"completed",
local_path=target_path,
task_id=task_id,
metadata={"downloaded_at": time.time()},
)
done += 1
else:
db.save_asset(
st.session_state.project_id,
scene_id,
"video",
"failed",
task_id=task_id,
metadata={"download_error": True, "checked_at": time.time()},
)
to_remove.append(scene_id)
elif status in ["failed", "cancelled"]:
db.save_asset(
st.session_state.project_id,
scene_id,
"video",
"failed",
task_id=task_id,
metadata={"volc_status": status, "checked_at": time.time()},
)
to_remove.append(scene_id)
else:
db.update_asset_metadata(
st.session_state.project_id,
scene_id,
"video",
{"volc_status": status, "checked_at": time.time()},
)
for sid in to_remove:
pending.discard(sid)
progress.progress(min(1.0, done / max(total, 1)))
if pending:
time.sleep(5)
if pending:
st.warning(f"仍有 {len(pending)} 个分镜未在本轮完成:{sorted(list(pending))}。可再次点击按钮继续轮询与下载。")
_record_metrics(st.session_state.project_id, {
"video_blocking_total_s": round(perf_counter() - t0, 3),
"video_done": done,
"video_total": total,
})
if any(p and os.path.exists(p) for p in (st.session_state.scene_videos or {}).values()):
db.update_project_status(st.session_state.project_id, "videos_generated")
st.success("已生成并下载到服务器本地。进入第 5 步合成。")
st.session_state.current_step = 4
st.rerun()
else:
st.error("未生成任何可用视频,请检查生视频接口或稍后重试。")
# Display Videos (even when partially available)
if st.session_state.scene_videos or scenes:
cols = st.columns(4)
for i, scene in enumerate(scenes):
scene_id = scene["id"]
vid_path = st.session_state.scene_videos.get(scene_id)
with cols[i % 4]:
st.markdown(f"**Scene {scene_id}**")
# Expand Prompt info
with st.expander("Prompt", expanded=False):
st.caption(scene.get("video_prompt", "No prompt"))
if vid_path and os.path.exists(vid_path):
st.video(vid_path)
else:
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
status = (asset or {}).get("status") or "pending"
task_id = (asset or {}).get("task_id")
if task_id:
st.caption(f"状态: {status} | Task: {str(task_id)[-6:]}")
else:
st.caption(f"状态: {status}")
st.warning("暂无本地视频(请点击上方“生成分镜视频并下载到服务器”)。")
# Per-scene regenerate button
if st.button(f"🔄 重生 S{scene_id}", key=f"regen_vid_{scene_id}"):
with st.spinner(f"正在重生成 Scene {scene_id} 视频..."):
vid_gen = VideoGenerator()
img_path = st.session_state.scene_images.get(scene_id)
if img_path and os.path.exists(img_path):
t_id = vid_gen.submit_scene_video_task(st.session_state.project_id, scene_id, img_path, scene.get("video_prompt"))
if t_id:
# Simple poll loop for single video regen
import time
for _ in range(60):
status, url = vid_gen.check_task_status(t_id)
if status == "succeeded":
# Use per-project unique name to avoid cross-project overwrite
out_name = path_utils.unique_filename(
prefix="scene_video",
ext="mp4",
project_id=st.session_state.project_id,
scene_id=scene_id,
extra=(t_id[-8:] if isinstance(t_id, str) else None),
)
target_dir = path_utils.project_videos_dir(st.session_state.project_id)
target_path = str(target_dir / out_name)
if vid_gen._download_video_to(url, target_path) and os.path.exists(target_path):
st.session_state.scene_videos[scene_id] = target_path
db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=target_path, task_id=t_id)
st.rerun()
break
elif status in ["failed", "cancelled"]:
st.error(f"生成失败: {status}")
break
time.sleep(5)
else:
st.error("提交任务失败")
else:
st.error("缺少源图片")
c_act1, c_act2 = st.columns([1, 4])
with c_act1:
if st.button("🧹 清空视频并重新生成", type="secondary"):
# Clear videos and rerun
st.session_state.scene_videos = {}
# Also clear DB video assets
if st.session_state.project_id:
db.clear_assets(st.session_state.project_id, "video", status="pending")
st.rerun()
with c_act2:
can_compose = any(
p and os.path.exists(p) for p in (st.session_state.scene_videos or {}).values()
)
if st.button("进入第 5 步:合成最终成片", type="primary", disabled=not can_compose):
st.session_state.current_step = 4
st.rerun()
if not can_compose:
st.caption("⚠️ 需要至少一个分镜视频已下载到服务器本地后,才能进入合成。")
# --- Step 5: Final Composition & Tuning ---
if st.session_state.current_step >= 4:
with st.expander("🎞️ 5. 合成结果与微调", expanded=(st.session_state.current_step == 4)):
# Tabs: Preview vs Tuning
tab_preview, tab_tune = st.tabs(["🎥 预览", "🎛️ 微调 (Track Editor)"])
with tab_tune:
st.subheader("轨道微调 (Tuning)")
# Global Settings
col_g1, col_g2 = st.columns(2)
with col_g1:
# BGM Select (支持 .mp3 和 .mp4)
bgm_dir = config.ASSETS_DIR / "bgm"
bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3"))
bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')]
bgm_names = [f.name for f in bgm_files]
# 智能推荐默认 BGM
bgm_style = st.session_state.script_data.get("bgm_style", "")
recommended_bgm = match_bgm_by_style(bgm_style, bgm_dir)
default_idx = 0
if recommended_bgm:
rec_name = Path(recommended_bgm).name
if rec_name in bgm_names:
default_idx = bgm_names.index(rec_name) + 1 # +1 因为有 "None" 选项
selected_bgm = st.selectbox(
f"背景音乐 (BGM) - 推荐风格: {bgm_style or '未指定'}",
["None"] + bgm_names,
index=default_idx
)
# 明确提示BGM 目录为空或选中 BGM 不存在时,本次将不含 BGM
if not bgm_names:
st.warning(f"BGM 目录为空:{bgm_dir}(本次合成将不含 BGM")
elif selected_bgm != "None":
candidate = config.ASSETS_DIR / "bgm" / selected_bgm
if not candidate.exists():
st.warning(f"所选 BGM 文件不存在:{candidate}(本次合成将不含 BGM")
with col_g2:
# Voice Select
selected_voice = st.selectbox("配音音色 (TTS)", [config.VOLC_TTS_DEFAULT_VOICE, "zh_female_meilinvyou_saturn_bigtts"])
st.divider()
# Global Timeline Editor
st.markdown("### 🎙️ 旁白与字幕时间轴 (单位: 秒)")
timeline = st.session_state.script_data.get("voiceover_timeline", [])
updated_timeline = []
for i, item in enumerate(timeline):
c1, c2, c3, c4 = st.columns([3, 3, 1, 1])
with c1:
item["text"] = st.text_input(f"旁白 #{i+1}", value=item.get("text", ""), key=_ui_key(f"tune_vo_{i}"))
with c2:
item["subtitle"] = st.text_input(f"字幕 #{i+1}", value=item.get("subtitle", item.get("text", "")), key=_ui_key(f"tune_sub_{i}"))
with c3:
# 兼容旧格式: 如果有 start_ratio 则转换为 start_time
default_start = item.get("start_time", item.get("start_ratio", 0) * 12)
item["start_time"] = st.number_input(f"开始(秒) #{i+1}", value=float(default_start), min_value=0.0, max_value=30.0, step=0.5, key=_ui_key(f"tune_start_{i}"))
with c4:
# 兼容旧格式: 如果有 duration_ratio 则转换为 duration
default_dur = item.get("duration", item.get("duration_ratio", 0.25) * 12)
item["duration"] = st.number_input(f"时长(秒) #{i+1}", value=float(default_dur), min_value=0.5, max_value=15.0, step=0.5, key=_ui_key(f"tune_dur_{i}"))
# 清理旧字段
item.pop("start_ratio", None)
item.pop("duration_ratio", None)
updated_timeline.append(item)
st.session_state.script_data["voiceover_timeline"] = updated_timeline
st.divider()
# Per Scene Editor (Only Fancy Text)
scenes = st.session_state.script_data.get("scenes", [])
updated_scenes = []
for i, scene in enumerate(scenes):
with st.container():
st.markdown(f"**Scene {scene['id']}**")
ft = scene.get("fancy_text", {})
ft_text = ft.get("text", "") if isinstance(ft, dict) else ""
new_ft = st.text_input(f"花字", value=ft_text, key=_ui_key(f"tune_ft_{i}"))
# 兼容:旧数据可能没有 fancy_text 字段
if not isinstance(scene.get("fancy_text"), dict):
scene["fancy_text"] = {}
scene["fancy_text"]["text"] = new_ft
updated_scenes.append(scene)
st.session_state.script_data["scenes"] = updated_scenes
if st.button("🔄 重新合成 (Re-Compose)", type="primary"):
with st.spinner("正在应用修改并重新合成..."):
# 自动补齐视频下载逻辑 (关键优化)
_ensure_local_videos(st.session_state.project_id, st.session_state.script_data.get("scenes", []))
composer = VideoComposer(voice_type=selected_voice)
bgm_path = None
if selected_bgm != "None":
bgm_path = str(config.ASSETS_DIR / "bgm" / selected_bgm)
try:
# Save updated script first
db.update_project_script(st.session_state.project_id, st.session_state.script_data)
t0 = perf_counter()
output_path = composer.compose_from_script(
script=st.session_state.script_data,
video_map=st.session_state.scene_videos,
bgm_path=bgm_path,
output_name=f"final_{st.session_state.project_id}_{int(time.time())}", # Unique name for history
project_id=st.session_state.project_id,
)
_record_metrics(st.session_state.project_id, {
"compose_s": round(perf_counter() - t0, 3),
"bgm_used": bool(bgm_path and Path(bgm_path).exists()),
})
st.session_state.final_video = output_path
db.save_asset(st.session_state.project_id, 0, "final_video", "completed", local_path=output_path)
st.success("重新合成成功!")
st.rerun()
except Exception as e:
st.error(f"合成失败: {e}")
with tab_preview:
st.subheader("🛠️ 导出与交付")
col_ex1, col_ex2 = st.columns([1, 2])
with col_ex1:
if st.button("📦 生成剪映素材包 (ZIP)", type="primary", key="btn_export_zip"):
with st.spinner("正在打包视频、音频与字幕..."):
try:
zip_path = export_utils.create_capcut_package(
st.session_state.project_id,
st.session_state.script_data,
{"scene_videos": st.session_state.scene_videos}
)
st.session_state['export_zip_path'] = zip_path
st.success("打包完成!请点击下载")
except Exception as e:
st.error(f"打包失败: {e}")
with col_ex2:
if 'export_zip_path' in st.session_state and os.path.exists(st.session_state['export_zip_path']):
zip_name = f"capcut_pack_{st.session_state.project_id}.zip"
with open(st.session_state['export_zip_path'], "rb") as f:
st.download_button(
label=f"📥 点击下载素材包 ({zip_name})",
data=f,
file_name=zip_name,
mime="application/zip"
)
st.divider()
# 扫描所有历史合成版本
import glob
# 查找匹配当前项目的所有最终视频
pattern_timestamp = str(config.OUTPUT_DIR / f"final_{st.session_state.project_id}_*.mp4")
pattern_simple = str(config.OUTPUT_DIR / f"final_{st.session_state.project_id}.mp4")
found_files = []
for f in glob.glob(pattern_timestamp):
found_files.append(f)
for f in glob.glob(pattern_simple):
if f not in found_files:
found_files.append(f)
# 按修改时间倒序排列
found_files.sort(key=os.path.getmtime, reverse=True)
if not found_files:
st.info("暂无合成视频,请先点击开始合成。")
if st.button("✨ 开始首次合成", type="primary"):
with st.spinner("正在进行多轨合成..."):
# 自动补齐视频下载逻辑 (关键优化)
_ensure_local_videos(st.session_state.project_id, st.session_state.script_data.get("scenes", []))
# Default compose logic with smart BGM matching
composer = VideoComposer(voice_type=config.VOLC_TTS_DEFAULT_VOICE)
# 智能匹配 BGM根据脚本 bgm_style 匹配
bgm_style = st.session_state.script_data.get("bgm_style", "")
bgm_path = match_bgm_by_style(bgm_style, config.ASSETS_DIR / "bgm")
if bgm_path and not Path(bgm_path).exists():
st.warning(f"推荐的 BGM 文件不存在:{bgm_path}(本次将不含 BGM")
bgm_path = None
try:
# 首次合成也加上时间戳
output_name = f"final_{st.session_state.project_id}_{int(time.time())}"
t0 = perf_counter()
output_path = composer.compose_from_script(
script=st.session_state.script_data,
video_map=st.session_state.scene_videos,
bgm_path=bgm_path,
output_name=output_name,
project_id=st.session_state.project_id,
)
_record_metrics(st.session_state.project_id, {
"compose_s": round(perf_counter() - t0, 3),
"bgm_used": bool(bgm_path and Path(bgm_path).exists()),
})
st.session_state.final_video = output_path
db.save_asset(st.session_state.project_id, 0, "final_video", "completed", local_path=output_path)
db.update_project_status(st.session_state.project_id, "completed")
st.rerun()
except Exception as e:
st.error(f"合成失败: {e}")
else:
st.success(f"共找到 {len(found_files)} 个历史版本")
# 提示用户:如果重新合成,系统会自动补全未下载的素材
st.caption("💡 重新合成将应用当前的微调设置。如果分镜未下载,系统将尝试自动补全。")
# 遍历显示所有历史版本
for idx, vid_path in enumerate(found_files):
mtime = os.path.getmtime(vid_path)
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime))
file_name = os.path.basename(vid_path)
# 容器样式
with st.container():
if idx == 0:
st.markdown(f"### 🟢 最新版本 ({time_str})")
else:
st.markdown(f"#### ⚪ 历史版本 {len(found_files)-idx} ({time_str})")
c1, c2, c3 = st.columns([1, 2, 1])
with c2:
st.video(vid_path)
# 下载按钮
with open(vid_path, "rb") as file:
st.download_button(
label=f"📥 下载 ({file_name})",
data=file,
file_name=file_name,
mime="video/mp4",
key=f"dl_btn_{file_name}"
)
st.divider()
# ============================================================
# Page: History (New)
# ============================================================
elif st.session_state.view_mode == "history":
st.header("📜 历史任务")
projects = db.list_projects_for_user(_current_user())
for proj in projects:
with st.expander(f"{proj['name']} ({proj['updated_at']})"):
st.caption(f"ID: {proj['id']} | Status: {proj.get('status', 'unknown')}")
assets = db.get_assets(proj['id'])
if assets:
st.markdown("#### 📥 下载素材")
# Group assets
final_vids = [a for a in assets if a['asset_type'] == 'final_video']
scene_imgs = [a for a in assets if a['asset_type'] == 'image']
scene_vids = [a for a in assets if a['asset_type'] == 'video']
# Final Video Download
if final_vids:
fv = final_vids[0]
if fv['local_path'] and os.path.exists(fv['local_path']):
st.success(f"✅ **最终成片已生成**")
with open(fv['local_path'], "rb") as f:
st.download_button(
label=f"📥 下载最终成片 ({os.path.basename(fv['local_path'])})",
data=f,
file_name=os.path.basename(fv['local_path']),
mime="video/mp4"
)
# Script Download
if proj.get("script_data"):
try:
script_json = json.dumps(proj["script_data"] if isinstance(proj["script_data"], dict) else json.loads(proj["script_data"]), ensure_ascii=False, indent=2)
st.download_button(
label="📥 下载脚本/字幕配置 (JSON)",
data=script_json,
file_name=f"script_{proj['id']}.json",
mime="application/json"
)
except:
pass
st.markdown("---")
c1, c2 = st.columns(2)
with c1:
st.markdown("**分镜图片**")
for img in scene_imgs:
if img['local_path'] and os.path.exists(img['local_path']):
with open(img['local_path'], "rb") as f:
st.download_button(f"Scene {img['scene_id']} 图片", f, key=f"dl_img_{proj['id']}_{img['id']}", file_name=f"scene_{img['scene_id']}.png")
with c2:
st.markdown("**分镜视频**")
for vid in scene_vids:
if vid['local_path'] and os.path.exists(vid['local_path']):
with open(vid['local_path'], "rb") as f:
st.download_button(f"Scene {vid['scene_id']} 视频", f, key=f"dl_vid_{proj['id']}_{vid['id']}", file_name=f"scene_{vid['scene_id']}.mp4")
else:
st.info("暂无生成素材")
# ============================================================
# Page: Settings (New)
# ============================================================
elif st.session_state.view_mode == "settings":
st.header("⚙️ 设置配置")
st.subheader("Prompt 配置")
# Script Generation Prompt
cu = _current_user() or {}
user_prompt = None
try:
user_prompt = db.get_user_prompt(cu.get("id"), "prompt_script_gen") if cu.get("id") else None
except Exception:
user_prompt = None
current_prompt = user_prompt or db.get_config("prompt_script_gen")
# 显示当前状态
if user_prompt:
st.info("✅ 已加载自定义 Prompt当前用户")
elif current_prompt:
st.info("✅ 已加载自定义 Prompt全局默认")
else:
st.warning("⚠️ 使用默认 Prompt数据库中无自定义配置")
# Load default from instance if not in DB
temp_gen = ScriptGenerator()
current_prompt = temp_gen.default_system_prompt
new_prompt = st.text_area("脚本拆分 Prompt (System Prompt)", value=current_prompt, height=400)
col_save, col_reset = st.columns([1, 3])
with col_save:
if st.button("💾 保存配置", type="primary"):
# Save per-user prompt
db.set_user_prompt(cu.get("id"), "prompt_script_gen", new_prompt)
saved = db.get_user_prompt(cu.get("id"), "prompt_script_gen")
if saved == new_prompt:
st.success("✅ 已保存为“当前用户 Prompt”。下次生成脚本仅影响你自己的账号。")
else:
st.error("❌ 保存可能失败,请检查日志")
with col_reset:
if st.button("🔄 恢复默认"):
temp_gen = ScriptGenerator()
# Clear per-user override: set empty to remove effect
db.set_user_prompt(cu.get("id"), "prompt_script_gen", "")
st.success("已清除当前用户 Prompt将回退到全局/默认 Prompt。")
st.rerun()
st.markdown("---")
st.subheader("账号与权限")
# Password change for current user
with st.expander("修改我的密码", expanded=False):
p1 = st.text_input("新密码", type="password", key=_ui_key("pwd_new"))
p2 = st.text_input("确认新密码", type="password", key=_ui_key("pwd_new2"))
if st.button("保存新密码", type="primary", key=_ui_key("pwd_save")):
if not p1 or len(p1) < 6:
st.error("密码至少 6 位")
elif p1 != p2:
st.error("两次输入不一致")
else:
db.reset_user_password(cu.get("id"), p1)
st.success("密码已更新,请重新登录。")
_logout()
# Admin console
if (cu.get("role") == "admin"):
st.markdown("### Admin账号管理")
users = db.list_users()
if users:
st.dataframe(
[{k: u.get(k) for k in ["username", "role", "is_active", "last_login_at", "created_at", "id"]} for u in users],
use_container_width=True,
)
with st.expander("创建/更新用户", expanded=False):
u_name = st.text_input("用户名", key=_ui_key("adm_u_name"))
u_role = st.selectbox("角色", ["user", "admin"], index=0, key=_ui_key("adm_u_role"))
u_active = st.selectbox("是否启用", ["启用", "禁用"], index=0, key=_ui_key("adm_u_active"))
u_pwd = st.text_input("初始/重置密码", type="password", key=_ui_key("adm_u_pwd"))
if st.button("保存用户", type="primary", key=_ui_key("adm_u_save")):
if not u_name:
st.error("用户名不能为空")
elif not u_pwd or len(u_pwd) < 6:
st.error("密码至少 6 位")
else:
uid = db.upsert_user(
username=u_name.strip(),
password=u_pwd,
role=u_role,
is_active=(1 if u_active == "启用" else 0),
)
st.success(f"已保存用户:{u_name} (id={uid})")
st.rerun()
with st.expander("重置/禁用用户(按用户名)", expanded=False):
uname = st.selectbox(
"选择用户",
options=[u.get("username") for u in users if u.get("username")],
key=_ui_key("adm_sel_user"),
) if users else None
new_pass = st.text_input("新密码(可选)", type="password", key=_ui_key("adm_reset_pwd"))
new_active = st.selectbox("状态", ["不修改", "启用", "禁用"], key=_ui_key("adm_reset_active"))
if st.button("应用修改", type="secondary", key=_ui_key("adm_apply")):
target = next((u for u in users if u.get("username") == uname), None)
if not target:
st.error("未找到用户")
else:
if new_pass:
if len(new_pass) < 6:
st.error("密码至少 6 位")
st.stop()
db.reset_user_password(target.get("id"), new_pass)
if new_active != "不修改":
db.set_user_active(target.get("id"), 1 if new_active == "启用" else 0)
st.success("已更新")
st.rerun()