Files
video-flow/app.py
Tony Zhang 33a165a615 feat: video-flow initial commit
- app.py: Streamlit UI for video generation workflow
- main_flow.py: CLI tool with argparse support
- modules/: Business logic modules (script_gen, image_gen, video_gen, composer, etc.)
- config.py: Configuration with API keys and paths
- requirements.txt: Python dependencies
- docs/: System prompt documentation
2025-12-12 19:18:27 +08:00

1060 lines
53 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
MatchMe Studio - UI (Streamlit)
Style: Kaogujia (Clean, Data-heavy, Professional)
"""
import streamlit as st
import json
import time
import os
import random
from pathlib import Path
import pandas as pd
# Import Backend Modules
import config
from modules.script_gen import ScriptGenerator
from modules.image_gen import ImageGenerator
from modules.video_gen import VideoGenerator
from modules.composer import VideoComposer, VideoComposer as Composer # alias
from modules.text_renderer import renderer
from modules import export_utils
from modules.db_manager import db
# Page Config
st.set_page_config(
page_title="Video Flow Console",
page_icon="🎬",
layout="wide",
initial_sidebar_state="expanded"
)
# ============================================================
# BGM 智能匹配函数
# ============================================================
def match_bgm_by_style(bgm_style: str, bgm_dir: Path) -> str:
"""
根据脚本 bgm_style 智能匹配 BGM 文件
- 匹配成功:随机选一个匹配的 BGM
- 匹配失败:随机选任意一个 BGM
"""
# 获取所有 BGM 文件 (支持 .mp3 和 .mp4)
bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3"))
bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')]
if not bgm_files:
return None
# 关键词匹配
if bgm_style:
style_lower = bgm_style.lower()
# 提取关键词 (中文分词简化版:按常见词匹配)
keywords = ["活泼", "欢快", "轻松", "舒缓", "休闲", "温柔", "随性", "百搭", "bling", "节奏"]
matched_keywords = [kw for kw in keywords if kw in style_lower]
matched_files = []
for f in bgm_files:
fname = f.name
if any(kw in fname for kw in matched_keywords):
matched_files.append(f)
if matched_files:
return str(random.choice(matched_files))
# 无匹配则随机选一个
return str(random.choice(bgm_files))
# ============================================================
# CSS Styling (Kaogujia Style)
# ============================================================
st.markdown("""
<style>
/* 只保留颜色和字体样式,完全不干预布局和滚动 */
.stApp {
background-color: #F4F5F7;
font-family: "PingFang SC", "Microsoft YaHei", sans-serif;
}
section[data-testid="stSidebar"] {
background-color: #FFFFFF;
border-right: 1px solid #E5E7EB;
}
.stButton button {
border-radius: 4px;
font-weight: 500;
}
div[data-testid="stButton"] > button[kind="primary"] {
background-color: #1677FF;
color: white;
border: none;
}
div[data-testid="stButton"] > button[kind="primary"]:hover {
background-color: #4096FF;
}
h1, h2, h3 {
color: #1F2329;
font-weight: 600;
}
h2 { border-left: 4px solid #1677FF; padding-left: 10px; }
.stTextInput input, .stTextArea textarea {
border-radius: 4px;
}
</style>
""", unsafe_allow_html=True)
# ============================================================
# Session State Management
# ============================================================
if "project_id" not in st.session_state:
st.session_state.project_id = None
if "current_step" not in st.session_state:
st.session_state.current_step = 0
if "script_data" not in st.session_state:
st.session_state.script_data = None
if "scene_images" not in st.session_state:
st.session_state.scene_images = {}
if "scene_videos" not in st.session_state:
st.session_state.scene_videos = {}
if "final_video" not in st.session_state:
st.session_state.final_video = None
if "uploaded_images" not in st.session_state:
st.session_state.uploaded_images = []
if "view_mode" not in st.session_state:
st.session_state.view_mode = "workspace" # workspace, history, settings
if "selected_img_provider" not in st.session_state:
st.session_state.selected_img_provider = "shubiaobiao"
def load_project(project_id):
"""Load project state from DB"""
data = db.get_project(project_id)
if not data:
st.error("Project not found")
return
st.session_state.project_id = project_id
st.session_state.script_data = data.get("script_data")
st.session_state.view_mode = "workspace"
# Restore product info for Step 1 display
product_info = data.get("product_info", {})
st.session_state.loaded_product_name = data.get("name", "")
st.session_state.loaded_product_info = product_info
# Restore uploaded images from product_info
if product_info and "uploaded_images" in product_info:
# 检查文件是否仍存在
valid_paths = [p for p in product_info["uploaded_images"] if os.path.exists(p)]
st.session_state.uploaded_images = valid_paths
if len(valid_paths) < len(product_info["uploaded_images"]):
st.warning(f"部分原始图片已丢失 ({len(product_info['uploaded_images']) - len(valid_paths)} 张),建议重新上传")
else:
st.session_state.uploaded_images = []
# Restore assets
assets = db.get_assets(project_id)
images = {}
videos = {}
final_vid = None
for asset in assets:
sid = asset["scene_id"]
# 假设 scene_id 0 或 -1 用于 final video
if asset["asset_type"] == "image" and asset["status"] == "completed":
images[sid] = asset["local_path"]
elif asset["asset_type"] == "video" and asset["status"] == "completed":
videos[sid] = asset["local_path"]
elif asset["asset_type"] == "final_video" and asset["status"] == "completed":
final_vid = asset["local_path"]
st.session_state.scene_images = images
st.session_state.scene_videos = videos
st.session_state.final_video = final_vid
# Determine step
if final_vid:
st.session_state.current_step = 4
elif videos:
st.session_state.current_step = 4
elif images:
st.session_state.current_step = 3
elif st.session_state.script_data:
st.session_state.current_step = 1
else:
st.session_state.current_step = 0
# ============================================================
# Sidebar
# ============================================================
with st.sidebar:
st.title("📽️ Video Flow")
# Mode Selection - 正确计算 index
mode_options = ["🛠️ 工作台", "📜 历史任务", "⚙️ 设置"]
mode_map = {"workspace": 0, "history": 1, "settings": 2}
current_index = mode_map.get(st.session_state.view_mode, 0)
mode = st.radio("模式", mode_options, index=current_index)
if mode == "🛠️ 工作台":
st.session_state.view_mode = "workspace"
elif mode == "📜 历史任务":
st.session_state.view_mode = "history"
elif mode == "⚙️ 设置":
st.session_state.view_mode = "settings"
st.markdown("---")
if st.session_state.view_mode == "workspace":
# Project Selection
st.subheader("Current Project")
projects = db.list_projects()
proj_options = {p['id']: f"{p.get('name', 'Untitled')} ({p['id']})" for p in projects}
selected_proj_id = st.selectbox(
"Select Project",
options=["New Project"] + list(proj_options.keys()),
format_func=lambda x: " New Project" if x == "New Project" else proj_options[x]
)
if selected_proj_id == "New Project":
if st.session_state.project_id and st.session_state.project_id in proj_options:
# If switching from existing to new, reset
if st.button("Start New"):
for key in list(st.session_state.keys()):
if key != "view_mode": del st.session_state[key]
st.rerun()
else:
if st.session_state.project_id != selected_proj_id:
if st.button(f"Load {selected_proj_id}"):
load_project(selected_proj_id)
st.rerun()
if st.session_state.project_id:
st.caption(f"Current ID: {st.session_state.project_id}")
st.markdown("---")
# Navigation / Progress
steps = ["1. 输入信息", "2. 生成脚本", "3. 画面生成", "4. 视频生成", "5. 合成输出"]
current = st.session_state.current_step
for i, step in enumerate(steps):
if i < current:
st.markdown(f"✅ **{step}**")
elif i == current:
st.markdown(f"🔵 **{step}**")
else:
st.markdown(f"{step}")
st.markdown("---")
if st.button("🔄 重置状态", type="secondary"):
for key in list(st.session_state.keys()):
if key != "view_mode": del st.session_state[key]
st.rerun()
# ============================================================
# Helper Functions
# ============================================================
def save_uploaded_file(uploaded_file):
"""Save uploaded file to temp dir."""
if uploaded_file is not None:
file_path = config.TEMP_DIR / uploaded_file.name
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
return str(file_path)
return None
# ============================================================
# Main Content: Workspace
# ============================================================
if st.session_state.view_mode == "workspace":
st.header("商品短视频自动生成控制台")
# --- Step 1: Input ---
with st.expander("📦 1. 商品信息输入", expanded=(st.session_state.current_step == 0)):
# 从 session_state 读取已加载项目的信息,否则使用默认值
loaded_info = st.session_state.get("loaded_product_info", {})
default_name = st.session_state.get("loaded_product_name", "网红气质大号发量多!高马尾香蕉夹")
default_category = loaded_info.get("category", "钟表配饰-时尚饰品-发饰")
default_price = loaded_info.get("price", "3.99元")
default_tags = loaded_info.get("tags", "回头客|款式好看|材质好|尺寸合适|颜色好看|很好用|做工好|质感不错|很牢固")
default_params = loaded_info.get("params", "金属材质:非金属; 非金属材质:树脂; 发夹分类:香蕉夹")
col1, col2 = st.columns([1, 1])
with col1:
product_name = st.text_input("商品标题", value=default_name)
category = st.text_input("商品类目", value=default_category)
price = st.text_input("价格", value=default_price)
with col2:
tags = st.text_area("评价标签 (用于提炼卖点)", value=default_tags, height=100)
params = st.text_area("商品参数", value=default_params, height=100)
# 商家自定义风格提示
style_hint = st.text_area(
"商品视频重点增强提示 (可选)",
value=loaded_info.get("style_hint", ""),
placeholder="例如:韩风、高级感、活力青春、简约日系...",
height=80
)
st.markdown("### 上传素材")
# 显示已有的上传图片(如果从历史项目加载)
if st.session_state.uploaded_images:
st.info(f"已有 {len(st.session_state.uploaded_images)} 张参考图片")
with st.expander("查看已有图片"):
img_cols = st.columns(min(len(st.session_state.uploaded_images), 4))
for i, img_path in enumerate(st.session_state.uploaded_images[:4]):
if os.path.exists(img_path):
img_cols[i % 4].image(img_path, width=150)
uploaded_files = st.file_uploader("上传商品主图 (建议 3-5 张)", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
# 允许在没有上传新图片但有历史图片的情况下继续
can_submit = uploaded_files or st.session_state.uploaded_images
# Model Selection
model_options = ["Gemini 3 Pro", "Doubao Pro (Vision)"]
selected_model_label = st.radio("选择脚本生成模型", model_options, horizontal=True)
# Map label to provider key
model_provider = "doubao" if "Doubao" in selected_model_label else "shubiaobiao"
if st.button("提交任务 & 生成脚本", type="primary", disabled=(not can_submit)):
# 处理图片路径
image_paths = list(st.session_state.uploaded_images) if st.session_state.uploaded_images else []
# 如果有新上传的文件,添加它们
if uploaded_files:
# Create Project ID
if not st.session_state.project_id:
st.session_state.project_id = f"PROJ-{int(time.time())}"
for uf in uploaded_files:
path = save_uploaded_file(uf)
if path: image_paths.append(path)
st.session_state.uploaded_images = image_paths
# 确保有项目 ID
if not st.session_state.project_id:
st.session_state.project_id = f"PROJ-{int(time.time())}"
# DB: Create Project
# 将 uploaded_images 保存到 product_info 以便持久化
product_info = {"category": category, "price": price, "tags": tags, "params": params, "style_hint": style_hint, "uploaded_images": image_paths}
db.create_project(st.session_state.project_id, product_name, product_info)
# Call Script Generator
with st.spinner(f"正在分析商品信息并生成脚本 ({selected_model_label})..."):
gen = ScriptGenerator()
script = gen.generate_script(product_name, product_info, image_paths, model_provider=model_provider)
if script:
st.session_state.script_data = script
# DB: Save Script
db.update_project_script(st.session_state.project_id, script)
st.session_state.current_step = 1
st.success("脚本生成成功!")
st.rerun()
else:
st.error("脚本生成失败,请检查日志。")
# --- Step 2: Script Review ---
if st.session_state.current_step >= 1:
with st.expander("📝 2. 脚本分镜确认", expanded=(st.session_state.current_step == 1)):
if st.session_state.script_data:
script = st.session_state.script_data
# Display Basic Info
c1, c2 = st.columns(2)
c1.write(f"**核心卖点**: {', '.join(script.get('selling_points', []))}")
c2.write(f"**目标人群**: {script.get('target_audience', '')}")
# Prompt Visualization
if "_debug" in script:
with st.expander("🔍 查看 AI Prompt (Debug)"):
debug_info = script["_debug"]
st.markdown("**System Prompt:**")
st.code(debug_info.get("system_prompt", ""), language="markdown")
st.markdown("**User Prompt:**")
st.code(debug_info.get("user_prompt", ""), language="markdown")
# 显示原始输出
if "raw_output" in debug_info:
st.markdown("**Raw AI Output:**")
st.code(debug_info.get("raw_output", ""), language="markdown")
# Editable Scenes Table
st.markdown("### 分镜列表")
scenes = script.get("scenes", [])
# Global Voiceover Timeline (New)
st.markdown("### 🎙️ 整体旁白与字幕时间轴")
with st.expander("编辑旁白时间轴 (Voiceover Timeline)", expanded=True):
timeline = script.get("voiceover_timeline", [])
if not timeline:
# Init with default if empty (使用秒)
timeline = [{"text": "示例旁白", "subtitle": "示例字幕", "start_time": 0.0, "duration": 3.0}]
updated_timeline = []
for i, item in enumerate(timeline):
c1, c2, c3, c4 = st.columns([3, 3, 1, 1])
with c1:
item["text"] = st.text_input(f"旁白 #{i+1}", value=item.get("text", ""), key=f"tl_vo_{i}")
with c2:
item["subtitle"] = st.text_input(f"字幕 #{i+1}", value=item.get("subtitle", item.get("text", "")), key=f"tl_sub_{i}")
with c3:
# 兼容旧格式: 如果有 start_ratio 则转换为 start_time
default_start = item.get("start_time", item.get("start_ratio", 0) * 12) # 假设总时长12秒
item["start_time"] = st.number_input(f"开始(秒) #{i+1}", value=float(default_start), min_value=0.0, max_value=30.0, step=0.5, key=f"tl_start_{i}")
with c4:
# 兼容旧格式: 如果有 duration_ratio 则转换为 duration
default_dur = item.get("duration", item.get("duration_ratio", 0.25) * 12) # 假设总时长12秒
item["duration"] = st.number_input(f"时长(秒) #{i+1}", value=float(default_dur), min_value=0.5, max_value=15.0, step=0.5, key=f"tl_dur_{i}")
# 清理旧字段
item.pop("start_ratio", None)
item.pop("duration_ratio", None)
updated_timeline.append(item)
if st.button(" 添加旁白段落"):
updated_timeline.append({"text": "", "subtitle": "", "start_time": 0.0, "duration": 3.0})
st.rerun()
script["voiceover_timeline"] = updated_timeline
updated_scenes = []
for i, scene in enumerate(scenes):
with st.container():
st.markdown(f"#### 🎬 Scene {scene['id']}")
c_vis, c_aud = st.columns([2, 1])
with c_vis:
new_visual = st.text_area(f"Visual Prompt (Scene {scene['id']})", value=scene.get("visual_prompt", ""), height=80, key=f"vp_{i}")
new_video = st.text_area(f"Video Prompt (Scene {scene['id']})", value=scene.get("video_prompt", ""), height=80, key=f"vidp_{i}")
scene["visual_prompt"] = new_visual
scene["video_prompt"] = new_video
with c_aud:
# 花字编辑保留
ft = scene.get("fancy_text", {})
if isinstance(ft, dict):
new_ft_text = st.text_input(f"Fancy Text (Scene {scene['id']})", value=ft.get("text", ""), key=f"ft_{i}")
scene["fancy_text"]["text"] = new_ft_text
# 旁白/字幕已移至上方整体时间轴,此处仅作展示或删除
st.caption("注:旁白与字幕已移至上方整体时间轴编辑")
updated_scenes.append(scene)
st.divider()
st.session_state.script_data["scenes"] = updated_scenes
col_act1, col_act2 = st.columns([1, 4])
with col_act1:
if st.button("确认脚本 & 开始生图", type="primary"):
# DB: Update script in case user edited it
db.update_project_script(st.session_state.project_id, st.session_state.script_data)
st.session_state.current_step = 2
st.rerun()
# --- Step 3: Image Generation ---
if st.session_state.current_step >= 2:
with st.expander("🎨 3. 画面生成", expanded=(st.session_state.current_step == 2)):
# Debug: Show reference image
if st.session_state.uploaded_images:
with st.expander("🔍 查看生图参考底图 (Base Image)"):
for i, p in enumerate(st.session_state.uploaded_images):
st.info(f"{i+1} 文件名: {os.path.basename(p)}")
if os.path.exists(p):
st.image(p, width=200)
else:
st.warning("⚠️ 未检测到参考底图!将生成随机风格图像。建议在 Step 1 重新上传。")
# Trigger Generation Button
if not st.session_state.scene_images:
# Model Selection for Image Gen
img_model_options = ["Shubiaobiao (Gemini)", "Doubao (Volcengine)", "Gemini (Direct)", "Doubao (Group Image)"]
selected_img_model = st.radio("选择生图模型", img_model_options, horizontal=True)
# Map to provider
if "Group Image" in selected_img_model:
img_provider = "doubao-group"
elif "Doubao" in selected_img_model:
img_provider = "doubao"
elif "Direct" in selected_img_model:
img_provider = "gemini"
else:
img_provider = "shubiaobiao"
# Store selected provider for regeneration
st.session_state.selected_img_provider = img_provider
if st.button("🚀 执行 AI 生图", type="primary"):
img_gen = ImageGenerator()
# Pass ALL uploaded images as reference
base_imgs = st.session_state.uploaded_images if st.session_state.uploaded_images else []
if not base_imgs:
st.error("No base image found (未找到参考底图). Please upload in Step 1.")
st.stop()
scenes = st.session_state.script_data.get("scenes", [])
# 读取 visual_anchor 用于保持生图一致性
visual_anchor = st.session_state.script_data.get("visual_anchor", "")
if img_provider == "doubao-group":
# --- Group Generation Logic ---
with st.spinner("正在进行 Doubao 组图生成 (Batch Group Generation)..."):
try:
results = img_gen.generate_group_images_doubao(
scenes=scenes,
reference_images=base_imgs,
visual_anchor=visual_anchor
)
for s_id, path in results.items():
st.session_state.scene_images[s_id] = path
db.save_asset(st.session_state.project_id, s_id, "image", "completed", local_path=path)
if len(results) == len(scenes):
st.success("组图生成完成!")
db.update_project_status(st.session_state.project_id, "images_generated")
time.sleep(1)
st.rerun()
else:
st.warning(f"部分图片生成失败: {len(results)}/{len(scenes)}")
st.rerun()
except Exception as e:
st.error(f"组图生成失败: {e}")
else:
# --- Sequential Logic ---
total_scenes = len(scenes)
progress_bar = st.progress(0)
status_text = st.empty()
current_refs = list(base_imgs) # Start with base images
try:
for idx, scene in enumerate(scenes):
scene_id = scene["id"]
status_text.text(f"正在生成 Scene {scene_id} ({idx+1}/{total_scenes}) using {selected_img_model}...")
img_path = img_gen.generate_single_scene_image(
scene=scene,
original_image_path=current_refs, # Pass ALL accumulated images
previous_image_path=None,
model_provider=img_provider,
visual_anchor=visual_anchor
)
if img_path:
st.session_state.scene_images[scene_id] = img_path
current_refs.append(img_path) # Add newly generated image to references for next scene
db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=img_path)
progress_bar.progress((idx + 1) / total_scenes)
status_text.text("生图完成!")
st.success("生图完成!")
# Update Status
db.update_project_status(st.session_state.project_id, "images_generated")
time.sleep(1)
st.rerun()
except PermissionError as e:
st.error(f"生图服务拒绝访问 (可能余额不足): {e}")
except Exception as e:
st.error(f"生图发生错误: {e}")
# Display & Regenerate Images
if st.session_state.scene_images:
cols = st.columns(4)
scenes = st.session_state.script_data.get("scenes", [])
for i, scene in enumerate(scenes):
scene_id = scene["id"]
img_path = st.session_state.scene_images.get(scene_id)
with cols[i % 4]:
st.markdown(f"**Scene {scene_id}**")
# Expand Prompt info
with st.expander("Prompt", expanded=False):
st.caption(scene.get("visual_prompt", "No prompt"))
if img_path and os.path.exists(img_path):
st.image(img_path, width=None, use_column_width=True) # Fix deprecated
else:
st.error("Image missing")
# Regenerate specific image
if st.button(f"🔄 重生 S{scene_id}", key=f"regen_img_{scene_id}"):
if not st.session_state.uploaded_images:
st.error("缺少参考底图,请在 Step 1 重新上传图片")
else:
with st.spinner(f"正在重绘 Scene {scene_id}..."):
img_gen = ImageGenerator()
# Use ALL uploaded images + previously generated images up to this point
current_refs_for_regen = list(st.session_state.uploaded_images)
for prev_s_id in range(1, scene_id):
if prev_s_id in st.session_state.scene_images:
current_refs_for_regen.append(st.session_state.scene_images[prev_s_id])
# Fallback to single mode for regen if group was used
provider = st.session_state.get("selected_img_provider", "shubiaobiao")
if provider == "doubao-group": provider = "doubao"
# 读取 visual_anchor
regen_visual_anchor = st.session_state.script_data.get("visual_anchor", "")
new_path = img_gen.generate_single_scene_image(
scene=scene,
original_image_path=current_refs_for_regen,
previous_image_path=None,
model_provider=provider,
visual_anchor=regen_visual_anchor
)
if new_path:
st.session_state.scene_images[scene_id] = new_path
db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=new_path)
st.rerun()
if st.button("下一步:生成视频", type="primary"):
st.session_state.current_step = 3
st.rerun()
# --- Step 4: Video Generation ---
if st.session_state.current_step >= 3:
with st.expander("🎥 4. 视频生成 (Volcengine I2V)", expanded=(st.session_state.current_step == 3)):
if not st.session_state.scene_videos:
if st.button("🎬 执行图生视频", type="primary"):
with st.spinner("正在生成视频 (耗时较长)..."):
vid_gen = VideoGenerator()
# Pass project_id
videos = vid_gen.generate_scene_videos(
st.session_state.project_id,
st.session_state.script_data,
st.session_state.scene_images
)
if videos:
st.session_state.scene_videos = videos
for sid, path in videos.items():
db.save_asset(st.session_state.project_id, sid, "video", "completed", local_path=path)
# Update Status
db.update_project_status(st.session_state.project_id, "videos_generated")
st.success("视频生成完成!")
st.rerun()
else:
st.warning("部分或全部视频生成失败")
# Display Videos
if st.session_state.scene_videos:
cols = st.columns(4)
scenes = st.session_state.script_data.get("scenes", [])
for i, scene in enumerate(scenes):
scene_id = scene["id"]
vid_path = st.session_state.scene_videos.get(scene_id)
with cols[i % 4]:
st.markdown(f"**Scene {scene_id}**")
# Expand Prompt info
with st.expander("Prompt", expanded=False):
st.caption(scene.get("video_prompt", "No prompt"))
if vid_path and os.path.exists(vid_path):
st.video(vid_path)
else:
st.warning("Video missing")
# --- Recovery Logic ---
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
if asset and asset.get("task_id"):
task_id = asset.get("task_id")
if st.button(f"🔍 找回视频 (Task {task_id[-6:]})", key=f"recov_{scene_id}"):
with st.spinner("查询任务状态中..."):
vid_gen = VideoGenerator()
output_filename = f"scene_{scene_id}_video.mp4"
target_path = str(config.TEMP_DIR / output_filename)
if vid_gen.recover_video_from_task(task_id, target_path):
st.session_state.scene_videos[scene_id] = target_path
db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=target_path)
st.success("找回成功!")
st.rerun()
else:
st.error("找回失败")
# Per-scene regenerate button
if st.button(f"🔄 重生 S{scene_id}", key=f"regen_vid_{scene_id}"):
with st.spinner(f"正在重生成 Scene {scene_id} 视频..."):
vid_gen = VideoGenerator()
img_path = st.session_state.scene_images.get(scene_id)
if img_path and os.path.exists(img_path):
t_id = vid_gen.submit_scene_video_task(st.session_state.project_id, scene_id, img_path, scene.get("video_prompt"))
if t_id:
# Simple poll loop for single video regen
import time
for _ in range(60):
status, url = vid_gen.check_task_status(t_id)
if status == "succeeded":
new_path = vid_gen._download_video(url, f"scene_{scene_id}_video_{int(time.time())}.mp4")
if new_path:
st.session_state.scene_videos[scene_id] = new_path
db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=new_path, task_id=t_id)
st.rerun()
break
elif status in ["failed", "cancelled"]:
st.error(f"生成失败: {status}")
break
time.sleep(5)
else:
st.error("提交任务失败")
else:
st.error("缺少源图片")
c_act1, c_act2 = st.columns([1, 4])
with c_act1:
if st.button("🔄 重新生成所有视频", type="secondary"):
# Clear videos and rerun
st.session_state.scene_videos = {}
st.rerun()
with c_act2:
if st.button("下一步:合成最终成片", type="primary"):
st.session_state.current_step = 4
st.rerun()
# --- Step 5: Final Composition & Tuning ---
if st.session_state.current_step >= 4:
with st.expander("🎞️ 5. 合成结果与微调", expanded=(st.session_state.current_step == 4)):
# Tabs: Preview vs Tuning
tab_preview, tab_tune = st.tabs(["🎥 预览", "🎛️ 微调 (Track Editor)"])
with tab_tune:
st.subheader("轨道微调 (Tuning)")
# Global Settings
col_g1, col_g2 = st.columns(2)
with col_g1:
# BGM Select (支持 .mp3 和 .mp4)
bgm_dir = config.ASSETS_DIR / "bgm"
bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3"))
bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')]
bgm_names = [f.name for f in bgm_files]
# 智能推荐默认 BGM
bgm_style = st.session_state.script_data.get("bgm_style", "")
recommended_bgm = match_bgm_by_style(bgm_style, bgm_dir)
default_idx = 0
if recommended_bgm:
rec_name = Path(recommended_bgm).name
if rec_name in bgm_names:
default_idx = bgm_names.index(rec_name) + 1 # +1 因为有 "None" 选项
selected_bgm = st.selectbox(
f"背景音乐 (BGM) - 推荐风格: {bgm_style or '未指定'}",
["None"] + bgm_names,
index=default_idx
)
with col_g2:
# Voice Select
selected_voice = st.selectbox("配音音色 (TTS)", [config.VOLC_TTS_DEFAULT_VOICE, "zh_female_meilinvyou_saturn_bigtts"])
st.divider()
# Global Timeline Editor
st.markdown("### 🎙️ 旁白与字幕时间轴 (单位: 秒)")
timeline = st.session_state.script_data.get("voiceover_timeline", [])
updated_timeline = []
for i, item in enumerate(timeline):
c1, c2, c3, c4 = st.columns([3, 3, 1, 1])
with c1:
item["text"] = st.text_input(f"旁白 #{i+1}", value=item.get("text", ""), key=f"tune_vo_{i}")
with c2:
item["subtitle"] = st.text_input(f"字幕 #{i+1}", value=item.get("subtitle", item.get("text", "")), key=f"tune_sub_{i}")
with c3:
# 兼容旧格式: 如果有 start_ratio 则转换为 start_time
default_start = item.get("start_time", item.get("start_ratio", 0) * 12)
item["start_time"] = st.number_input(f"开始(秒) #{i+1}", value=float(default_start), min_value=0.0, max_value=30.0, step=0.5, key=f"tune_start_{i}")
with c4:
# 兼容旧格式: 如果有 duration_ratio 则转换为 duration
default_dur = item.get("duration", item.get("duration_ratio", 0.25) * 12)
item["duration"] = st.number_input(f"时长(秒) #{i+1}", value=float(default_dur), min_value=0.5, max_value=15.0, step=0.5, key=f"tune_dur_{i}")
# 清理旧字段
item.pop("start_ratio", None)
item.pop("duration_ratio", None)
updated_timeline.append(item)
st.session_state.script_data["voiceover_timeline"] = updated_timeline
st.divider()
# Per Scene Editor (Only Fancy Text)
scenes = st.session_state.script_data.get("scenes", [])
updated_scenes = []
for i, scene in enumerate(scenes):
with st.container():
st.markdown(f"**Scene {scene['id']}**")
ft = scene.get("fancy_text", {})
ft_text = ft.get("text", "") if isinstance(ft, dict) else ""
new_ft = st.text_input(f"花字", value=ft_text, key=f"tune_ft_{i}")
if isinstance(scene.get("fancy_text"), dict):
scene["fancy_text"]["text"] = new_ft
updated_scenes.append(scene)
st.session_state.script_data["scenes"] = updated_scenes
if st.button("🔄 重新合成 (Re-Compose)", type="primary"):
with st.spinner("正在应用修改并重新合成..."):
composer = VideoComposer(voice_type=selected_voice)
bgm_path = None
if selected_bgm != "None":
bgm_path = str(config.ASSETS_DIR / "bgm" / selected_bgm)
try:
# Save updated script first
db.update_project_script(st.session_state.project_id, st.session_state.script_data)
output_path = composer.compose_from_script(
script=st.session_state.script_data,
video_map=st.session_state.scene_videos,
bgm_path=bgm_path,
output_name=f"final_{st.session_state.project_id}_{int(time.time())}" # Unique name for history
)
st.session_state.final_video = output_path
db.save_asset(st.session_state.project_id, 0, "final_video", "completed", local_path=output_path)
st.success("重新合成成功!")
st.rerun()
except Exception as e:
st.error(f"合成失败: {e}")
with tab_preview:
st.subheader("🛠️ 导出与交付")
col_ex1, col_ex2 = st.columns([1, 2])
with col_ex1:
if st.button("📦 生成剪映素材包 (ZIP)", type="primary", key="btn_export_zip"):
with st.spinner("正在打包视频、音频与字幕..."):
try:
zip_path = export_utils.create_capcut_package(
st.session_state.project_id,
st.session_state.script_data,
{"scene_videos": st.session_state.scene_videos}
)
st.session_state['export_zip_path'] = zip_path
st.success("打包完成!请点击下载")
except Exception as e:
st.error(f"打包失败: {e}")
with col_ex2:
if 'export_zip_path' in st.session_state and os.path.exists(st.session_state['export_zip_path']):
zip_name = f"capcut_pack_{st.session_state.project_id}.zip"
with open(st.session_state['export_zip_path'], "rb") as f:
st.download_button(
label=f"📥 点击下载素材包 ({zip_name})",
data=f,
file_name=zip_name,
mime="application/zip"
)
st.divider()
# 扫描所有历史合成版本
import glob
# 查找匹配当前项目的所有最终视频
pattern_timestamp = str(config.OUTPUT_DIR / f"final_{st.session_state.project_id}_*.mp4")
pattern_simple = str(config.OUTPUT_DIR / f"final_{st.session_state.project_id}.mp4")
found_files = []
for f in glob.glob(pattern_timestamp):
found_files.append(f)
for f in glob.glob(pattern_simple):
if f not in found_files:
found_files.append(f)
# 按修改时间倒序排列
found_files.sort(key=os.path.getmtime, reverse=True)
if not found_files:
st.info("暂无合成视频,请先点击开始合成。")
if st.button("✨ 开始首次合成", type="primary"):
with st.spinner("正在进行多轨合成..."):
# Default compose logic with smart BGM matching
composer = VideoComposer(voice_type=config.VOLC_TTS_DEFAULT_VOICE)
# 智能匹配 BGM根据脚本 bgm_style 匹配
bgm_style = st.session_state.script_data.get("bgm_style", "")
bgm_path = match_bgm_by_style(bgm_style, config.ASSETS_DIR / "bgm")
try:
# 首次合成也加上时间戳
output_name = f"final_{st.session_state.project_id}_{int(time.time())}"
output_path = composer.compose_from_script(
script=st.session_state.script_data,
video_map=st.session_state.scene_videos,
bgm_path=bgm_path,
output_name=output_name
)
st.session_state.final_video = output_path
db.save_asset(st.session_state.project_id, 0, "final_video", "completed", local_path=output_path)
db.update_project_status(st.session_state.project_id, "completed")
st.rerun()
except Exception as e:
st.error(f"合成失败: {e}")
else:
st.success(f"共找到 {len(found_files)} 个历史版本")
# 遍历显示所有历史版本
for idx, vid_path in enumerate(found_files):
mtime = os.path.getmtime(vid_path)
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime))
file_name = os.path.basename(vid_path)
# 容器样式
with st.container():
if idx == 0:
st.markdown(f"### 🟢 最新版本 ({time_str})")
else:
st.markdown(f"#### ⚪ 历史版本 {len(found_files)-idx} ({time_str})")
c1, c2, c3 = st.columns([1, 2, 1])
with c2:
st.video(vid_path)
# 下载按钮
with open(vid_path, "rb") as file:
st.download_button(
label=f"📥 下载 ({file_name})",
data=file,
file_name=file_name,
mime="video/mp4",
key=f"dl_btn_{file_name}"
)
st.divider()
# ============================================================
# Page: History (New)
# ============================================================
elif st.session_state.view_mode == "history":
st.header("📜 历史任务")
projects = db.list_projects()
for proj in projects:
with st.expander(f"{proj['name']} ({proj['updated_at']})"):
st.caption(f"ID: {proj['id']} | Status: {proj.get('status', 'unknown')}")
assets = db.get_assets(proj['id'])
if assets:
st.markdown("#### 📥 下载素材")
# Group assets
final_vids = [a for a in assets if a['asset_type'] == 'final_video']
scene_imgs = [a for a in assets if a['asset_type'] == 'image']
scene_vids = [a for a in assets if a['asset_type'] == 'video']
# Final Video Download
if final_vids:
fv = final_vids[0]
if fv['local_path'] and os.path.exists(fv['local_path']):
st.success(f"✅ **最终成片已生成**")
with open(fv['local_path'], "rb") as f:
st.download_button(
label=f"📥 下载最终成片 ({os.path.basename(fv['local_path'])})",
data=f,
file_name=os.path.basename(fv['local_path']),
mime="video/mp4"
)
# Script Download
if proj.get("script_data"):
try:
script_json = json.dumps(proj["script_data"] if isinstance(proj["script_data"], dict) else json.loads(proj["script_data"]), ensure_ascii=False, indent=2)
st.download_button(
label="📥 下载脚本/字幕配置 (JSON)",
data=script_json,
file_name=f"script_{proj['id']}.json",
mime="application/json"
)
except:
pass
st.markdown("---")
c1, c2 = st.columns(2)
with c1:
st.markdown("**分镜图片**")
for img in scene_imgs:
if img['local_path'] and os.path.exists(img['local_path']):
with open(img['local_path'], "rb") as f:
st.download_button(f"Scene {img['scene_id']} 图片", f, key=f"dl_img_{proj['id']}_{img['id']}", file_name=f"scene_{img['scene_id']}.png")
with c2:
st.markdown("**分镜视频**")
for vid in scene_vids:
if vid['local_path'] and os.path.exists(vid['local_path']):
with open(vid['local_path'], "rb") as f:
st.download_button(f"Scene {vid['scene_id']} 视频", f, key=f"dl_vid_{proj['id']}_{vid['id']}", file_name=f"scene_{vid['scene_id']}.mp4")
else:
st.info("暂无生成素材")
# ============================================================
# Page: Settings (New)
# ============================================================
elif st.session_state.view_mode == "settings":
st.header("⚙️ 设置配置")
st.subheader("Prompt 配置")
# Script Generation Prompt
current_prompt = db.get_config("prompt_script_gen")
# 显示当前状态
if current_prompt:
st.info("✅ 已加载自定义 Prompt来自数据库")
else:
st.warning("⚠️ 使用默认 Prompt数据库中无自定义配置")
# Load default from instance if not in DB
temp_gen = ScriptGenerator()
current_prompt = temp_gen.default_system_prompt
new_prompt = st.text_area("脚本拆分 Prompt (System Prompt)", value=current_prompt, height=400)
col_save, col_reset = st.columns([1, 3])
with col_save:
if st.button("💾 保存配置", type="primary"):
db.set_config("prompt_script_gen", new_prompt, "System prompt for script generation step")
# 验证保存
saved = db.get_config("prompt_script_gen")
if saved == new_prompt:
st.success("✅ 配置已保存并验证成功!下次生成脚本时将使用新 Prompt。")
else:
st.error("❌ 保存可能失败,请检查日志")
with col_reset:
if st.button("🔄 恢复默认"):
temp_gen = ScriptGenerator()
db.set_config("prompt_script_gen", temp_gen.default_system_prompt, "System prompt for script generation step (DEFAULT)")
st.success("已恢复默认 Prompt请刷新页面查看")
st.rerun()