Files
video-flow/app.py

1370 lines
70 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
MatchMe Studio - UI (Streamlit)
Style: Kaogujia (Clean, Data-heavy, Professional)
"""
import streamlit as st
import json
import time
import os
import random
from pathlib import Path
import pandas as pd
from time import perf_counter
from concurrent.futures import ThreadPoolExecutor, as_completed
# Import Backend Modules
import config
from modules.script_gen import ScriptGenerator
from modules.image_gen import ImageGenerator
from modules.video_gen import VideoGenerator
from modules.composer import VideoComposer, VideoComposer as Composer # alias
from modules.text_renderer import renderer
from modules import export_utils
from modules.db_manager import db
from modules import path_utils
from modules import limits
from modules.legacy_path_mapper import map_legacy_local_path
from modules.legacy_normalizer import normalize_legacy_project
# Page Config
st.set_page_config(
page_title="Video Flow Console",
page_icon="🎬",
layout="wide",
initial_sidebar_state="expanded"
)
# ============================================================
# BGM 智能匹配函数
# ============================================================
def match_bgm_by_style(bgm_style: str, bgm_dir: Path) -> str:
"""
根据脚本 bgm_style 智能匹配 BGM 文件
- 匹配成功:随机选一个匹配的 BGM
- 匹配失败:随机选任意一个 BGM
"""
# 获取所有 BGM 文件 (支持 .mp3 和 .mp4)
bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3"))
bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')]
if not bgm_files:
return None
# 关键词匹配
if bgm_style:
style_lower = bgm_style.lower()
# 提取关键词 (中文分词简化版:按常见词匹配)
keywords = ["活泼", "欢快", "轻松", "舒缓", "休闲", "温柔", "随性", "百搭", "bling", "节奏"]
matched_keywords = [kw for kw in keywords if kw in style_lower]
matched_files = []
for f in bgm_files:
fname = f.name
if any(kw in fname for kw in matched_keywords):
matched_files.append(f)
if matched_files:
return str(random.choice(matched_files))
# 无匹配则随机选一个
return str(random.choice(bgm_files))
# ============================================================
# CSS Styling (Kaogujia Style)
# ============================================================
st.markdown("""
<style>
/* 只保留颜色和字体样式,完全不干预布局和滚动 */
.stApp {
background-color: #F4F5F7;
font-family: "PingFang SC", "Microsoft YaHei", sans-serif;
}
section[data-testid="stSidebar"] {
background-color: #FFFFFF;
border-right: 1px solid #E5E7EB;
}
.stButton button {
border-radius: 4px;
font-weight: 500;
}
div[data-testid="stButton"] > button[kind="primary"] {
background-color: #1677FF;
color: white;
border: none;
}
div[data-testid="stButton"] > button[kind="primary"]:hover {
background-color: #4096FF;
}
h1, h2, h3 {
color: #1F2329;
font-weight: 600;
}
h2 { border-left: 4px solid #1677FF; padding-left: 10px; }
.stTextInput input, .stTextArea textarea {
border-radius: 4px;
}
</style>
""", unsafe_allow_html=True)
# ============================================================
# Session State Management
# ============================================================
if "project_id" not in st.session_state:
st.session_state.project_id = None
if "current_step" not in st.session_state:
st.session_state.current_step = 0
if "script_data" not in st.session_state:
st.session_state.script_data = None
if "scene_images" not in st.session_state:
st.session_state.scene_images = {}
if "scene_videos" not in st.session_state:
st.session_state.scene_videos = {}
if "final_video" not in st.session_state:
st.session_state.final_video = None
if "uploaded_images" not in st.session_state:
st.session_state.uploaded_images = []
if "view_mode" not in st.session_state:
st.session_state.view_mode = "workspace" # workspace, history, settings
if "selected_img_provider" not in st.session_state:
st.session_state.selected_img_provider = "shubiaobiao"
def load_project(project_id):
"""Load project state from DB"""
data = db.get_project(project_id)
if not data:
st.error("Project not found")
return
st.session_state.project_id = project_id
st.session_state.script_data = data.get("script_data")
st.session_state.view_mode = "workspace"
# Fallback: 如果 DB 中的 script_data 是旧结构/缺字段,则从 legacy JSON 重新规范化一次
try:
script_data = st.session_state.script_data
legacy_json = Path(config.TEMP_DIR) / f"project_{project_id}.json"
def _needs_normalize(sd: Any) -> bool:
if not isinstance(sd, dict):
return True
if "_legacy_schema" not in sd:
return True
scenes = sd.get("scenes") or []
if scenes and isinstance(scenes, list) and isinstance(scenes[0], dict):
if "visual_prompt" not in scenes[0] or "video_prompt" not in scenes[0]:
return True
return False
if legacy_json.exists() and _needs_normalize(script_data):
raw = json.loads(legacy_json.read_text(encoding="utf-8"))
normalized = normalize_legacy_project(raw)
st.session_state.script_data = normalized
# 写回 DB避免每次 load 都重新算
db.update_project_script(project_id, normalized)
st.info("已从 legacy JSON 重新规范化脚本字段(兼容旧版项目)。")
except Exception as e:
st.warning(f"legacy 规范化失败(将继续使用 DB 数据): {e}")
# Restore product info for Step 1 display
product_info = data.get("product_info", {})
st.session_state.loaded_product_name = data.get("name", "")
st.session_state.loaded_product_info = product_info
# Restore uploaded images from product_info
if product_info and "uploaded_images" in product_info:
# 检查文件是否仍存在
valid_paths = [p for p in product_info["uploaded_images"] if os.path.exists(p)]
st.session_state.uploaded_images = valid_paths
if len(valid_paths) < len(product_info["uploaded_images"]):
st.warning(f"部分原始图片已丢失 ({len(product_info['uploaded_images']) - len(valid_paths)} 张),建议重新上传")
else:
st.session_state.uploaded_images = []
# Restore assets
assets = db.get_assets(project_id)
images = {}
videos = {}
final_vid = None
for asset in assets:
sid = asset["scene_id"]
source_path, _mapped_url = map_legacy_local_path(asset.get("local_path"))
# 假设 scene_id 0 或 -1 用于 final video
if asset["asset_type"] == "image" and asset["status"] == "completed":
if source_path:
images[sid] = source_path
elif asset["asset_type"] == "video" and asset["status"] == "completed":
if source_path:
videos[sid] = source_path
elif asset["asset_type"] == "final_video" and asset["status"] == "completed":
if source_path:
final_vid = source_path
st.session_state.scene_images = images
st.session_state.scene_videos = videos
st.session_state.final_video = final_vid
# Determine step
if final_vid:
st.session_state.current_step = 4
elif videos:
st.session_state.current_step = 4
elif images:
st.session_state.current_step = 3
elif st.session_state.script_data:
st.session_state.current_step = 1
else:
st.session_state.current_step = 0
# ============================================================
# Sidebar
# ============================================================
with st.sidebar:
st.title("📽️ Video Flow")
# Mode Selection - 正确计算 index
mode_options = ["🛠️ 工作台", "📜 历史任务", "⚙️ 设置"]
mode_map = {"workspace": 0, "history": 1, "settings": 2}
current_index = mode_map.get(st.session_state.view_mode, 0)
mode = st.radio("模式", mode_options, index=current_index)
if mode == "🛠️ 工作台":
st.session_state.view_mode = "workspace"
elif mode == "📜 历史任务":
st.session_state.view_mode = "history"
elif mode == "⚙️ 设置":
st.session_state.view_mode = "settings"
st.markdown("---")
if st.session_state.view_mode == "workspace":
# Project Selection
st.subheader("Current Project")
projects = db.list_projects()
proj_options = {p['id']: f"{p.get('name', 'Untitled')} ({p['id']})" for p in projects}
selected_proj_id = st.selectbox(
"Select Project",
options=["New Project"] + list(proj_options.keys()),
format_func=lambda x: " New Project" if x == "New Project" else proj_options[x]
)
if selected_proj_id == "New Project":
if st.session_state.project_id and st.session_state.project_id in proj_options:
# If switching from existing to new, reset
if st.button("Start New"):
for key in list(st.session_state.keys()):
if key != "view_mode": del st.session_state[key]
st.rerun()
else:
if st.session_state.project_id != selected_proj_id:
if st.button(f"Load {selected_proj_id}"):
load_project(selected_proj_id)
st.rerun()
if st.session_state.project_id:
st.caption(f"Current ID: {st.session_state.project_id}")
with st.expander("⏱️ 性能与诊断", expanded=False):
m = _get_metrics(st.session_state.project_id)
if not m:
st.caption("暂无指标(执行一次脚本/生图/生视频/合成后会出现)。")
else:
keys = [
"script_gen_s",
"image_gen_total_s",
"video_submit_s",
"video_recover_s",
"compose_s",
"script_model",
"image_provider",
"image_generated",
"video_submitted",
"video_recovered",
"bgm_used",
]
for k in keys:
if k in m:
st.caption(f"{k}: {m.get(k)}")
# 在线剪辑入口React Editor
web_base_url = os.getenv("WEB_BASE_URL", "http://localhost:3000").rstrip("/")
st.markdown(
f"[打开在线剪辑器]({web_base_url}/editor/{st.session_state.project_id})",
unsafe_allow_html=False,
)
st.markdown("---")
# Navigation / Progress
steps = ["1. 输入信息", "2. 生成脚本", "3. 画面生成", "4. 视频生成", "5. 合成输出"]
current = st.session_state.current_step
for i, step in enumerate(steps):
if i < current:
st.markdown(f"✅ **{step}**")
elif i == current:
st.markdown(f"🔵 **{step}**")
else:
st.markdown(f"{step}")
st.markdown("---")
if st.button("🔄 重置状态", type="secondary"):
for key in list(st.session_state.keys()):
if key != "view_mode": del st.session_state[key]
st.rerun()
# ============================================================
# Helper Functions
# ============================================================
def _record_metrics(project_id: str, patch: dict):
"""Persist lightweight timing/diagnostic metrics into project.product_info['_metrics']."""
if not project_id or not isinstance(patch, dict) or not patch:
return
try:
proj = db.get_project(project_id) or {}
product_info = proj.get("product_info") or {}
metrics = product_info.get("_metrics") if isinstance(product_info.get("_metrics"), dict) else {}
metrics.update(patch)
metrics["updated_at"] = time.time()
product_info["_metrics"] = metrics
db.update_project_product_info(project_id, product_info)
except Exception:
# metrics must never break UX
pass
def _get_metrics(project_id: str) -> dict:
try:
proj = db.get_project(project_id) or {}
product_info = proj.get("product_info") or {}
m = product_info.get("_metrics")
return m if isinstance(m, dict) else {}
except Exception:
return {}
def save_uploaded_file(project_id: str, uploaded_file):
"""Save uploaded file to per-project upload dir (avoid overwrites across projects)."""
if uploaded_file is None:
return None
if not project_id:
raise ValueError("project_id is required to save uploaded files safely")
upload_dir = path_utils.project_upload_dir(project_id)
original = path_utils.sanitize_filename(getattr(uploaded_file, "name", "upload"))
# keep original stem for readability, but ensure uniqueness
suffix = Path(original).suffix.lstrip(".") or "bin"
stem = Path(original).stem or "upload"
unique_name = path_utils.unique_filename(prefix=f"upload_{stem}", ext=suffix, project_id=project_id)
file_path = upload_dir / unique_name
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
return str(file_path)
# ============================================================
# Main Content: Workspace
# ============================================================
if st.session_state.view_mode == "workspace":
st.header("商品短视频自动生成控制台")
# --- Step 1: Input ---
with st.expander("📦 1. 商品信息输入", expanded=(st.session_state.current_step == 0)):
# 从 session_state 读取已加载项目的信息,否则使用默认值
loaded_info = st.session_state.get("loaded_product_info", {})
default_name = st.session_state.get("loaded_product_name", "网红气质大号发量多!高马尾香蕉夹")
default_category = loaded_info.get("category", "钟表配饰-时尚饰品-发饰")
default_price = loaded_info.get("price", "3.99元")
default_tags = loaded_info.get("tags", "回头客|款式好看|材质好|尺寸合适|颜色好看|很好用|做工好|质感不错|很牢固")
default_params = loaded_info.get("params", "金属材质:非金属; 非金属材质:树脂; 发夹分类:香蕉夹")
col1, col2 = st.columns([1, 1])
with col1:
product_name = st.text_input("商品标题", value=default_name)
category = st.text_input("商品类目", value=default_category)
price = st.text_input("价格", value=default_price)
with col2:
tags = st.text_area("评价标签 (用于提炼卖点)", value=default_tags, height=100)
params = st.text_area("商品参数", value=default_params, height=100)
# 商家自定义风格提示
style_hint = st.text_area(
"商品视频重点增强提示 (可选)",
value=loaded_info.get("style_hint", ""),
placeholder="例如:韩风、高级感、活力青春、简约日系...",
height=80
)
st.markdown("### 上传素材")
# 显示已有的上传图片(如果从历史项目加载)
if st.session_state.uploaded_images:
st.info(f"已有 {len(st.session_state.uploaded_images)} 张参考图片")
with st.expander("查看已有图片"):
img_cols = st.columns(min(len(st.session_state.uploaded_images), 4))
for i, img_path in enumerate(st.session_state.uploaded_images[:4]):
if os.path.exists(img_path):
img_cols[i % 4].image(img_path, width=150)
uploaded_files = st.file_uploader("上传商品主图 (建议 3-5 张)", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
# 允许在没有上传新图片但有历史图片的情况下继续
can_submit = uploaded_files or st.session_state.uploaded_images
# Model Selection (all support images; user explicitly chooses model)
model_options = ["GPT-5.2", "Gemini 3 Pro", "Doubao Pro (Vision)"]
selected_model_label = st.radio("选择脚本生成模型", model_options, horizontal=True, index=0)
# Map label to provider key
if selected_model_label == "GPT-5.2":
model_provider = "shubiaobiao_gpt"
elif "Doubao" in selected_model_label:
model_provider = "doubao"
else:
model_provider = "shubiaobiao"
if st.button("提交任务 & 生成脚本", type="primary", disabled=(not can_submit)):
# 处理图片路径
image_paths = list(st.session_state.uploaded_images) if st.session_state.uploaded_images else []
# 如果有新上传的文件,添加它们
if uploaded_files:
# Create Project ID
if not st.session_state.project_id:
st.session_state.project_id = f"PROJ-{int(time.time())}"
for uf in uploaded_files:
path = save_uploaded_file(st.session_state.project_id, uf)
if path: image_paths.append(path)
st.session_state.uploaded_images = image_paths
# 确保有项目 ID
if not st.session_state.project_id:
st.session_state.project_id = f"PROJ-{int(time.time())}"
# DB: Create Project
# 将 uploaded_images 保存到 product_info 以便持久化
product_info = {"category": category, "price": price, "tags": tags, "params": params, "style_hint": style_hint, "uploaded_images": image_paths}
db.create_project(st.session_state.project_id, product_name, product_info)
# Call Script Generator
with st.spinner(f"正在分析商品信息并生成脚本 ({selected_model_label})..."):
gen = ScriptGenerator()
t0 = perf_counter()
script = gen.generate_script(product_name, product_info, image_paths, model_provider=model_provider)
_record_metrics(st.session_state.project_id, {
"script_gen_s": round(perf_counter() - t0, 3),
"script_model": model_provider,
})
if script:
st.session_state.script_data = script
# DB: Save Script
db.update_project_script(st.session_state.project_id, script)
st.session_state.current_step = 1
st.success("脚本生成成功!")
st.rerun()
else:
st.error("脚本生成失败,请检查日志。")
# --- Step 2: Script Review ---
if st.session_state.current_step >= 1:
with st.expander("📝 2. 脚本分镜确认", expanded=(st.session_state.current_step == 1)):
if st.session_state.script_data:
script = st.session_state.script_data
# Display Basic Info (兼容 legacy schema)
selling_points = script.get("selling_points", []) or []
target_audience = script.get("target_audience", "") or ""
analysis_text = script.get("analysis", "") or ""
legacy_schema = script.get("_legacy_schema", "") or ""
c1, c2 = st.columns(2)
if selling_points:
c1.write(f"**核心卖点**: {', '.join(selling_points)}")
else:
c1.write("**核心卖点**: (legacy 项目可能未生成该字段)")
if analysis_text:
with st.expander("查看 legacy analysis用于补齐信息"):
st.write(analysis_text)
if target_audience:
c2.write(f"**目标人群**: {target_audience}")
else:
c2.write("**目标人群**: (legacy 项目可能未生成该字段)")
# Hook / CTA / Schema
hook = script.get("hook", "") or ""
if hook:
st.markdown(f"**Hook**: {hook}")
cta = script.get("cta", "")
if cta:
if isinstance(cta, dict):
st.markdown("**CTAlegacy object**")
st.json(cta)
else:
st.markdown(f"**CTA**: {cta}")
if legacy_schema:
st.caption(f"Legacy Schema: {legacy_schema}")
# Prompt Visualization
if "_debug" in script:
with st.expander("🔍 查看 AI Prompt (Debug)"):
debug_info = script["_debug"]
st.markdown("**System Prompt:**")
st.code(debug_info.get("system_prompt", ""), language="markdown")
st.markdown("**User Prompt:**")
st.code(debug_info.get("user_prompt", ""), language="markdown")
# 显示原始输出
if "raw_output" in debug_info:
st.markdown("**Raw AI Output:**")
st.code(debug_info.get("raw_output", ""), language="markdown")
# Editable Scenes Table
st.markdown("### 分镜列表")
scenes = script.get("scenes", [])
# Global Voiceover Timeline (New)
st.markdown("### 🎙️ 整体旁白与字幕时间轴")
with st.expander("编辑旁白时间轴 (Voiceover Timeline)", expanded=True):
timeline = script.get("voiceover_timeline", []) or []
if not timeline:
# 对于历史项目:如果没有 scenes 也没有 timeline不要强行塞“示例旁白”避免污染数据
if not scenes and analysis_text:
st.info("该历史项目暂无旁白时间轴(可能停留在分析/提问阶段)。")
timeline = []
else:
# Init with default if empty (使用秒)
timeline = [{"text": "示例旁白", "subtitle": "示例字幕", "start_time": 0.0, "duration": 3.0}]
updated_timeline = []
for i, item in enumerate(timeline):
c1, c2, c3, c4 = st.columns([3, 3, 1, 1])
with c1:
item["text"] = st.text_input(f"旁白 #{i+1}", value=item.get("text", ""), key=f"tl_vo_{i}")
with c2:
item["subtitle"] = st.text_input(f"字幕 #{i+1}", value=item.get("subtitle", item.get("text", "")), key=f"tl_sub_{i}")
with c3:
# 兼容旧格式: 如果有 start_ratio 则转换为 start_time
default_start = item.get("start_time", item.get("start_ratio", 0) * 12) # 假设总时长12秒
item["start_time"] = st.number_input(f"开始(秒) #{i+1}", value=float(default_start), min_value=0.0, max_value=30.0, step=0.5, key=f"tl_start_{i}")
with c4:
# 兼容旧格式: 如果有 duration_ratio 则转换为 duration
default_dur = item.get("duration", item.get("duration_ratio", 0.25) * 12) # 假设总时长12秒
item["duration"] = st.number_input(f"时长(秒) #{i+1}", value=float(default_dur), min_value=0.5, max_value=15.0, step=0.5, key=f"tl_dur_{i}")
# 清理旧字段
item.pop("start_ratio", None)
item.pop("duration_ratio", None)
updated_timeline.append(item)
if st.button(" 添加旁白段落"):
updated_timeline.append({"text": "", "subtitle": "", "start_time": 0.0, "duration": 3.0})
st.rerun()
script["voiceover_timeline"] = updated_timeline
updated_scenes = []
for i, scene in enumerate(scenes):
with st.container():
st.markdown(f"#### 🎬 Scene {scene['id']}")
c_vis, c_aud = st.columns([2, 1])
with c_vis:
new_visual = st.text_area(f"Visual Prompt (Scene {scene['id']})", value=scene.get("visual_prompt", ""), height=80, key=f"vp_{i}")
new_video = st.text_area(f"Video Prompt (Scene {scene['id']})", value=scene.get("video_prompt", ""), height=80, key=f"vidp_{i}")
scene["visual_prompt"] = new_visual
scene["video_prompt"] = new_video
with c_aud:
# 花字编辑保留
ft = scene.get("fancy_text", {})
if isinstance(ft, dict):
new_ft_text = st.text_input(
f"Fancy Text (Scene {scene['id']})",
value=ft.get("text", ""),
key=f"ft_{i}",
)
# 兼容:旧数据可能没有 fancy_text 字段
if not isinstance(scene.get("fancy_text"), dict):
scene["fancy_text"] = {}
scene["fancy_text"]["text"] = new_ft_text
# 旁白/字幕已移至上方整体时间轴,此处仅作展示或删除
st.caption("注:旁白与字幕已移至上方整体时间轴编辑")
# Legacy 信息展示(只读,用于调试/对齐)
legacy_scene = scene.get("_legacy", {}) if isinstance(scene.get("_legacy", {}), dict) else {}
if legacy_scene:
with st.expander(f"Legacy 信息 (Scene {scene['id']})", expanded=False):
img_url = legacy_scene.get("image_url")
if img_url:
st.markdown(f"- image_url: `{img_url}`")
cam = legacy_scene.get("camera_movement")
if cam:
st.markdown(f"- camera_movement: {cam}")
vo = legacy_scene.get("voiceover")
if vo:
st.markdown(f"- voiceover: {vo}")
keyframe = legacy_scene.get("keyframe")
if keyframe:
st.markdown("- keyframe:")
st.json(keyframe)
rhythm = legacy_scene.get("rhythm")
if rhythm:
st.markdown("- rhythm:")
st.json(rhythm)
updated_scenes.append(scene)
st.divider()
st.session_state.script_data["scenes"] = updated_scenes
col_act1, col_act2 = st.columns([1, 4])
with col_act1:
if st.button("确认脚本 & 开始生图", type="primary"):
# DB: Update script in case user edited it
db.update_project_script(st.session_state.project_id, st.session_state.script_data)
st.session_state.current_step = 2
st.rerun()
# --- Step 3: Image Generation ---
if st.session_state.current_step >= 2:
with st.expander("🎨 3. 画面生成", expanded=(st.session_state.current_step == 2)):
# Debug: Show reference image
if st.session_state.uploaded_images:
with st.expander("🔍 查看生图参考底图 (Base Image)"):
for i, p in enumerate(st.session_state.uploaded_images):
st.info(f"{i+1} 文件名: {os.path.basename(p)}")
if os.path.exists(p):
st.image(p, width=200)
else:
st.warning("⚠️ 未检测到参考底图!将生成随机风格图像。建议在 Step 1 重新上传。")
# Trigger Generation Button
if not st.session_state.scene_images:
# Model Selection for Image Gen
img_model_options = ["Shubiaobiao (Gemini)", "Doubao (Volcengine)", "Gemini (Direct)", "Doubao (Group Image)"]
selected_img_model = st.radio("选择生图模型", img_model_options, horizontal=True)
# Map to provider
if "Group Image" in selected_img_model:
img_provider = "doubao-group"
elif "Doubao" in selected_img_model:
img_provider = "doubao"
elif "Direct" in selected_img_model:
img_provider = "gemini"
else:
img_provider = "shubiaobiao"
# Store selected provider for regeneration
st.session_state.selected_img_provider = img_provider
if st.button("🚀 执行 AI 生图", type="primary"):
with limits.acquire_image(blocking=False) as ok:
if not ok:
st.warning("系统正在生成其他任务(生图并发已达上限),请稍后再试。")
st.stop()
img_gen = ImageGenerator()
# Pass ALL uploaded images as reference
base_imgs = st.session_state.uploaded_images if st.session_state.uploaded_images else []
if not base_imgs:
st.error("No base image found (未找到参考底图). Please upload in Step 1.")
st.stop()
scenes = st.session_state.script_data.get("scenes", [])
# 读取 visual_anchor 用于保持生图一致性
visual_anchor = st.session_state.script_data.get("visual_anchor", "")
if img_provider == "doubao-group":
# --- Group Generation Logic ---
with st.spinner("正在进行 Doubao 组图生成 (Batch Group Generation)..."):
try:
t0 = perf_counter()
results = img_gen.generate_group_images_doubao(
scenes=scenes,
reference_images=base_imgs,
visual_anchor=visual_anchor,
project_id=st.session_state.project_id
)
_record_metrics(st.session_state.project_id, {
"image_gen_total_s": round(perf_counter() - t0, 3),
"image_provider": img_provider,
"image_generated": len(results),
})
for s_id, path in results.items():
st.session_state.scene_images[s_id] = path
db.save_asset(st.session_state.project_id, s_id, "image", "completed", local_path=path)
if len(results) == len(scenes):
st.success("组图生成完成!")
db.update_project_status(st.session_state.project_id, "images_generated")
time.sleep(1)
st.rerun()
else:
st.warning(f"部分图片生成失败: {len(results)}/{len(scenes)}")
st.rerun()
except Exception as e:
st.error(f"组图生成失败: {e}")
else:
# --- Parallel Logic (default): only merchant uploaded images as references ---
total_scenes = len(scenes)
progress_bar = st.progress(0)
status_text = st.empty()
try:
t0 = perf_counter()
# Parallel workers within a single run; global semaphore already acquired above.
max_workers = 6
futures = {}
with ThreadPoolExecutor(max_workers=max_workers) as ex:
for idx, scene in enumerate(scenes):
scene_id = scene["id"]
futures[ex.submit(
img_gen.generate_single_scene_image,
scene=scene,
original_image_path=list(base_imgs), # ONLY merchant images
previous_image_path=None,
model_provider=img_provider,
visual_anchor=visual_anchor,
project_id=st.session_state.project_id,
)] = (idx, scene_id)
done = 0
for fut in as_completed(futures):
idx, scene_id = futures[fut]
done += 1
status_text.text(f"已完成 {done}/{total_scenes}Scene {scene_id}")
try:
img_path = fut.result()
except Exception as e:
img_path = None
st.warning(f"Scene {scene_id} 生成失败:{e}")
if img_path:
st.session_state.scene_images[scene_id] = img_path
db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=img_path)
progress_bar.progress(done / total_scenes)
status_text.text("生图完成!")
st.success("生图完成!")
_record_metrics(st.session_state.project_id, {
"image_gen_total_s": round(perf_counter() - t0, 3),
"image_provider": img_provider,
"image_generated": len(st.session_state.scene_images),
})
# Update Status
db.update_project_status(st.session_state.project_id, "images_generated")
time.sleep(1)
st.rerun()
except PermissionError as e:
st.error(f"生图服务拒绝访问 (可能余额不足): {e}")
except Exception as e:
st.error(f"生图发生错误: {e}")
# Display & Regenerate Images
if st.session_state.scene_images:
cols = st.columns(4)
scenes = st.session_state.script_data.get("scenes", [])
for i, scene in enumerate(scenes):
scene_id = scene["id"]
img_path = st.session_state.scene_images.get(scene_id)
with cols[i % 4]:
st.markdown(f"**Scene {scene_id}**")
# Expand Prompt info
with st.expander("Prompt", expanded=False):
st.caption(scene.get("visual_prompt", "No prompt"))
if img_path and os.path.exists(img_path):
st.image(img_path, width=None, use_column_width=True) # Fix deprecated
else:
st.error("Image missing")
# Regenerate specific image
if st.button(f"🔄 重生 S{scene_id}", key=f"regen_img_{scene_id}"):
if not st.session_state.uploaded_images:
st.error("缺少参考底图,请在 Step 1 重新上传图片")
else:
with st.spinner(f"正在重绘 Scene {scene_id}..."):
img_gen = ImageGenerator()
# Only merchant uploaded images as references (no chaining)
current_refs_for_regen = list(st.session_state.uploaded_images)
# Fallback to single mode for regen if group was used
provider = st.session_state.get("selected_img_provider", "shubiaobiao")
if provider == "doubao-group": provider = "doubao"
# 读取 visual_anchor
regen_visual_anchor = st.session_state.script_data.get("visual_anchor", "")
new_path = img_gen.generate_single_scene_image(
scene=scene,
original_image_path=current_refs_for_regen,
previous_image_path=None,
model_provider=provider,
visual_anchor=regen_visual_anchor,
project_id=st.session_state.project_id
)
if new_path:
st.session_state.scene_images[scene_id] = new_path
db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=new_path)
st.rerun()
if st.button("下一步:生成视频", type="primary"):
st.session_state.current_step = 3
st.rerun()
# --- Step 4: Video Generation ---
if st.session_state.current_step >= 3:
with st.expander("🎥 4. 视频生成 (Volcengine I2V)", expanded=(st.session_state.current_step == 3)):
scenes = st.session_state.script_data.get("scenes", [])
vid_gen = VideoGenerator()
# Submit-only (non-blocking) to avoid freezing Streamlit under concurrency
if st.button("🎬 提交图生视频任务(非阻塞)", type="primary"):
with limits.acquire_video(blocking=False) as ok:
if not ok:
st.warning("系统正在处理其他视频任务(并发已达上限),请稍后再试。")
st.stop()
t0 = perf_counter()
submitted = 0
for scene in scenes:
scene_id = scene["id"]
image_path = st.session_state.scene_images.get(scene_id)
prompt = scene.get("video_prompt", "High quality video")
task_id = vid_gen.submit_scene_video_task(
st.session_state.project_id, scene_id, image_path, prompt
)
if task_id:
submitted += 1
_record_metrics(st.session_state.project_id, {
"video_submit_s": round(perf_counter() - t0, 3),
"video_submitted": submitted,
})
if submitted:
db.update_project_status(st.session_state.project_id, "videos_processing")
st.success(f"已提交 {submitted} 个分镜视频任务。可点击下方“刷新恢复”下载结果。")
time.sleep(0.5)
st.rerun()
else:
st.warning("未提交任何任务(可能缺少图片或接口失败)。")
if st.button("🔄 刷新状态并恢复已完成任务", type="secondary"):
with limits.acquire_video(blocking=False) as ok:
if not ok:
st.warning("系统正在处理其他视频任务(并发已达上限),请稍后再试。")
st.stop()
t0 = perf_counter()
updated = 0
for scene in scenes:
scene_id = scene["id"]
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
if not asset or not asset.get("task_id"):
continue
# if already have local video, skip
existing = st.session_state.scene_videos.get(scene_id)
if existing and os.path.exists(existing):
continue
task_id = asset.get("task_id")
# Query volc status; store URL for direct preview (no server download)
status = None
url = None
# short retries for "succeeded but url missing"
for attempt in range(3):
status, url = vid_gen.check_task_status(task_id)
if status == "succeeded" and url:
break
time.sleep(0.5 * (2 ** attempt))
meta_patch = {"checked_at": time.time(), "volc_status": status}
if url:
meta_patch["video_url"] = url
db.update_asset_metadata(st.session_state.project_id, scene_id, "video", meta_patch)
updated += 1
_record_metrics(st.session_state.project_id, {
"video_recover_s": round(perf_counter() - t0, 3),
"video_recovered": updated,
})
if updated:
st.success(f"已刷新 {updated} 个分镜状态(成功的将以 URL 直连预览)。")
else:
st.info("暂无可恢复的视频(可能仍在排队/生成中)。")
time.sleep(0.5)
st.rerun()
if st.button("📥 准备合成素材(下载成功的视频到服务器)", type="secondary"):
with limits.acquire_video(blocking=False) as ok:
if not ok:
st.warning("系统正在处理其他视频任务(并发已达上限),请稍后再试。")
st.stop()
downloaded = 0
for scene in scenes:
scene_id = scene["id"]
existing = st.session_state.scene_videos.get(scene_id)
if existing and os.path.exists(existing):
continue
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
meta = (asset or {}).get("metadata") or {}
video_url = meta.get("video_url")
if not video_url:
continue
out_name = path_utils.unique_filename(
prefix="scene_video",
ext="mp4",
project_id=st.session_state.project_id,
scene_id=scene_id,
)
target_path = str(path_utils.project_videos_dir(st.session_state.project_id) / out_name)
if vid_gen._download_video_to(video_url, target_path):
st.session_state.scene_videos[scene_id] = target_path
db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=target_path, task_id=(asset or {}).get("task_id"), metadata=meta)
downloaded += 1
if downloaded:
st.success(f"已下载 {downloaded} 段视频,可进入合成。")
else:
st.info("暂无可下载的视频(请先刷新状态获取 video_url")
time.sleep(0.5)
st.rerun()
# Display Videos (even when partially available)
if st.session_state.scene_videos or scenes:
cols = st.columns(4)
for i, scene in enumerate(scenes):
scene_id = scene["id"]
vid_path = st.session_state.scene_videos.get(scene_id)
with cols[i % 4]:
st.markdown(f"**Scene {scene_id}**")
# Expand Prompt info
with st.expander("Prompt", expanded=False):
st.caption(scene.get("video_prompt", "No prompt"))
if vid_path and os.path.exists(vid_path):
st.video(vid_path)
else:
# Try URL preview from DB metadata
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
meta = (asset or {}).get("metadata") or {}
video_url = meta.get("video_url")
if video_url:
st.caption("URL 直连预览(不经服务器落盘)")
st.video(video_url)
else:
st.warning("Video missing")
# --- Recovery Logic ---
if asset and asset.get("task_id"):
task_id = asset.get("task_id")
if st.button(f"🔍 刷新URL (Task {task_id[-6:]})", key=f"recov_{scene_id}"):
with st.spinner("查询任务状态中..."):
status, url = vid_gen.check_task_status(task_id)
patch = {"checked_at": time.time(), "volc_status": status}
if url:
patch["video_url"] = url
db.update_asset_metadata(st.session_state.project_id, scene_id, "video", patch)
st.success("已刷新任务状态。")
st.rerun()
# Per-scene regenerate button
if st.button(f"🔄 重生 S{scene_id}", key=f"regen_vid_{scene_id}"):
with st.spinner(f"正在重生成 Scene {scene_id} 视频..."):
vid_gen = VideoGenerator()
img_path = st.session_state.scene_images.get(scene_id)
if img_path and os.path.exists(img_path):
t_id = vid_gen.submit_scene_video_task(st.session_state.project_id, scene_id, img_path, scene.get("video_prompt"))
if t_id:
# Simple poll loop for single video regen
import time
for _ in range(60):
status, url = vid_gen.check_task_status(t_id)
if status == "succeeded":
new_path = vid_gen._download_video(url, f"scene_{scene_id}_video_{int(time.time())}.mp4")
if new_path:
st.session_state.scene_videos[scene_id] = new_path
db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=new_path, task_id=t_id)
st.rerun()
break
elif status in ["failed", "cancelled"]:
st.error(f"生成失败: {status}")
break
time.sleep(5)
else:
st.error("提交任务失败")
else:
st.error("缺少源图片")
c_act1, c_act2 = st.columns([1, 4])
with c_act1:
if st.button("🔄 重新生成所有视频", type="secondary"):
# Clear videos and rerun
st.session_state.scene_videos = {}
st.rerun()
with c_act2:
if st.button("下一步:合成最终成片", type="primary"):
st.session_state.current_step = 4
st.rerun()
# --- Step 5: Final Composition & Tuning ---
if st.session_state.current_step >= 4:
with st.expander("🎞️ 5. 合成结果与微调", expanded=(st.session_state.current_step == 4)):
# Tabs: Preview vs Tuning
tab_preview, tab_tune = st.tabs(["🎥 预览", "🎛️ 微调 (Track Editor)"])
with tab_tune:
st.subheader("轨道微调 (Tuning)")
# Global Settings
col_g1, col_g2 = st.columns(2)
with col_g1:
# BGM Select (支持 .mp3 和 .mp4)
bgm_dir = config.ASSETS_DIR / "bgm"
bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3"))
bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')]
bgm_names = [f.name for f in bgm_files]
# 智能推荐默认 BGM
bgm_style = st.session_state.script_data.get("bgm_style", "")
recommended_bgm = match_bgm_by_style(bgm_style, bgm_dir)
default_idx = 0
if recommended_bgm:
rec_name = Path(recommended_bgm).name
if rec_name in bgm_names:
default_idx = bgm_names.index(rec_name) + 1 # +1 因为有 "None" 选项
selected_bgm = st.selectbox(
f"背景音乐 (BGM) - 推荐风格: {bgm_style or '未指定'}",
["None"] + bgm_names,
index=default_idx
)
# 明确提示BGM 目录为空或选中 BGM 不存在时,本次将不含 BGM
if not bgm_names:
st.warning(f"BGM 目录为空:{bgm_dir}(本次合成将不含 BGM")
elif selected_bgm != "None":
candidate = config.ASSETS_DIR / "bgm" / selected_bgm
if not candidate.exists():
st.warning(f"所选 BGM 文件不存在:{candidate}(本次合成将不含 BGM")
with col_g2:
# Voice Select
selected_voice = st.selectbox("配音音色 (TTS)", [config.VOLC_TTS_DEFAULT_VOICE, "zh_female_meilinvyou_saturn_bigtts"])
st.divider()
# Global Timeline Editor
st.markdown("### 🎙️ 旁白与字幕时间轴 (单位: 秒)")
timeline = st.session_state.script_data.get("voiceover_timeline", [])
updated_timeline = []
for i, item in enumerate(timeline):
c1, c2, c3, c4 = st.columns([3, 3, 1, 1])
with c1:
item["text"] = st.text_input(f"旁白 #{i+1}", value=item.get("text", ""), key=f"tune_vo_{i}")
with c2:
item["subtitle"] = st.text_input(f"字幕 #{i+1}", value=item.get("subtitle", item.get("text", "")), key=f"tune_sub_{i}")
with c3:
# 兼容旧格式: 如果有 start_ratio 则转换为 start_time
default_start = item.get("start_time", item.get("start_ratio", 0) * 12)
item["start_time"] = st.number_input(f"开始(秒) #{i+1}", value=float(default_start), min_value=0.0, max_value=30.0, step=0.5, key=f"tune_start_{i}")
with c4:
# 兼容旧格式: 如果有 duration_ratio 则转换为 duration
default_dur = item.get("duration", item.get("duration_ratio", 0.25) * 12)
item["duration"] = st.number_input(f"时长(秒) #{i+1}", value=float(default_dur), min_value=0.5, max_value=15.0, step=0.5, key=f"tune_dur_{i}")
# 清理旧字段
item.pop("start_ratio", None)
item.pop("duration_ratio", None)
updated_timeline.append(item)
st.session_state.script_data["voiceover_timeline"] = updated_timeline
st.divider()
# Per Scene Editor (Only Fancy Text)
scenes = st.session_state.script_data.get("scenes", [])
updated_scenes = []
for i, scene in enumerate(scenes):
with st.container():
st.markdown(f"**Scene {scene['id']}**")
ft = scene.get("fancy_text", {})
ft_text = ft.get("text", "") if isinstance(ft, dict) else ""
new_ft = st.text_input(f"花字", value=ft_text, key=f"tune_ft_{i}")
# 兼容:旧数据可能没有 fancy_text 字段
if not isinstance(scene.get("fancy_text"), dict):
scene["fancy_text"] = {}
scene["fancy_text"]["text"] = new_ft
updated_scenes.append(scene)
st.session_state.script_data["scenes"] = updated_scenes
if st.button("🔄 重新合成 (Re-Compose)", type="primary"):
with st.spinner("正在应用修改并重新合成..."):
composer = VideoComposer(voice_type=selected_voice)
bgm_path = None
if selected_bgm != "None":
bgm_path = str(config.ASSETS_DIR / "bgm" / selected_bgm)
try:
# Save updated script first
db.update_project_script(st.session_state.project_id, st.session_state.script_data)
t0 = perf_counter()
output_path = composer.compose_from_script(
script=st.session_state.script_data,
video_map=st.session_state.scene_videos,
bgm_path=bgm_path,
output_name=f"final_{st.session_state.project_id}_{int(time.time())}", # Unique name for history
project_id=st.session_state.project_id,
)
_record_metrics(st.session_state.project_id, {
"compose_s": round(perf_counter() - t0, 3),
"bgm_used": bool(bgm_path and Path(bgm_path).exists()),
})
st.session_state.final_video = output_path
db.save_asset(st.session_state.project_id, 0, "final_video", "completed", local_path=output_path)
st.success("重新合成成功!")
st.rerun()
except Exception as e:
st.error(f"合成失败: {e}")
with tab_preview:
st.subheader("🛠️ 导出与交付")
col_ex1, col_ex2 = st.columns([1, 2])
with col_ex1:
if st.button("📦 生成剪映素材包 (ZIP)", type="primary", key="btn_export_zip"):
with st.spinner("正在打包视频、音频与字幕..."):
try:
zip_path = export_utils.create_capcut_package(
st.session_state.project_id,
st.session_state.script_data,
{"scene_videos": st.session_state.scene_videos}
)
st.session_state['export_zip_path'] = zip_path
st.success("打包完成!请点击下载")
except Exception as e:
st.error(f"打包失败: {e}")
with col_ex2:
if 'export_zip_path' in st.session_state and os.path.exists(st.session_state['export_zip_path']):
zip_name = f"capcut_pack_{st.session_state.project_id}.zip"
with open(st.session_state['export_zip_path'], "rb") as f:
st.download_button(
label=f"📥 点击下载素材包 ({zip_name})",
data=f,
file_name=zip_name,
mime="application/zip"
)
st.divider()
# 扫描所有历史合成版本
import glob
# 查找匹配当前项目的所有最终视频
pattern_timestamp = str(config.OUTPUT_DIR / f"final_{st.session_state.project_id}_*.mp4")
pattern_simple = str(config.OUTPUT_DIR / f"final_{st.session_state.project_id}.mp4")
found_files = []
for f in glob.glob(pattern_timestamp):
found_files.append(f)
for f in glob.glob(pattern_simple):
if f not in found_files:
found_files.append(f)
# 按修改时间倒序排列
found_files.sort(key=os.path.getmtime, reverse=True)
if not found_files:
st.info("暂无合成视频,请先点击开始合成。")
if st.button("✨ 开始首次合成", type="primary"):
with st.spinner("正在进行多轨合成..."):
# Default compose logic with smart BGM matching
composer = VideoComposer(voice_type=config.VOLC_TTS_DEFAULT_VOICE)
# 智能匹配 BGM根据脚本 bgm_style 匹配
bgm_style = st.session_state.script_data.get("bgm_style", "")
bgm_path = match_bgm_by_style(bgm_style, config.ASSETS_DIR / "bgm")
if bgm_path and not Path(bgm_path).exists():
st.warning(f"推荐的 BGM 文件不存在:{bgm_path}(本次将不含 BGM")
bgm_path = None
try:
# 首次合成也加上时间戳
output_name = f"final_{st.session_state.project_id}_{int(time.time())}"
t0 = perf_counter()
output_path = composer.compose_from_script(
script=st.session_state.script_data,
video_map=st.session_state.scene_videos,
bgm_path=bgm_path,
output_name=output_name,
project_id=st.session_state.project_id,
)
_record_metrics(st.session_state.project_id, {
"compose_s": round(perf_counter() - t0, 3),
"bgm_used": bool(bgm_path and Path(bgm_path).exists()),
})
st.session_state.final_video = output_path
db.save_asset(st.session_state.project_id, 0, "final_video", "completed", local_path=output_path)
db.update_project_status(st.session_state.project_id, "completed")
st.rerun()
except Exception as e:
st.error(f"合成失败: {e}")
else:
st.success(f"共找到 {len(found_files)} 个历史版本")
# 遍历显示所有历史版本
for idx, vid_path in enumerate(found_files):
mtime = os.path.getmtime(vid_path)
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime))
file_name = os.path.basename(vid_path)
# 容器样式
with st.container():
if idx == 0:
st.markdown(f"### 🟢 最新版本 ({time_str})")
else:
st.markdown(f"#### ⚪ 历史版本 {len(found_files)-idx} ({time_str})")
c1, c2, c3 = st.columns([1, 2, 1])
with c2:
st.video(vid_path)
# 下载按钮
with open(vid_path, "rb") as file:
st.download_button(
label=f"📥 下载 ({file_name})",
data=file,
file_name=file_name,
mime="video/mp4",
key=f"dl_btn_{file_name}"
)
st.divider()
# ============================================================
# Page: History (New)
# ============================================================
elif st.session_state.view_mode == "history":
st.header("📜 历史任务")
projects = db.list_projects()
for proj in projects:
with st.expander(f"{proj['name']} ({proj['updated_at']})"):
st.caption(f"ID: {proj['id']} | Status: {proj.get('status', 'unknown')}")
assets = db.get_assets(proj['id'])
if assets:
st.markdown("#### 📥 下载素材")
# Group assets
final_vids = [a for a in assets if a['asset_type'] == 'final_video']
scene_imgs = [a for a in assets if a['asset_type'] == 'image']
scene_vids = [a for a in assets if a['asset_type'] == 'video']
# Final Video Download
if final_vids:
fv = final_vids[0]
if fv['local_path'] and os.path.exists(fv['local_path']):
st.success(f"✅ **最终成片已生成**")
with open(fv['local_path'], "rb") as f:
st.download_button(
label=f"📥 下载最终成片 ({os.path.basename(fv['local_path'])})",
data=f,
file_name=os.path.basename(fv['local_path']),
mime="video/mp4"
)
# Script Download
if proj.get("script_data"):
try:
script_json = json.dumps(proj["script_data"] if isinstance(proj["script_data"], dict) else json.loads(proj["script_data"]), ensure_ascii=False, indent=2)
st.download_button(
label="📥 下载脚本/字幕配置 (JSON)",
data=script_json,
file_name=f"script_{proj['id']}.json",
mime="application/json"
)
except:
pass
st.markdown("---")
c1, c2 = st.columns(2)
with c1:
st.markdown("**分镜图片**")
for img in scene_imgs:
if img['local_path'] and os.path.exists(img['local_path']):
with open(img['local_path'], "rb") as f:
st.download_button(f"Scene {img['scene_id']} 图片", f, key=f"dl_img_{proj['id']}_{img['id']}", file_name=f"scene_{img['scene_id']}.png")
with c2:
st.markdown("**分镜视频**")
for vid in scene_vids:
if vid['local_path'] and os.path.exists(vid['local_path']):
with open(vid['local_path'], "rb") as f:
st.download_button(f"Scene {vid['scene_id']} 视频", f, key=f"dl_vid_{proj['id']}_{vid['id']}", file_name=f"scene_{vid['scene_id']}.mp4")
else:
st.info("暂无生成素材")
# ============================================================
# Page: Settings (New)
# ============================================================
elif st.session_state.view_mode == "settings":
st.header("⚙️ 设置配置")
st.subheader("Prompt 配置")
# Script Generation Prompt
current_prompt = db.get_config("prompt_script_gen")
# 显示当前状态
if current_prompt:
st.info("✅ 已加载自定义 Prompt来自数据库")
else:
st.warning("⚠️ 使用默认 Prompt数据库中无自定义配置")
# Load default from instance if not in DB
temp_gen = ScriptGenerator()
current_prompt = temp_gen.default_system_prompt
new_prompt = st.text_area("脚本拆分 Prompt (System Prompt)", value=current_prompt, height=400)
col_save, col_reset = st.columns([1, 3])
with col_save:
if st.button("💾 保存配置", type="primary"):
db.set_config("prompt_script_gen", new_prompt, "System prompt for script generation step")
# 验证保存
saved = db.get_config("prompt_script_gen")
if saved == new_prompt:
st.success("✅ 配置已保存并验证成功!下次生成脚本时将使用新 Prompt。")
else:
st.error("❌ 保存可能失败,请检查日志")
with col_reset:
if st.button("🔄 恢复默认"):
temp_gen = ScriptGenerator()
db.set_config("prompt_script_gen", temp_gen.default_system_prompt, "System prompt for script generation step (DEFAULT)")
st.success("已恢复默认 Prompt请刷新页面查看")
st.rerun()