""" MatchMe Studio - UI (Streamlit) Style: Kaogujia (Clean, Data-heavy, Professional) """ import streamlit as st import json import time import os import random from pathlib import Path import pandas as pd from time import perf_counter from concurrent.futures import ThreadPoolExecutor, as_completed # Import Backend Modules import config from modules.script_gen import ScriptGenerator from modules.image_gen import ImageGenerator from modules.video_gen import VideoGenerator from modules.composer import VideoComposer, VideoComposer as Composer # alias from modules.text_renderer import renderer from modules import export_utils from modules.db_manager import db from modules import path_utils from modules import limits from modules.legacy_path_mapper import map_legacy_local_path from modules.legacy_normalizer import normalize_legacy_project # Page Config st.set_page_config( page_title="Video Flow Console", page_icon="🎬", layout="wide", initial_sidebar_state="expanded" ) # ============================================================ # BGM 智能匹配函数 # ============================================================ def match_bgm_by_style(bgm_style: str, bgm_dir: Path) -> str: """ 根据脚本 bgm_style 智能匹配 BGM 文件 - 匹配成功:随机选一个匹配的 BGM - 匹配失败:随机选任意一个 BGM """ # 获取所有 BGM 文件 (支持 .mp3 和 .mp4) bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3")) bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')] if not bgm_files: return None # 关键词匹配 if bgm_style: style_lower = bgm_style.lower() # 提取关键词 (中文分词简化版:按常见词匹配) keywords = ["活泼", "欢快", "轻松", "舒缓", "休闲", "温柔", "随性", "百搭", "bling", "节奏"] matched_keywords = [kw for kw in keywords if kw in style_lower] matched_files = [] for f in bgm_files: fname = f.name if any(kw in fname for kw in matched_keywords): matched_files.append(f) if matched_files: return str(random.choice(matched_files)) # 无匹配则随机选一个 return str(random.choice(bgm_files)) # ============================================================ # CSS Styling (Kaogujia Style) # ============================================================ st.markdown(""" """, unsafe_allow_html=True) # ============================================================ # Session State Management # ============================================================ if "project_id" not in st.session_state: st.session_state.project_id = None if "current_step" not in st.session_state: st.session_state.current_step = 0 if "script_data" not in st.session_state: st.session_state.script_data = None if "scene_images" not in st.session_state: st.session_state.scene_images = {} if "scene_videos" not in st.session_state: st.session_state.scene_videos = {} if "final_video" not in st.session_state: st.session_state.final_video = None if "uploaded_images" not in st.session_state: st.session_state.uploaded_images = [] if "view_mode" not in st.session_state: st.session_state.view_mode = "workspace" # workspace, history, settings if "selected_img_provider" not in st.session_state: st.session_state.selected_img_provider = "shubiaobiao" def load_project(project_id): """Load project state from DB""" data = db.get_project(project_id) if not data: st.error("Project not found") return st.session_state.project_id = project_id st.session_state.script_data = data.get("script_data") st.session_state.view_mode = "workspace" # Fallback: 如果 DB 中的 script_data 是旧结构/缺字段,则从 legacy JSON 重新规范化一次 try: script_data = st.session_state.script_data legacy_json = Path(config.TEMP_DIR) / f"project_{project_id}.json" def _needs_normalize(sd: Any) -> bool: if not isinstance(sd, dict): return True if "_legacy_schema" not in sd: return True scenes = sd.get("scenes") or [] if scenes and isinstance(scenes, list) and isinstance(scenes[0], dict): if "visual_prompt" not in scenes[0] or "video_prompt" not in scenes[0]: return True return False if legacy_json.exists() and _needs_normalize(script_data): raw = json.loads(legacy_json.read_text(encoding="utf-8")) normalized = normalize_legacy_project(raw) st.session_state.script_data = normalized # 写回 DB,避免每次 load 都重新算 db.update_project_script(project_id, normalized) st.info("已从 legacy JSON 重新规范化脚本字段(兼容旧版项目)。") except Exception as e: st.warning(f"legacy 规范化失败(将继续使用 DB 数据): {e}") # Restore product info for Step 1 display product_info = data.get("product_info", {}) st.session_state.loaded_product_name = data.get("name", "") st.session_state.loaded_product_info = product_info # Restore uploaded images from product_info if product_info and "uploaded_images" in product_info: # 检查文件是否仍存在 valid_paths = [p for p in product_info["uploaded_images"] if os.path.exists(p)] st.session_state.uploaded_images = valid_paths if len(valid_paths) < len(product_info["uploaded_images"]): st.warning(f"部分原始图片已丢失 ({len(product_info['uploaded_images']) - len(valid_paths)} 张),建议重新上传") else: st.session_state.uploaded_images = [] # Restore assets assets = db.get_assets(project_id) images = {} videos = {} final_vid = None for asset in assets: sid = asset["scene_id"] source_path, _mapped_url = map_legacy_local_path(asset.get("local_path")) # 假设 scene_id 0 或 -1 用于 final video if asset["asset_type"] == "image" and asset["status"] == "completed": if source_path: images[sid] = source_path elif asset["asset_type"] == "video" and asset["status"] == "completed": if source_path: videos[sid] = source_path elif asset["asset_type"] == "final_video" and asset["status"] == "completed": if source_path: final_vid = source_path st.session_state.scene_images = images st.session_state.scene_videos = videos st.session_state.final_video = final_vid # Determine step if final_vid: st.session_state.current_step = 4 elif videos: st.session_state.current_step = 4 elif images: st.session_state.current_step = 3 elif st.session_state.script_data: st.session_state.current_step = 1 else: st.session_state.current_step = 0 # ============================================================ # Sidebar # ============================================================ with st.sidebar: st.title("📽️ Video Flow") # Mode Selection - 正确计算 index mode_options = ["🛠️ 工作台", "📜 历史任务", "⚙️ 设置"] mode_map = {"workspace": 0, "history": 1, "settings": 2} current_index = mode_map.get(st.session_state.view_mode, 0) mode = st.radio("模式", mode_options, index=current_index) if mode == "🛠️ 工作台": st.session_state.view_mode = "workspace" elif mode == "📜 历史任务": st.session_state.view_mode = "history" elif mode == "⚙️ 设置": st.session_state.view_mode = "settings" st.markdown("---") if st.session_state.view_mode == "workspace": # Project Selection st.subheader("Current Project") projects = db.list_projects() proj_options = {p['id']: f"{p.get('name', 'Untitled')} ({p['id']})" for p in projects} selected_proj_id = st.selectbox( "Select Project", options=["New Project"] + list(proj_options.keys()), format_func=lambda x: "➕ New Project" if x == "New Project" else proj_options[x] ) if selected_proj_id == "New Project": if st.session_state.project_id and st.session_state.project_id in proj_options: # If switching from existing to new, reset if st.button("Start New"): for key in list(st.session_state.keys()): if key != "view_mode": del st.session_state[key] st.rerun() else: if st.session_state.project_id != selected_proj_id: if st.button(f"Load {selected_proj_id}"): load_project(selected_proj_id) st.rerun() if st.session_state.project_id: st.caption(f"Current ID: {st.session_state.project_id}") with st.expander("⏱️ 性能与诊断", expanded=False): m = _get_metrics(st.session_state.project_id) if not m: st.caption("暂无指标(执行一次脚本/生图/生视频/合成后会出现)。") else: keys = [ "script_gen_s", "image_gen_total_s", "video_submit_s", "video_recover_s", "compose_s", "script_model", "image_provider", "image_generated", "video_submitted", "video_recovered", "bgm_used", ] for k in keys: if k in m: st.caption(f"{k}: {m.get(k)}") # 在线剪辑入口(React Editor) web_base_url = os.getenv("WEB_BASE_URL", "http://localhost:3000").rstrip("/") st.markdown( f"[打开在线剪辑器]({web_base_url}/editor/{st.session_state.project_id})", unsafe_allow_html=False, ) st.markdown("---") # Navigation / Progress steps = ["1. 输入信息", "2. 生成脚本", "3. 画面生成", "4. 视频生成", "5. 合成输出"] current = st.session_state.current_step for i, step in enumerate(steps): if i < current: st.markdown(f"✅ **{step}**") elif i == current: st.markdown(f"🔵 **{step}**") else: st.markdown(f"⚪ {step}") st.markdown("---") if st.button("🔄 重置状态", type="secondary"): for key in list(st.session_state.keys()): if key != "view_mode": del st.session_state[key] st.rerun() # ============================================================ # Helper Functions # ============================================================ def _record_metrics(project_id: str, patch: dict): """Persist lightweight timing/diagnostic metrics into project.product_info['_metrics'].""" if not project_id or not isinstance(patch, dict) or not patch: return try: proj = db.get_project(project_id) or {} product_info = proj.get("product_info") or {} metrics = product_info.get("_metrics") if isinstance(product_info.get("_metrics"), dict) else {} metrics.update(patch) metrics["updated_at"] = time.time() product_info["_metrics"] = metrics db.update_project_product_info(project_id, product_info) except Exception: # metrics must never break UX pass def _get_metrics(project_id: str) -> dict: try: proj = db.get_project(project_id) or {} product_info = proj.get("product_info") or {} m = product_info.get("_metrics") return m if isinstance(m, dict) else {} except Exception: return {} def save_uploaded_file(project_id: str, uploaded_file): """Save uploaded file to per-project upload dir (avoid overwrites across projects).""" if uploaded_file is None: return None if not project_id: raise ValueError("project_id is required to save uploaded files safely") upload_dir = path_utils.project_upload_dir(project_id) original = path_utils.sanitize_filename(getattr(uploaded_file, "name", "upload")) # keep original stem for readability, but ensure uniqueness suffix = Path(original).suffix.lstrip(".") or "bin" stem = Path(original).stem or "upload" unique_name = path_utils.unique_filename(prefix=f"upload_{stem}", ext=suffix, project_id=project_id) file_path = upload_dir / unique_name with open(file_path, "wb") as f: f.write(uploaded_file.getbuffer()) return str(file_path) # ============================================================ # Main Content: Workspace # ============================================================ if st.session_state.view_mode == "workspace": st.header("商品短视频自动生成控制台") # --- Step 1: Input --- with st.expander("📦 1. 商品信息输入", expanded=(st.session_state.current_step == 0)): # 从 session_state 读取已加载项目的信息,否则使用默认值 loaded_info = st.session_state.get("loaded_product_info", {}) default_name = st.session_state.get("loaded_product_name", "网红气质大号发量多!高马尾香蕉夹") default_category = loaded_info.get("category", "钟表配饰-时尚饰品-发饰") default_price = loaded_info.get("price", "3.99元") default_tags = loaded_info.get("tags", "回头客|款式好看|材质好|尺寸合适|颜色好看|很好用|做工好|质感不错|很牢固") default_params = loaded_info.get("params", "金属材质:非金属; 非金属材质:树脂; 发夹分类:香蕉夹") col1, col2 = st.columns([1, 1]) with col1: product_name = st.text_input("商品标题", value=default_name) category = st.text_input("商品类目", value=default_category) price = st.text_input("价格", value=default_price) with col2: tags = st.text_area("评价标签 (用于提炼卖点)", value=default_tags, height=100) params = st.text_area("商品参数", value=default_params, height=100) # 商家自定义风格提示 style_hint = st.text_area( "商品视频重点增强提示 (可选)", value=loaded_info.get("style_hint", ""), placeholder="例如:韩风、高级感、活力青春、简约日系...", height=80 ) st.markdown("### 上传素材") # 显示已有的上传图片(如果从历史项目加载) if st.session_state.uploaded_images: st.info(f"已有 {len(st.session_state.uploaded_images)} 张参考图片") with st.expander("查看已有图片"): img_cols = st.columns(min(len(st.session_state.uploaded_images), 4)) for i, img_path in enumerate(st.session_state.uploaded_images[:4]): if os.path.exists(img_path): img_cols[i % 4].image(img_path, width=150) uploaded_files = st.file_uploader("上传商品主图 (建议 3-5 张)", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True) # 允许在没有上传新图片但有历史图片的情况下继续 can_submit = uploaded_files or st.session_state.uploaded_images # Model Selection (all support images; user explicitly chooses model) model_options = ["GPT-5.2", "Gemini 3 Pro", "Doubao Pro (Vision)"] selected_model_label = st.radio("选择脚本生成模型", model_options, horizontal=True, index=0) # Map label to provider key if selected_model_label == "GPT-5.2": model_provider = "shubiaobiao_gpt" elif "Doubao" in selected_model_label: model_provider = "doubao" else: model_provider = "shubiaobiao" if st.button("提交任务 & 生成脚本", type="primary", disabled=(not can_submit)): # 处理图片路径 image_paths = list(st.session_state.uploaded_images) if st.session_state.uploaded_images else [] # 如果有新上传的文件,添加它们 if uploaded_files: # Create Project ID if not st.session_state.project_id: st.session_state.project_id = f"PROJ-{int(time.time())}" for uf in uploaded_files: path = save_uploaded_file(st.session_state.project_id, uf) if path: image_paths.append(path) st.session_state.uploaded_images = image_paths # 确保有项目 ID if not st.session_state.project_id: st.session_state.project_id = f"PROJ-{int(time.time())}" # DB: Create Project # 将 uploaded_images 保存到 product_info 以便持久化 product_info = {"category": category, "price": price, "tags": tags, "params": params, "style_hint": style_hint, "uploaded_images": image_paths} db.create_project(st.session_state.project_id, product_name, product_info) # Call Script Generator with st.spinner(f"正在分析商品信息并生成脚本 ({selected_model_label})..."): gen = ScriptGenerator() t0 = perf_counter() script = gen.generate_script(product_name, product_info, image_paths, model_provider=model_provider) _record_metrics(st.session_state.project_id, { "script_gen_s": round(perf_counter() - t0, 3), "script_model": model_provider, }) if script: st.session_state.script_data = script # DB: Save Script db.update_project_script(st.session_state.project_id, script) st.session_state.current_step = 1 st.success("脚本生成成功!") st.rerun() else: st.error("脚本生成失败,请检查日志。") # --- Step 2: Script Review --- if st.session_state.current_step >= 1: with st.expander("📝 2. 脚本分镜确认", expanded=(st.session_state.current_step == 1)): if st.session_state.script_data: script = st.session_state.script_data # Display Basic Info (兼容 legacy schema) selling_points = script.get("selling_points", []) or [] target_audience = script.get("target_audience", "") or "" analysis_text = script.get("analysis", "") or "" legacy_schema = script.get("_legacy_schema", "") or "" c1, c2 = st.columns(2) if selling_points: c1.write(f"**核心卖点**: {', '.join(selling_points)}") else: c1.write("**核心卖点**: (legacy 项目可能未生成该字段)") if analysis_text: with st.expander("查看 legacy analysis(用于补齐信息)"): st.write(analysis_text) if target_audience: c2.write(f"**目标人群**: {target_audience}") else: c2.write("**目标人群**: (legacy 项目可能未生成该字段)") # Hook / CTA / Schema hook = script.get("hook", "") or "" if hook: st.markdown(f"**Hook**: {hook}") cta = script.get("cta", "") if cta: if isinstance(cta, dict): st.markdown("**CTA(legacy object)**") st.json(cta) else: st.markdown(f"**CTA**: {cta}") if legacy_schema: st.caption(f"Legacy Schema: {legacy_schema}") # Prompt Visualization if "_debug" in script: with st.expander("🔍 查看 AI Prompt (Debug)"): debug_info = script["_debug"] st.markdown("**System Prompt:**") st.code(debug_info.get("system_prompt", ""), language="markdown") st.markdown("**User Prompt:**") st.code(debug_info.get("user_prompt", ""), language="markdown") # 显示原始输出 if "raw_output" in debug_info: st.markdown("**Raw AI Output:**") st.code(debug_info.get("raw_output", ""), language="markdown") # Editable Scenes Table st.markdown("### 分镜列表") scenes = script.get("scenes", []) # Global Voiceover Timeline (New) st.markdown("### 🎙️ 整体旁白与字幕时间轴") with st.expander("编辑旁白时间轴 (Voiceover Timeline)", expanded=True): timeline = script.get("voiceover_timeline", []) or [] if not timeline: # 对于历史项目:如果没有 scenes 也没有 timeline,不要强行塞“示例旁白”,避免污染数据 if not scenes and analysis_text: st.info("该历史项目暂无旁白时间轴(可能停留在分析/提问阶段)。") timeline = [] else: # Init with default if empty (使用秒) timeline = [{"text": "示例旁白", "subtitle": "示例字幕", "start_time": 0.0, "duration": 3.0}] updated_timeline = [] for i, item in enumerate(timeline): c1, c2, c3, c4 = st.columns([3, 3, 1, 1]) with c1: item["text"] = st.text_input(f"旁白 #{i+1}", value=item.get("text", ""), key=f"tl_vo_{i}") with c2: item["subtitle"] = st.text_input(f"字幕 #{i+1}", value=item.get("subtitle", item.get("text", "")), key=f"tl_sub_{i}") with c3: # 兼容旧格式: 如果有 start_ratio 则转换为 start_time default_start = item.get("start_time", item.get("start_ratio", 0) * 12) # 假设总时长12秒 item["start_time"] = st.number_input(f"开始(秒) #{i+1}", value=float(default_start), min_value=0.0, max_value=30.0, step=0.5, key=f"tl_start_{i}") with c4: # 兼容旧格式: 如果有 duration_ratio 则转换为 duration default_dur = item.get("duration", item.get("duration_ratio", 0.25) * 12) # 假设总时长12秒 item["duration"] = st.number_input(f"时长(秒) #{i+1}", value=float(default_dur), min_value=0.5, max_value=15.0, step=0.5, key=f"tl_dur_{i}") # 清理旧字段 item.pop("start_ratio", None) item.pop("duration_ratio", None) updated_timeline.append(item) if st.button("➕ 添加旁白段落"): updated_timeline.append({"text": "", "subtitle": "", "start_time": 0.0, "duration": 3.0}) st.rerun() script["voiceover_timeline"] = updated_timeline updated_scenes = [] for i, scene in enumerate(scenes): with st.container(): st.markdown(f"#### 🎬 Scene {scene['id']}") c_vis, c_aud = st.columns([2, 1]) with c_vis: new_visual = st.text_area(f"Visual Prompt (Scene {scene['id']})", value=scene.get("visual_prompt", ""), height=80, key=f"vp_{i}") new_video = st.text_area(f"Video Prompt (Scene {scene['id']})", value=scene.get("video_prompt", ""), height=80, key=f"vidp_{i}") scene["visual_prompt"] = new_visual scene["video_prompt"] = new_video with c_aud: # 花字编辑保留 ft = scene.get("fancy_text", {}) if isinstance(ft, dict): new_ft_text = st.text_input( f"Fancy Text (Scene {scene['id']})", value=ft.get("text", ""), key=f"ft_{i}", ) # 兼容:旧数据可能没有 fancy_text 字段 if not isinstance(scene.get("fancy_text"), dict): scene["fancy_text"] = {} scene["fancy_text"]["text"] = new_ft_text # 旁白/字幕已移至上方整体时间轴,此处仅作展示或删除 st.caption("注:旁白与字幕已移至上方整体时间轴编辑") # Legacy 信息展示(只读,用于调试/对齐) legacy_scene = scene.get("_legacy", {}) if isinstance(scene.get("_legacy", {}), dict) else {} if legacy_scene: with st.expander(f"Legacy 信息 (Scene {scene['id']})", expanded=False): img_url = legacy_scene.get("image_url") if img_url: st.markdown(f"- image_url: `{img_url}`") cam = legacy_scene.get("camera_movement") if cam: st.markdown(f"- camera_movement: {cam}") vo = legacy_scene.get("voiceover") if vo: st.markdown(f"- voiceover: {vo}") keyframe = legacy_scene.get("keyframe") if keyframe: st.markdown("- keyframe:") st.json(keyframe) rhythm = legacy_scene.get("rhythm") if rhythm: st.markdown("- rhythm:") st.json(rhythm) updated_scenes.append(scene) st.divider() st.session_state.script_data["scenes"] = updated_scenes col_act1, col_act2 = st.columns([1, 4]) with col_act1: if st.button("确认脚本 & 开始生图", type="primary"): # DB: Update script in case user edited it db.update_project_script(st.session_state.project_id, st.session_state.script_data) st.session_state.current_step = 2 st.rerun() # --- Step 3: Image Generation --- if st.session_state.current_step >= 2: with st.expander("🎨 3. 画面生成", expanded=(st.session_state.current_step == 2)): # Debug: Show reference image if st.session_state.uploaded_images: with st.expander("🔍 查看生图参考底图 (Base Image)"): for i, p in enumerate(st.session_state.uploaded_images): st.info(f"图 {i+1} 文件名: {os.path.basename(p)}") if os.path.exists(p): st.image(p, width=200) else: st.warning("⚠️ 未检测到参考底图!将生成随机风格图像。建议在 Step 1 重新上传。") # Trigger Generation Button if not st.session_state.scene_images: # Model Selection for Image Gen img_model_options = ["Shubiaobiao (Gemini)", "Doubao (Volcengine)", "Gemini (Direct)", "Doubao (Group Image)"] selected_img_model = st.radio("选择生图模型", img_model_options, horizontal=True) # Map to provider if "Group Image" in selected_img_model: img_provider = "doubao-group" elif "Doubao" in selected_img_model: img_provider = "doubao" elif "Direct" in selected_img_model: img_provider = "gemini" else: img_provider = "shubiaobiao" # Store selected provider for regeneration st.session_state.selected_img_provider = img_provider if st.button("🚀 执行 AI 生图", type="primary"): with limits.acquire_image(blocking=False) as ok: if not ok: st.warning("系统正在生成其他任务(生图并发已达上限),请稍后再试。") st.stop() img_gen = ImageGenerator() # Pass ALL uploaded images as reference base_imgs = st.session_state.uploaded_images if st.session_state.uploaded_images else [] if not base_imgs: st.error("No base image found (未找到参考底图). Please upload in Step 1.") st.stop() scenes = st.session_state.script_data.get("scenes", []) # 读取 visual_anchor 用于保持生图一致性 visual_anchor = st.session_state.script_data.get("visual_anchor", "") if img_provider == "doubao-group": # --- Group Generation Logic --- with st.spinner("正在进行 Doubao 组图生成 (Batch Group Generation)..."): try: t0 = perf_counter() results = img_gen.generate_group_images_doubao( scenes=scenes, reference_images=base_imgs, visual_anchor=visual_anchor, project_id=st.session_state.project_id ) _record_metrics(st.session_state.project_id, { "image_gen_total_s": round(perf_counter() - t0, 3), "image_provider": img_provider, "image_generated": len(results), }) for s_id, path in results.items(): st.session_state.scene_images[s_id] = path db.save_asset(st.session_state.project_id, s_id, "image", "completed", local_path=path) if len(results) == len(scenes): st.success("组图生成完成!") db.update_project_status(st.session_state.project_id, "images_generated") time.sleep(1) st.rerun() else: st.warning(f"部分图片生成失败: {len(results)}/{len(scenes)}") st.rerun() except Exception as e: st.error(f"组图生成失败: {e}") else: # --- Parallel Logic (default): only merchant uploaded images as references --- total_scenes = len(scenes) progress_bar = st.progress(0) status_text = st.empty() try: t0 = perf_counter() # Parallel workers within a single run; global semaphore already acquired above. max_workers = 6 futures = {} with ThreadPoolExecutor(max_workers=max_workers) as ex: for idx, scene in enumerate(scenes): scene_id = scene["id"] futures[ex.submit( img_gen.generate_single_scene_image, scene=scene, original_image_path=list(base_imgs), # ONLY merchant images previous_image_path=None, model_provider=img_provider, visual_anchor=visual_anchor, project_id=st.session_state.project_id, )] = (idx, scene_id) done = 0 for fut in as_completed(futures): idx, scene_id = futures[fut] done += 1 status_text.text(f"已完成 {done}/{total_scenes}(Scene {scene_id})") try: img_path = fut.result() except Exception as e: img_path = None st.warning(f"Scene {scene_id} 生成失败:{e}") if img_path: st.session_state.scene_images[scene_id] = img_path db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=img_path) progress_bar.progress(done / total_scenes) status_text.text("生图完成!") st.success("生图完成!") _record_metrics(st.session_state.project_id, { "image_gen_total_s": round(perf_counter() - t0, 3), "image_provider": img_provider, "image_generated": len(st.session_state.scene_images), }) # Update Status db.update_project_status(st.session_state.project_id, "images_generated") time.sleep(1) st.rerun() except PermissionError as e: st.error(f"生图服务拒绝访问 (可能余额不足): {e}") except Exception as e: st.error(f"生图发生错误: {e}") # Display & Regenerate Images if st.session_state.scene_images: cols = st.columns(4) scenes = st.session_state.script_data.get("scenes", []) for i, scene in enumerate(scenes): scene_id = scene["id"] img_path = st.session_state.scene_images.get(scene_id) with cols[i % 4]: st.markdown(f"**Scene {scene_id}**") # Expand Prompt info with st.expander("Prompt", expanded=False): st.caption(scene.get("visual_prompt", "No prompt")) if img_path and os.path.exists(img_path): st.image(img_path, width=None, use_column_width=True) # Fix deprecated else: st.error("Image missing") # Regenerate specific image if st.button(f"🔄 重生 S{scene_id}", key=f"regen_img_{scene_id}"): if not st.session_state.uploaded_images: st.error("缺少参考底图,请在 Step 1 重新上传图片") else: with st.spinner(f"正在重绘 Scene {scene_id}..."): img_gen = ImageGenerator() # Only merchant uploaded images as references (no chaining) current_refs_for_regen = list(st.session_state.uploaded_images) # Fallback to single mode for regen if group was used provider = st.session_state.get("selected_img_provider", "shubiaobiao") if provider == "doubao-group": provider = "doubao" # 读取 visual_anchor regen_visual_anchor = st.session_state.script_data.get("visual_anchor", "") new_path = img_gen.generate_single_scene_image( scene=scene, original_image_path=current_refs_for_regen, previous_image_path=None, model_provider=provider, visual_anchor=regen_visual_anchor, project_id=st.session_state.project_id ) if new_path: st.session_state.scene_images[scene_id] = new_path db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=new_path) st.rerun() if st.button("下一步:生成视频", type="primary"): st.session_state.current_step = 3 st.rerun() # --- Step 4: Video Generation --- if st.session_state.current_step >= 3: with st.expander("🎥 4. 视频生成 (Volcengine I2V)", expanded=(st.session_state.current_step == 3)): scenes = st.session_state.script_data.get("scenes", []) vid_gen = VideoGenerator() # Submit-only (non-blocking) to avoid freezing Streamlit under concurrency if st.button("🎬 提交图生视频任务(非阻塞)", type="primary"): with limits.acquire_video(blocking=False) as ok: if not ok: st.warning("系统正在处理其他视频任务(并发已达上限),请稍后再试。") st.stop() t0 = perf_counter() submitted = 0 for scene in scenes: scene_id = scene["id"] image_path = st.session_state.scene_images.get(scene_id) prompt = scene.get("video_prompt", "High quality video") task_id = vid_gen.submit_scene_video_task( st.session_state.project_id, scene_id, image_path, prompt ) if task_id: submitted += 1 _record_metrics(st.session_state.project_id, { "video_submit_s": round(perf_counter() - t0, 3), "video_submitted": submitted, }) if submitted: db.update_project_status(st.session_state.project_id, "videos_processing") st.success(f"已提交 {submitted} 个分镜视频任务。可点击下方“刷新恢复”下载结果。") time.sleep(0.5) st.rerun() else: st.warning("未提交任何任务(可能缺少图片或接口失败)。") if st.button("🔄 刷新状态并恢复已完成任务", type="secondary"): with limits.acquire_video(blocking=False) as ok: if not ok: st.warning("系统正在处理其他视频任务(并发已达上限),请稍后再试。") st.stop() t0 = perf_counter() updated = 0 for scene in scenes: scene_id = scene["id"] asset = db.get_asset(st.session_state.project_id, scene_id, "video") if not asset or not asset.get("task_id"): continue # if already have local video, skip existing = st.session_state.scene_videos.get(scene_id) if existing and os.path.exists(existing): continue task_id = asset.get("task_id") # Query volc status; store URL for direct preview (no server download) status = None url = None # short retries for "succeeded but url missing" for attempt in range(3): status, url = vid_gen.check_task_status(task_id) if status == "succeeded" and url: break time.sleep(0.5 * (2 ** attempt)) meta_patch = {"checked_at": time.time(), "volc_status": status} if url: meta_patch["video_url"] = url db.update_asset_metadata(st.session_state.project_id, scene_id, "video", meta_patch) updated += 1 _record_metrics(st.session_state.project_id, { "video_recover_s": round(perf_counter() - t0, 3), "video_recovered": updated, }) if updated: st.success(f"已刷新 {updated} 个分镜状态(成功的将以 URL 直连预览)。") else: st.info("暂无可恢复的视频(可能仍在排队/生成中)。") time.sleep(0.5) st.rerun() if st.button("📥 准备合成素材(下载成功的视频到服务器)", type="secondary"): with limits.acquire_video(blocking=False) as ok: if not ok: st.warning("系统正在处理其他视频任务(并发已达上限),请稍后再试。") st.stop() downloaded = 0 for scene in scenes: scene_id = scene["id"] existing = st.session_state.scene_videos.get(scene_id) if existing and os.path.exists(existing): continue asset = db.get_asset(st.session_state.project_id, scene_id, "video") meta = (asset or {}).get("metadata") or {} video_url = meta.get("video_url") if not video_url: continue out_name = path_utils.unique_filename( prefix="scene_video", ext="mp4", project_id=st.session_state.project_id, scene_id=scene_id, ) target_path = str(path_utils.project_videos_dir(st.session_state.project_id) / out_name) if vid_gen._download_video_to(video_url, target_path): st.session_state.scene_videos[scene_id] = target_path db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=target_path, task_id=(asset or {}).get("task_id"), metadata=meta) downloaded += 1 if downloaded: st.success(f"已下载 {downloaded} 段视频,可进入合成。") else: st.info("暂无可下载的视频(请先刷新状态获取 video_url)。") time.sleep(0.5) st.rerun() # Display Videos (even when partially available) if st.session_state.scene_videos or scenes: cols = st.columns(4) for i, scene in enumerate(scenes): scene_id = scene["id"] vid_path = st.session_state.scene_videos.get(scene_id) with cols[i % 4]: st.markdown(f"**Scene {scene_id}**") # Expand Prompt info with st.expander("Prompt", expanded=False): st.caption(scene.get("video_prompt", "No prompt")) if vid_path and os.path.exists(vid_path): st.video(vid_path) else: # Try URL preview from DB metadata asset = db.get_asset(st.session_state.project_id, scene_id, "video") meta = (asset or {}).get("metadata") or {} video_url = meta.get("video_url") if video_url: st.caption("URL 直连预览(不经服务器落盘)") st.video(video_url) else: st.warning("Video missing") # --- Recovery Logic --- if asset and asset.get("task_id"): task_id = asset.get("task_id") if st.button(f"🔍 刷新URL (Task {task_id[-6:]})", key=f"recov_{scene_id}"): with st.spinner("查询任务状态中..."): status, url = vid_gen.check_task_status(task_id) patch = {"checked_at": time.time(), "volc_status": status} if url: patch["video_url"] = url db.update_asset_metadata(st.session_state.project_id, scene_id, "video", patch) st.success("已刷新任务状态。") st.rerun() # Per-scene regenerate button if st.button(f"🔄 重生 S{scene_id}", key=f"regen_vid_{scene_id}"): with st.spinner(f"正在重生成 Scene {scene_id} 视频..."): vid_gen = VideoGenerator() img_path = st.session_state.scene_images.get(scene_id) if img_path and os.path.exists(img_path): t_id = vid_gen.submit_scene_video_task(st.session_state.project_id, scene_id, img_path, scene.get("video_prompt")) if t_id: # Simple poll loop for single video regen import time for _ in range(60): status, url = vid_gen.check_task_status(t_id) if status == "succeeded": new_path = vid_gen._download_video(url, f"scene_{scene_id}_video_{int(time.time())}.mp4") if new_path: st.session_state.scene_videos[scene_id] = new_path db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=new_path, task_id=t_id) st.rerun() break elif status in ["failed", "cancelled"]: st.error(f"生成失败: {status}") break time.sleep(5) else: st.error("提交任务失败") else: st.error("缺少源图片") c_act1, c_act2 = st.columns([1, 4]) with c_act1: if st.button("🔄 重新生成所有视频", type="secondary"): # Clear videos and rerun st.session_state.scene_videos = {} st.rerun() with c_act2: if st.button("下一步:合成最终成片", type="primary"): st.session_state.current_step = 4 st.rerun() # --- Step 5: Final Composition & Tuning --- if st.session_state.current_step >= 4: with st.expander("🎞️ 5. 合成结果与微调", expanded=(st.session_state.current_step == 4)): # Tabs: Preview vs Tuning tab_preview, tab_tune = st.tabs(["🎥 预览", "🎛️ 微调 (Track Editor)"]) with tab_tune: st.subheader("轨道微调 (Tuning)") # Global Settings col_g1, col_g2 = st.columns(2) with col_g1: # BGM Select (支持 .mp3 和 .mp4) bgm_dir = config.ASSETS_DIR / "bgm" bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3")) bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')] bgm_names = [f.name for f in bgm_files] # 智能推荐默认 BGM bgm_style = st.session_state.script_data.get("bgm_style", "") recommended_bgm = match_bgm_by_style(bgm_style, bgm_dir) default_idx = 0 if recommended_bgm: rec_name = Path(recommended_bgm).name if rec_name in bgm_names: default_idx = bgm_names.index(rec_name) + 1 # +1 因为有 "None" 选项 selected_bgm = st.selectbox( f"背景音乐 (BGM) - 推荐风格: {bgm_style or '未指定'}", ["None"] + bgm_names, index=default_idx ) # 明确提示:BGM 目录为空或选中 BGM 不存在时,本次将不含 BGM if not bgm_names: st.warning(f"BGM 目录为空:{bgm_dir}(本次合成将不含 BGM)") elif selected_bgm != "None": candidate = config.ASSETS_DIR / "bgm" / selected_bgm if not candidate.exists(): st.warning(f"所选 BGM 文件不存在:{candidate}(本次合成将不含 BGM)") with col_g2: # Voice Select selected_voice = st.selectbox("配音音色 (TTS)", [config.VOLC_TTS_DEFAULT_VOICE, "zh_female_meilinvyou_saturn_bigtts"]) st.divider() # Global Timeline Editor st.markdown("### 🎙️ 旁白与字幕时间轴 (单位: 秒)") timeline = st.session_state.script_data.get("voiceover_timeline", []) updated_timeline = [] for i, item in enumerate(timeline): c1, c2, c3, c4 = st.columns([3, 3, 1, 1]) with c1: item["text"] = st.text_input(f"旁白 #{i+1}", value=item.get("text", ""), key=f"tune_vo_{i}") with c2: item["subtitle"] = st.text_input(f"字幕 #{i+1}", value=item.get("subtitle", item.get("text", "")), key=f"tune_sub_{i}") with c3: # 兼容旧格式: 如果有 start_ratio 则转换为 start_time default_start = item.get("start_time", item.get("start_ratio", 0) * 12) item["start_time"] = st.number_input(f"开始(秒) #{i+1}", value=float(default_start), min_value=0.0, max_value=30.0, step=0.5, key=f"tune_start_{i}") with c4: # 兼容旧格式: 如果有 duration_ratio 则转换为 duration default_dur = item.get("duration", item.get("duration_ratio", 0.25) * 12) item["duration"] = st.number_input(f"时长(秒) #{i+1}", value=float(default_dur), min_value=0.5, max_value=15.0, step=0.5, key=f"tune_dur_{i}") # 清理旧字段 item.pop("start_ratio", None) item.pop("duration_ratio", None) updated_timeline.append(item) st.session_state.script_data["voiceover_timeline"] = updated_timeline st.divider() # Per Scene Editor (Only Fancy Text) scenes = st.session_state.script_data.get("scenes", []) updated_scenes = [] for i, scene in enumerate(scenes): with st.container(): st.markdown(f"**Scene {scene['id']}**") ft = scene.get("fancy_text", {}) ft_text = ft.get("text", "") if isinstance(ft, dict) else "" new_ft = st.text_input(f"花字", value=ft_text, key=f"tune_ft_{i}") # 兼容:旧数据可能没有 fancy_text 字段 if not isinstance(scene.get("fancy_text"), dict): scene["fancy_text"] = {} scene["fancy_text"]["text"] = new_ft updated_scenes.append(scene) st.session_state.script_data["scenes"] = updated_scenes if st.button("🔄 重新合成 (Re-Compose)", type="primary"): with st.spinner("正在应用修改并重新合成..."): composer = VideoComposer(voice_type=selected_voice) bgm_path = None if selected_bgm != "None": bgm_path = str(config.ASSETS_DIR / "bgm" / selected_bgm) try: # Save updated script first db.update_project_script(st.session_state.project_id, st.session_state.script_data) t0 = perf_counter() output_path = composer.compose_from_script( script=st.session_state.script_data, video_map=st.session_state.scene_videos, bgm_path=bgm_path, output_name=f"final_{st.session_state.project_id}_{int(time.time())}", # Unique name for history project_id=st.session_state.project_id, ) _record_metrics(st.session_state.project_id, { "compose_s": round(perf_counter() - t0, 3), "bgm_used": bool(bgm_path and Path(bgm_path).exists()), }) st.session_state.final_video = output_path db.save_asset(st.session_state.project_id, 0, "final_video", "completed", local_path=output_path) st.success("重新合成成功!") st.rerun() except Exception as e: st.error(f"合成失败: {e}") with tab_preview: st.subheader("🛠️ 导出与交付") col_ex1, col_ex2 = st.columns([1, 2]) with col_ex1: if st.button("📦 生成剪映素材包 (ZIP)", type="primary", key="btn_export_zip"): with st.spinner("正在打包视频、音频与字幕..."): try: zip_path = export_utils.create_capcut_package( st.session_state.project_id, st.session_state.script_data, {"scene_videos": st.session_state.scene_videos} ) st.session_state['export_zip_path'] = zip_path st.success("打包完成!请点击下载") except Exception as e: st.error(f"打包失败: {e}") with col_ex2: if 'export_zip_path' in st.session_state and os.path.exists(st.session_state['export_zip_path']): zip_name = f"capcut_pack_{st.session_state.project_id}.zip" with open(st.session_state['export_zip_path'], "rb") as f: st.download_button( label=f"📥 点击下载素材包 ({zip_name})", data=f, file_name=zip_name, mime="application/zip" ) st.divider() # 扫描所有历史合成版本 import glob # 查找匹配当前项目的所有最终视频 pattern_timestamp = str(config.OUTPUT_DIR / f"final_{st.session_state.project_id}_*.mp4") pattern_simple = str(config.OUTPUT_DIR / f"final_{st.session_state.project_id}.mp4") found_files = [] for f in glob.glob(pattern_timestamp): found_files.append(f) for f in glob.glob(pattern_simple): if f not in found_files: found_files.append(f) # 按修改时间倒序排列 found_files.sort(key=os.path.getmtime, reverse=True) if not found_files: st.info("暂无合成视频,请先点击开始合成。") if st.button("✨ 开始首次合成", type="primary"): with st.spinner("正在进行多轨合成..."): # Default compose logic with smart BGM matching composer = VideoComposer(voice_type=config.VOLC_TTS_DEFAULT_VOICE) # 智能匹配 BGM:根据脚本 bgm_style 匹配 bgm_style = st.session_state.script_data.get("bgm_style", "") bgm_path = match_bgm_by_style(bgm_style, config.ASSETS_DIR / "bgm") if bgm_path and not Path(bgm_path).exists(): st.warning(f"推荐的 BGM 文件不存在:{bgm_path}(本次将不含 BGM)") bgm_path = None try: # 首次合成也加上时间戳 output_name = f"final_{st.session_state.project_id}_{int(time.time())}" t0 = perf_counter() output_path = composer.compose_from_script( script=st.session_state.script_data, video_map=st.session_state.scene_videos, bgm_path=bgm_path, output_name=output_name, project_id=st.session_state.project_id, ) _record_metrics(st.session_state.project_id, { "compose_s": round(perf_counter() - t0, 3), "bgm_used": bool(bgm_path and Path(bgm_path).exists()), }) st.session_state.final_video = output_path db.save_asset(st.session_state.project_id, 0, "final_video", "completed", local_path=output_path) db.update_project_status(st.session_state.project_id, "completed") st.rerun() except Exception as e: st.error(f"合成失败: {e}") else: st.success(f"共找到 {len(found_files)} 个历史版本") # 遍历显示所有历史版本 for idx, vid_path in enumerate(found_files): mtime = os.path.getmtime(vid_path) time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime)) file_name = os.path.basename(vid_path) # 容器样式 with st.container(): if idx == 0: st.markdown(f"### 🟢 最新版本 ({time_str})") else: st.markdown(f"#### ⚪ 历史版本 {len(found_files)-idx} ({time_str})") c1, c2, c3 = st.columns([1, 2, 1]) with c2: st.video(vid_path) # 下载按钮 with open(vid_path, "rb") as file: st.download_button( label=f"📥 下载 ({file_name})", data=file, file_name=file_name, mime="video/mp4", key=f"dl_btn_{file_name}" ) st.divider() # ============================================================ # Page: History (New) # ============================================================ elif st.session_state.view_mode == "history": st.header("📜 历史任务") projects = db.list_projects() for proj in projects: with st.expander(f"{proj['name']} ({proj['updated_at']})"): st.caption(f"ID: {proj['id']} | Status: {proj.get('status', 'unknown')}") assets = db.get_assets(proj['id']) if assets: st.markdown("#### 📥 下载素材") # Group assets final_vids = [a for a in assets if a['asset_type'] == 'final_video'] scene_imgs = [a for a in assets if a['asset_type'] == 'image'] scene_vids = [a for a in assets if a['asset_type'] == 'video'] # Final Video Download if final_vids: fv = final_vids[0] if fv['local_path'] and os.path.exists(fv['local_path']): st.success(f"✅ **最终成片已生成**") with open(fv['local_path'], "rb") as f: st.download_button( label=f"📥 下载最终成片 ({os.path.basename(fv['local_path'])})", data=f, file_name=os.path.basename(fv['local_path']), mime="video/mp4" ) # Script Download if proj.get("script_data"): try: script_json = json.dumps(proj["script_data"] if isinstance(proj["script_data"], dict) else json.loads(proj["script_data"]), ensure_ascii=False, indent=2) st.download_button( label="📥 下载脚本/字幕配置 (JSON)", data=script_json, file_name=f"script_{proj['id']}.json", mime="application/json" ) except: pass st.markdown("---") c1, c2 = st.columns(2) with c1: st.markdown("**分镜图片**") for img in scene_imgs: if img['local_path'] and os.path.exists(img['local_path']): with open(img['local_path'], "rb") as f: st.download_button(f"Scene {img['scene_id']} 图片", f, key=f"dl_img_{proj['id']}_{img['id']}", file_name=f"scene_{img['scene_id']}.png") with c2: st.markdown("**分镜视频**") for vid in scene_vids: if vid['local_path'] and os.path.exists(vid['local_path']): with open(vid['local_path'], "rb") as f: st.download_button(f"Scene {vid['scene_id']} 视频", f, key=f"dl_vid_{proj['id']}_{vid['id']}", file_name=f"scene_{vid['scene_id']}.mp4") else: st.info("暂无生成素材") # ============================================================ # Page: Settings (New) # ============================================================ elif st.session_state.view_mode == "settings": st.header("⚙️ 设置配置") st.subheader("Prompt 配置") # Script Generation Prompt current_prompt = db.get_config("prompt_script_gen") # 显示当前状态 if current_prompt: st.info("✅ 已加载自定义 Prompt(来自数据库)") else: st.warning("⚠️ 使用默认 Prompt(数据库中无自定义配置)") # Load default from instance if not in DB temp_gen = ScriptGenerator() current_prompt = temp_gen.default_system_prompt new_prompt = st.text_area("脚本拆分 Prompt (System Prompt)", value=current_prompt, height=400) col_save, col_reset = st.columns([1, 3]) with col_save: if st.button("💾 保存配置", type="primary"): db.set_config("prompt_script_gen", new_prompt, "System prompt for script generation step") # 验证保存 saved = db.get_config("prompt_script_gen") if saved == new_prompt: st.success("✅ 配置已保存并验证成功!下次生成脚本时将使用新 Prompt。") else: st.error("❌ 保存可能失败,请检查日志") with col_reset: if st.button("🔄 恢复默认"): temp_gen = ScriptGenerator() db.set_config("prompt_script_gen", temp_gen.default_system_prompt, "System prompt for script generation step (DEFAULT)") st.success("已恢复默认 Prompt,请刷新页面查看") st.rerun()