commit 33a165a615872129131690956ff65e7d48d569cc Author: Tony Zhang Date: Fri Dec 12 19:18:27 2025 +0800 feat: video-flow initial commit - app.py: Streamlit UI for video generation workflow - main_flow.py: CLI tool with argparse support - modules/: Business logic modules (script_gen, image_gen, video_gen, composer, etc.) - config.py: Configuration with API keys and paths - requirements.txt: Python dependencies - docs/: System prompt documentation diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7c718ab --- /dev/null +++ b/.gitignore @@ -0,0 +1,56 @@ +# Environment +.env +.env.local +.env.*.local + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +venv/ +ENV/ +*.egg-info/ + +# Output +output/ +*.mp4 +*.mp3 +*.wav +*.m4a + +# Assets (downloaded) +assets/fonts/*.ttf + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db + +# 参考 +参考/ + +# 素材 +素材/ + +# Images +*.png +*.jpeg +*.jpg + +# Database & Logs +*.db +*.log + +# Temp files +temp/ + +# Binaries +bin/ + diff --git a/app.py b/app.py new file mode 100644 index 0000000..a92f2c5 --- /dev/null +++ b/app.py @@ -0,0 +1,1059 @@ +""" +MatchMe Studio - UI (Streamlit) +Style: Kaogujia (Clean, Data-heavy, Professional) +""" +import streamlit as st +import json +import time +import os +import random +from pathlib import Path +import pandas as pd + +# Import Backend Modules +import config +from modules.script_gen import ScriptGenerator +from modules.image_gen import ImageGenerator +from modules.video_gen import VideoGenerator +from modules.composer import VideoComposer, VideoComposer as Composer # alias +from modules.text_renderer import renderer +from modules import export_utils +from modules.db_manager import db + +# Page Config +st.set_page_config( + page_title="Video Flow Console", + page_icon="🎬", + layout="wide", + initial_sidebar_state="expanded" +) + +# ============================================================ +# BGM 智能匹配函数 +# ============================================================ +def match_bgm_by_style(bgm_style: str, bgm_dir: Path) -> str: + """ + 根据脚本 bgm_style 智能匹配 BGM 文件 + - 匹配成功:随机选一个匹配的 BGM + - 匹配失败:随机选任意一个 BGM + """ + # 获取所有 BGM 文件 (支持 .mp3 和 .mp4) + bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3")) + bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')] + + if not bgm_files: + return None + + # 关键词匹配 + if bgm_style: + style_lower = bgm_style.lower() + # 提取关键词 (中文分词简化版:按常见词匹配) + keywords = ["活泼", "欢快", "轻松", "舒缓", "休闲", "温柔", "随性", "百搭", "bling", "节奏"] + matched_keywords = [kw for kw in keywords if kw in style_lower] + + matched_files = [] + for f in bgm_files: + fname = f.name + if any(kw in fname for kw in matched_keywords): + matched_files.append(f) + + if matched_files: + return str(random.choice(matched_files)) + + # 无匹配则随机选一个 + return str(random.choice(bgm_files)) + +# ============================================================ +# CSS Styling (Kaogujia Style) +# ============================================================ +st.markdown(""" + +""", unsafe_allow_html=True) + +# ============================================================ +# Session State Management +# ============================================================ +if "project_id" not in st.session_state: + st.session_state.project_id = None +if "current_step" not in st.session_state: + st.session_state.current_step = 0 +if "script_data" not in st.session_state: + st.session_state.script_data = None +if "scene_images" not in st.session_state: + st.session_state.scene_images = {} +if "scene_videos" not in st.session_state: + st.session_state.scene_videos = {} +if "final_video" not in st.session_state: + st.session_state.final_video = None +if "uploaded_images" not in st.session_state: + st.session_state.uploaded_images = [] +if "view_mode" not in st.session_state: + st.session_state.view_mode = "workspace" # workspace, history, settings +if "selected_img_provider" not in st.session_state: + st.session_state.selected_img_provider = "shubiaobiao" + +def load_project(project_id): + """Load project state from DB""" + data = db.get_project(project_id) + if not data: + st.error("Project not found") + return + + st.session_state.project_id = project_id + st.session_state.script_data = data.get("script_data") + st.session_state.view_mode = "workspace" + + # Restore product info for Step 1 display + product_info = data.get("product_info", {}) + st.session_state.loaded_product_name = data.get("name", "") + st.session_state.loaded_product_info = product_info + + # Restore uploaded images from product_info + if product_info and "uploaded_images" in product_info: + # 检查文件是否仍存在 + valid_paths = [p for p in product_info["uploaded_images"] if os.path.exists(p)] + st.session_state.uploaded_images = valid_paths + if len(valid_paths) < len(product_info["uploaded_images"]): + st.warning(f"部分原始图片已丢失 ({len(product_info['uploaded_images']) - len(valid_paths)} 张),建议重新上传") + else: + st.session_state.uploaded_images = [] + + # Restore assets + assets = db.get_assets(project_id) + images = {} + videos = {} + final_vid = None + + for asset in assets: + sid = asset["scene_id"] + # 假设 scene_id 0 或 -1 用于 final video + if asset["asset_type"] == "image" and asset["status"] == "completed": + images[sid] = asset["local_path"] + elif asset["asset_type"] == "video" and asset["status"] == "completed": + videos[sid] = asset["local_path"] + elif asset["asset_type"] == "final_video" and asset["status"] == "completed": + final_vid = asset["local_path"] + + st.session_state.scene_images = images + st.session_state.scene_videos = videos + st.session_state.final_video = final_vid + + # Determine step + if final_vid: + st.session_state.current_step = 4 + elif videos: + st.session_state.current_step = 4 + elif images: + st.session_state.current_step = 3 + elif st.session_state.script_data: + st.session_state.current_step = 1 + else: + st.session_state.current_step = 0 + +# ============================================================ +# Sidebar +# ============================================================ +with st.sidebar: + st.title("📽️ Video Flow") + + # Mode Selection - 正确计算 index + mode_options = ["🛠️ 工作台", "📜 历史任务", "⚙️ 设置"] + mode_map = {"workspace": 0, "history": 1, "settings": 2} + current_index = mode_map.get(st.session_state.view_mode, 0) + + mode = st.radio("模式", mode_options, index=current_index) + if mode == "🛠️ 工作台": + st.session_state.view_mode = "workspace" + elif mode == "📜 历史任务": + st.session_state.view_mode = "history" + elif mode == "⚙️ 设置": + st.session_state.view_mode = "settings" + + st.markdown("---") + + if st.session_state.view_mode == "workspace": + # Project Selection + st.subheader("Current Project") + projects = db.list_projects() + proj_options = {p['id']: f"{p.get('name', 'Untitled')} ({p['id']})" for p in projects} + + selected_proj_id = st.selectbox( + "Select Project", + options=["New Project"] + list(proj_options.keys()), + format_func=lambda x: "➕ New Project" if x == "New Project" else proj_options[x] + ) + + if selected_proj_id == "New Project": + if st.session_state.project_id and st.session_state.project_id in proj_options: + # If switching from existing to new, reset + if st.button("Start New"): + for key in list(st.session_state.keys()): + if key != "view_mode": del st.session_state[key] + st.rerun() + else: + if st.session_state.project_id != selected_proj_id: + if st.button(f"Load {selected_proj_id}"): + load_project(selected_proj_id) + st.rerun() + + if st.session_state.project_id: + st.caption(f"Current ID: {st.session_state.project_id}") + + st.markdown("---") + + # Navigation / Progress + steps = ["1. 输入信息", "2. 生成脚本", "3. 画面生成", "4. 视频生成", "5. 合成输出"] + current = st.session_state.current_step + + for i, step in enumerate(steps): + if i < current: + st.markdown(f"✅ **{step}**") + elif i == current: + st.markdown(f"🔵 **{step}**") + else: + st.markdown(f"⚪ {step}") + + st.markdown("---") + + if st.button("🔄 重置状态", type="secondary"): + for key in list(st.session_state.keys()): + if key != "view_mode": del st.session_state[key] + st.rerun() + +# ============================================================ +# Helper Functions +# ============================================================ +def save_uploaded_file(uploaded_file): + """Save uploaded file to temp dir.""" + if uploaded_file is not None: + file_path = config.TEMP_DIR / uploaded_file.name + with open(file_path, "wb") as f: + f.write(uploaded_file.getbuffer()) + return str(file_path) + return None + +# ============================================================ +# Main Content: Workspace +# ============================================================ +if st.session_state.view_mode == "workspace": + st.header("商品短视频自动生成控制台") + + # --- Step 1: Input --- + with st.expander("📦 1. 商品信息输入", expanded=(st.session_state.current_step == 0)): + # 从 session_state 读取已加载项目的信息,否则使用默认值 + loaded_info = st.session_state.get("loaded_product_info", {}) + default_name = st.session_state.get("loaded_product_name", "网红气质大号发量多!高马尾香蕉夹") + default_category = loaded_info.get("category", "钟表配饰-时尚饰品-发饰") + default_price = loaded_info.get("price", "3.99元") + default_tags = loaded_info.get("tags", "回头客|款式好看|材质好|尺寸合适|颜色好看|很好用|做工好|质感不错|很牢固") + default_params = loaded_info.get("params", "金属材质:非金属; 非金属材质:树脂; 发夹分类:香蕉夹") + + col1, col2 = st.columns([1, 1]) + + with col1: + product_name = st.text_input("商品标题", value=default_name) + category = st.text_input("商品类目", value=default_category) + price = st.text_input("价格", value=default_price) + + with col2: + tags = st.text_area("评价标签 (用于提炼卖点)", value=default_tags, height=100) + params = st.text_area("商品参数", value=default_params, height=100) + + # 商家自定义风格提示 + style_hint = st.text_area( + "商品视频重点增强提示 (可选)", + value=loaded_info.get("style_hint", ""), + placeholder="例如:韩风、高级感、活力青春、简约日系...", + height=80 + ) + + st.markdown("### 上传素材") + + # 显示已有的上传图片(如果从历史项目加载) + if st.session_state.uploaded_images: + st.info(f"已有 {len(st.session_state.uploaded_images)} 张参考图片") + with st.expander("查看已有图片"): + img_cols = st.columns(min(len(st.session_state.uploaded_images), 4)) + for i, img_path in enumerate(st.session_state.uploaded_images[:4]): + if os.path.exists(img_path): + img_cols[i % 4].image(img_path, width=150) + + uploaded_files = st.file_uploader("上传商品主图 (建议 3-5 张)", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True) + + # 允许在没有上传新图片但有历史图片的情况下继续 + can_submit = uploaded_files or st.session_state.uploaded_images + + # Model Selection + model_options = ["Gemini 3 Pro", "Doubao Pro (Vision)"] + selected_model_label = st.radio("选择脚本生成模型", model_options, horizontal=True) + # Map label to provider key + model_provider = "doubao" if "Doubao" in selected_model_label else "shubiaobiao" + + if st.button("提交任务 & 生成脚本", type="primary", disabled=(not can_submit)): + # 处理图片路径 + image_paths = list(st.session_state.uploaded_images) if st.session_state.uploaded_images else [] + + # 如果有新上传的文件,添加它们 + if uploaded_files: + # Create Project ID + if not st.session_state.project_id: + st.session_state.project_id = f"PROJ-{int(time.time())}" + + for uf in uploaded_files: + path = save_uploaded_file(uf) + if path: image_paths.append(path) + + st.session_state.uploaded_images = image_paths + + # 确保有项目 ID + if not st.session_state.project_id: + st.session_state.project_id = f"PROJ-{int(time.time())}" + + # DB: Create Project + # 将 uploaded_images 保存到 product_info 以便持久化 + product_info = {"category": category, "price": price, "tags": tags, "params": params, "style_hint": style_hint, "uploaded_images": image_paths} + db.create_project(st.session_state.project_id, product_name, product_info) + + # Call Script Generator + with st.spinner(f"正在分析商品信息并生成脚本 ({selected_model_label})..."): + gen = ScriptGenerator() + script = gen.generate_script(product_name, product_info, image_paths, model_provider=model_provider) + + if script: + st.session_state.script_data = script + # DB: Save Script + db.update_project_script(st.session_state.project_id, script) + + st.session_state.current_step = 1 + st.success("脚本生成成功!") + st.rerun() + else: + st.error("脚本生成失败,请检查日志。") + + # --- Step 2: Script Review --- + if st.session_state.current_step >= 1: + with st.expander("📝 2. 脚本分镜确认", expanded=(st.session_state.current_step == 1)): + if st.session_state.script_data: + script = st.session_state.script_data + + # Display Basic Info + c1, c2 = st.columns(2) + c1.write(f"**核心卖点**: {', '.join(script.get('selling_points', []))}") + c2.write(f"**目标人群**: {script.get('target_audience', '')}") + + # Prompt Visualization + if "_debug" in script: + with st.expander("🔍 查看 AI Prompt (Debug)"): + debug_info = script["_debug"] + st.markdown("**System Prompt:**") + st.code(debug_info.get("system_prompt", ""), language="markdown") + st.markdown("**User Prompt:**") + st.code(debug_info.get("user_prompt", ""), language="markdown") + + # 显示原始输出 + if "raw_output" in debug_info: + st.markdown("**Raw AI Output:**") + st.code(debug_info.get("raw_output", ""), language="markdown") + + # Editable Scenes Table + st.markdown("### 分镜列表") + scenes = script.get("scenes", []) + + # Global Voiceover Timeline (New) + st.markdown("### 🎙️ 整体旁白与字幕时间轴") + with st.expander("编辑旁白时间轴 (Voiceover Timeline)", expanded=True): + timeline = script.get("voiceover_timeline", []) + if not timeline: + # Init with default if empty (使用秒) + timeline = [{"text": "示例旁白", "subtitle": "示例字幕", "start_time": 0.0, "duration": 3.0}] + + updated_timeline = [] + for i, item in enumerate(timeline): + c1, c2, c3, c4 = st.columns([3, 3, 1, 1]) + with c1: + item["text"] = st.text_input(f"旁白 #{i+1}", value=item.get("text", ""), key=f"tl_vo_{i}") + with c2: + item["subtitle"] = st.text_input(f"字幕 #{i+1}", value=item.get("subtitle", item.get("text", "")), key=f"tl_sub_{i}") + with c3: + # 兼容旧格式: 如果有 start_ratio 则转换为 start_time + default_start = item.get("start_time", item.get("start_ratio", 0) * 12) # 假设总时长12秒 + item["start_time"] = st.number_input(f"开始(秒) #{i+1}", value=float(default_start), min_value=0.0, max_value=30.0, step=0.5, key=f"tl_start_{i}") + with c4: + # 兼容旧格式: 如果有 duration_ratio 则转换为 duration + default_dur = item.get("duration", item.get("duration_ratio", 0.25) * 12) # 假设总时长12秒 + item["duration"] = st.number_input(f"时长(秒) #{i+1}", value=float(default_dur), min_value=0.5, max_value=15.0, step=0.5, key=f"tl_dur_{i}") + # 清理旧字段 + item.pop("start_ratio", None) + item.pop("duration_ratio", None) + updated_timeline.append(item) + + if st.button("➕ 添加旁白段落"): + updated_timeline.append({"text": "", "subtitle": "", "start_time": 0.0, "duration": 3.0}) + st.rerun() + + script["voiceover_timeline"] = updated_timeline + + updated_scenes = [] + for i, scene in enumerate(scenes): + with st.container(): + st.markdown(f"#### 🎬 Scene {scene['id']}") + c_vis, c_aud = st.columns([2, 1]) + + with c_vis: + new_visual = st.text_area(f"Visual Prompt (Scene {scene['id']})", value=scene.get("visual_prompt", ""), height=80, key=f"vp_{i}") + new_video = st.text_area(f"Video Prompt (Scene {scene['id']})", value=scene.get("video_prompt", ""), height=80, key=f"vidp_{i}") + scene["visual_prompt"] = new_visual + scene["video_prompt"] = new_video + + with c_aud: + # 花字编辑保留 + ft = scene.get("fancy_text", {}) + if isinstance(ft, dict): + new_ft_text = st.text_input(f"Fancy Text (Scene {scene['id']})", value=ft.get("text", ""), key=f"ft_{i}") + scene["fancy_text"]["text"] = new_ft_text + + # 旁白/字幕已移至上方整体时间轴,此处仅作展示或删除 + st.caption("注:旁白与字幕已移至上方整体时间轴编辑") + + updated_scenes.append(scene) + st.divider() + + st.session_state.script_data["scenes"] = updated_scenes + + col_act1, col_act2 = st.columns([1, 4]) + with col_act1: + if st.button("确认脚本 & 开始生图", type="primary"): + # DB: Update script in case user edited it + db.update_project_script(st.session_state.project_id, st.session_state.script_data) + st.session_state.current_step = 2 + st.rerun() + + # --- Step 3: Image Generation --- + if st.session_state.current_step >= 2: + with st.expander("🎨 3. 画面生成", expanded=(st.session_state.current_step == 2)): + + # Debug: Show reference image + if st.session_state.uploaded_images: + with st.expander("🔍 查看生图参考底图 (Base Image)"): + for i, p in enumerate(st.session_state.uploaded_images): + st.info(f"图 {i+1} 文件名: {os.path.basename(p)}") + if os.path.exists(p): + st.image(p, width=200) + else: + st.warning("⚠️ 未检测到参考底图!将生成随机风格图像。建议在 Step 1 重新上传。") + + # Trigger Generation Button + if not st.session_state.scene_images: + # Model Selection for Image Gen + img_model_options = ["Shubiaobiao (Gemini)", "Doubao (Volcengine)", "Gemini (Direct)", "Doubao (Group Image)"] + selected_img_model = st.radio("选择生图模型", img_model_options, horizontal=True) + + # Map to provider + if "Group Image" in selected_img_model: + img_provider = "doubao-group" + elif "Doubao" in selected_img_model: + img_provider = "doubao" + elif "Direct" in selected_img_model: + img_provider = "gemini" + else: + img_provider = "shubiaobiao" + + # Store selected provider for regeneration + st.session_state.selected_img_provider = img_provider + + if st.button("🚀 执行 AI 生图", type="primary"): + img_gen = ImageGenerator() + # Pass ALL uploaded images as reference + base_imgs = st.session_state.uploaded_images if st.session_state.uploaded_images else [] + + if not base_imgs: + st.error("No base image found (未找到参考底图). Please upload in Step 1.") + st.stop() + + scenes = st.session_state.script_data.get("scenes", []) + + # 读取 visual_anchor 用于保持生图一致性 + visual_anchor = st.session_state.script_data.get("visual_anchor", "") + + if img_provider == "doubao-group": + # --- Group Generation Logic --- + with st.spinner("正在进行 Doubao 组图生成 (Batch Group Generation)..."): + try: + results = img_gen.generate_group_images_doubao( + scenes=scenes, + reference_images=base_imgs, + visual_anchor=visual_anchor + ) + + for s_id, path in results.items(): + st.session_state.scene_images[s_id] = path + db.save_asset(st.session_state.project_id, s_id, "image", "completed", local_path=path) + + if len(results) == len(scenes): + st.success("组图生成完成!") + db.update_project_status(st.session_state.project_id, "images_generated") + time.sleep(1) + st.rerun() + else: + st.warning(f"部分图片生成失败: {len(results)}/{len(scenes)}") + st.rerun() + + except Exception as e: + st.error(f"组图生成失败: {e}") + else: + # --- Sequential Logic --- + total_scenes = len(scenes) + progress_bar = st.progress(0) + status_text = st.empty() + + current_refs = list(base_imgs) # Start with base images + + try: + for idx, scene in enumerate(scenes): + scene_id = scene["id"] + status_text.text(f"正在生成 Scene {scene_id} ({idx+1}/{total_scenes}) using {selected_img_model}...") + + img_path = img_gen.generate_single_scene_image( + scene=scene, + original_image_path=current_refs, # Pass ALL accumulated images + previous_image_path=None, + model_provider=img_provider, + visual_anchor=visual_anchor + ) + + if img_path: + st.session_state.scene_images[scene_id] = img_path + current_refs.append(img_path) # Add newly generated image to references for next scene + db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=img_path) + + progress_bar.progress((idx + 1) / total_scenes) + + status_text.text("生图完成!") + st.success("生图完成!") + # Update Status + db.update_project_status(st.session_state.project_id, "images_generated") + time.sleep(1) + st.rerun() + + except PermissionError as e: + st.error(f"生图服务拒绝访问 (可能余额不足): {e}") + except Exception as e: + st.error(f"生图发生错误: {e}") + + # Display & Regenerate Images + if st.session_state.scene_images: + cols = st.columns(4) + scenes = st.session_state.script_data.get("scenes", []) + + for i, scene in enumerate(scenes): + scene_id = scene["id"] + img_path = st.session_state.scene_images.get(scene_id) + + with cols[i % 4]: + st.markdown(f"**Scene {scene_id}**") + + # Expand Prompt info + with st.expander("Prompt", expanded=False): + st.caption(scene.get("visual_prompt", "No prompt")) + + if img_path and os.path.exists(img_path): + st.image(img_path, width=None, use_column_width=True) # Fix deprecated + else: + st.error("Image missing") + + # Regenerate specific image + if st.button(f"🔄 重生 S{scene_id}", key=f"regen_img_{scene_id}"): + if not st.session_state.uploaded_images: + st.error("缺少参考底图,请在 Step 1 重新上传图片") + else: + with st.spinner(f"正在重绘 Scene {scene_id}..."): + img_gen = ImageGenerator() + # Use ALL uploaded images + previously generated images up to this point + current_refs_for_regen = list(st.session_state.uploaded_images) + for prev_s_id in range(1, scene_id): + if prev_s_id in st.session_state.scene_images: + current_refs_for_regen.append(st.session_state.scene_images[prev_s_id]) + + # Fallback to single mode for regen if group was used + provider = st.session_state.get("selected_img_provider", "shubiaobiao") + if provider == "doubao-group": provider = "doubao" + + # 读取 visual_anchor + regen_visual_anchor = st.session_state.script_data.get("visual_anchor", "") + + new_path = img_gen.generate_single_scene_image( + scene=scene, + original_image_path=current_refs_for_regen, + previous_image_path=None, + model_provider=provider, + visual_anchor=regen_visual_anchor + ) + if new_path: + st.session_state.scene_images[scene_id] = new_path + db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=new_path) + st.rerun() + + if st.button("下一步:生成视频", type="primary"): + st.session_state.current_step = 3 + st.rerun() + + # --- Step 4: Video Generation --- + if st.session_state.current_step >= 3: + with st.expander("🎥 4. 视频生成 (Volcengine I2V)", expanded=(st.session_state.current_step == 3)): + + if not st.session_state.scene_videos: + if st.button("🎬 执行图生视频", type="primary"): + with st.spinner("正在生成视频 (耗时较长)..."): + vid_gen = VideoGenerator() + # Pass project_id + videos = vid_gen.generate_scene_videos( + st.session_state.project_id, + st.session_state.script_data, + st.session_state.scene_images + ) + + if videos: + st.session_state.scene_videos = videos + for sid, path in videos.items(): + db.save_asset(st.session_state.project_id, sid, "video", "completed", local_path=path) + + # Update Status + db.update_project_status(st.session_state.project_id, "videos_generated") + st.success("视频生成完成!") + st.rerun() + else: + st.warning("部分或全部视频生成失败") + + # Display Videos + if st.session_state.scene_videos: + cols = st.columns(4) + scenes = st.session_state.script_data.get("scenes", []) + + for i, scene in enumerate(scenes): + scene_id = scene["id"] + vid_path = st.session_state.scene_videos.get(scene_id) + + with cols[i % 4]: + st.markdown(f"**Scene {scene_id}**") + # Expand Prompt info + with st.expander("Prompt", expanded=False): + st.caption(scene.get("video_prompt", "No prompt")) + + if vid_path and os.path.exists(vid_path): + st.video(vid_path) + else: + st.warning("Video missing") + # --- Recovery Logic --- + asset = db.get_asset(st.session_state.project_id, scene_id, "video") + if asset and asset.get("task_id"): + task_id = asset.get("task_id") + if st.button(f"🔍 找回视频 (Task {task_id[-6:]})", key=f"recov_{scene_id}"): + with st.spinner("查询任务状态中..."): + vid_gen = VideoGenerator() + output_filename = f"scene_{scene_id}_video.mp4" + target_path = str(config.TEMP_DIR / output_filename) + + if vid_gen.recover_video_from_task(task_id, target_path): + st.session_state.scene_videos[scene_id] = target_path + db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=target_path) + st.success("找回成功!") + st.rerun() + else: + st.error("找回失败") + + # Per-scene regenerate button + if st.button(f"🔄 重生 S{scene_id}", key=f"regen_vid_{scene_id}"): + with st.spinner(f"正在重生成 Scene {scene_id} 视频..."): + vid_gen = VideoGenerator() + img_path = st.session_state.scene_images.get(scene_id) + if img_path and os.path.exists(img_path): + t_id = vid_gen.submit_scene_video_task(st.session_state.project_id, scene_id, img_path, scene.get("video_prompt")) + if t_id: + # Simple poll loop for single video regen + import time + for _ in range(60): + status, url = vid_gen.check_task_status(t_id) + if status == "succeeded": + new_path = vid_gen._download_video(url, f"scene_{scene_id}_video_{int(time.time())}.mp4") + if new_path: + st.session_state.scene_videos[scene_id] = new_path + db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=new_path, task_id=t_id) + st.rerun() + break + elif status in ["failed", "cancelled"]: + st.error(f"生成失败: {status}") + break + time.sleep(5) + else: + st.error("提交任务失败") + else: + st.error("缺少源图片") + + c_act1, c_act2 = st.columns([1, 4]) + with c_act1: + if st.button("🔄 重新生成所有视频", type="secondary"): + # Clear videos and rerun + st.session_state.scene_videos = {} + st.rerun() + + with c_act2: + if st.button("下一步:合成最终成片", type="primary"): + st.session_state.current_step = 4 + st.rerun() + + # --- Step 5: Final Composition & Tuning --- + if st.session_state.current_step >= 4: + with st.expander("🎞️ 5. 合成结果与微调", expanded=(st.session_state.current_step == 4)): + + # Tabs: Preview vs Tuning + tab_preview, tab_tune = st.tabs(["🎥 预览", "🎛️ 微调 (Track Editor)"]) + + with tab_tune: + st.subheader("轨道微调 (Tuning)") + + # Global Settings + col_g1, col_g2 = st.columns(2) + with col_g1: + # BGM Select (支持 .mp3 和 .mp4) + bgm_dir = config.ASSETS_DIR / "bgm" + bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3")) + bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')] + bgm_names = [f.name for f in bgm_files] + + # 智能推荐默认 BGM + bgm_style = st.session_state.script_data.get("bgm_style", "") + recommended_bgm = match_bgm_by_style(bgm_style, bgm_dir) + default_idx = 0 + if recommended_bgm: + rec_name = Path(recommended_bgm).name + if rec_name in bgm_names: + default_idx = bgm_names.index(rec_name) + 1 # +1 因为有 "None" 选项 + + selected_bgm = st.selectbox( + f"背景音乐 (BGM) - 推荐风格: {bgm_style or '未指定'}", + ["None"] + bgm_names, + index=default_idx + ) + with col_g2: + # Voice Select + selected_voice = st.selectbox("配音音色 (TTS)", [config.VOLC_TTS_DEFAULT_VOICE, "zh_female_meilinvyou_saturn_bigtts"]) + + st.divider() + + # Global Timeline Editor + st.markdown("### 🎙️ 旁白与字幕时间轴 (单位: 秒)") + timeline = st.session_state.script_data.get("voiceover_timeline", []) + updated_timeline = [] + for i, item in enumerate(timeline): + c1, c2, c3, c4 = st.columns([3, 3, 1, 1]) + with c1: + item["text"] = st.text_input(f"旁白 #{i+1}", value=item.get("text", ""), key=f"tune_vo_{i}") + with c2: + item["subtitle"] = st.text_input(f"字幕 #{i+1}", value=item.get("subtitle", item.get("text", "")), key=f"tune_sub_{i}") + with c3: + # 兼容旧格式: 如果有 start_ratio 则转换为 start_time + default_start = item.get("start_time", item.get("start_ratio", 0) * 12) + item["start_time"] = st.number_input(f"开始(秒) #{i+1}", value=float(default_start), min_value=0.0, max_value=30.0, step=0.5, key=f"tune_start_{i}") + with c4: + # 兼容旧格式: 如果有 duration_ratio 则转换为 duration + default_dur = item.get("duration", item.get("duration_ratio", 0.25) * 12) + item["duration"] = st.number_input(f"时长(秒) #{i+1}", value=float(default_dur), min_value=0.5, max_value=15.0, step=0.5, key=f"tune_dur_{i}") + # 清理旧字段 + item.pop("start_ratio", None) + item.pop("duration_ratio", None) + updated_timeline.append(item) + + st.session_state.script_data["voiceover_timeline"] = updated_timeline + + st.divider() + + # Per Scene Editor (Only Fancy Text) + scenes = st.session_state.script_data.get("scenes", []) + updated_scenes = [] + + for i, scene in enumerate(scenes): + with st.container(): + st.markdown(f"**Scene {scene['id']}**") + ft = scene.get("fancy_text", {}) + ft_text = ft.get("text", "") if isinstance(ft, dict) else "" + new_ft = st.text_input(f"花字", value=ft_text, key=f"tune_ft_{i}") + if isinstance(scene.get("fancy_text"), dict): + scene["fancy_text"]["text"] = new_ft + + updated_scenes.append(scene) + + st.session_state.script_data["scenes"] = updated_scenes + + if st.button("🔄 重新合成 (Re-Compose)", type="primary"): + with st.spinner("正在应用修改并重新合成..."): + composer = VideoComposer(voice_type=selected_voice) + + bgm_path = None + if selected_bgm != "None": + bgm_path = str(config.ASSETS_DIR / "bgm" / selected_bgm) + + try: + # Save updated script first + db.update_project_script(st.session_state.project_id, st.session_state.script_data) + + output_path = composer.compose_from_script( + script=st.session_state.script_data, + video_map=st.session_state.scene_videos, + bgm_path=bgm_path, + output_name=f"final_{st.session_state.project_id}_{int(time.time())}" # Unique name for history + ) + st.session_state.final_video = output_path + db.save_asset(st.session_state.project_id, 0, "final_video", "completed", local_path=output_path) + + st.success("重新合成成功!") + st.rerun() + except Exception as e: + st.error(f"合成失败: {e}") + + with tab_preview: + st.subheader("🛠️ 导出与交付") + + col_ex1, col_ex2 = st.columns([1, 2]) + with col_ex1: + if st.button("📦 生成剪映素材包 (ZIP)", type="primary", key="btn_export_zip"): + with st.spinner("正在打包视频、音频与字幕..."): + try: + zip_path = export_utils.create_capcut_package( + st.session_state.project_id, + st.session_state.script_data, + {"scene_videos": st.session_state.scene_videos} + ) + st.session_state['export_zip_path'] = zip_path + st.success("打包完成!请点击下载") + except Exception as e: + st.error(f"打包失败: {e}") + + with col_ex2: + if 'export_zip_path' in st.session_state and os.path.exists(st.session_state['export_zip_path']): + zip_name = f"capcut_pack_{st.session_state.project_id}.zip" + with open(st.session_state['export_zip_path'], "rb") as f: + st.download_button( + label=f"📥 点击下载素材包 ({zip_name})", + data=f, + file_name=zip_name, + mime="application/zip" + ) + st.divider() + + # 扫描所有历史合成版本 + import glob + + # 查找匹配当前项目的所有最终视频 + pattern_timestamp = str(config.OUTPUT_DIR / f"final_{st.session_state.project_id}_*.mp4") + pattern_simple = str(config.OUTPUT_DIR / f"final_{st.session_state.project_id}.mp4") + + found_files = [] + for f in glob.glob(pattern_timestamp): + found_files.append(f) + for f in glob.glob(pattern_simple): + if f not in found_files: + found_files.append(f) + + # 按修改时间倒序排列 + found_files.sort(key=os.path.getmtime, reverse=True) + + if not found_files: + st.info("暂无合成视频,请先点击开始合成。") + if st.button("✨ 开始首次合成", type="primary"): + with st.spinner("正在进行多轨合成..."): + # Default compose logic with smart BGM matching + composer = VideoComposer(voice_type=config.VOLC_TTS_DEFAULT_VOICE) + + # 智能匹配 BGM:根据脚本 bgm_style 匹配 + bgm_style = st.session_state.script_data.get("bgm_style", "") + bgm_path = match_bgm_by_style(bgm_style, config.ASSETS_DIR / "bgm") + + try: + # 首次合成也加上时间戳 + output_name = f"final_{st.session_state.project_id}_{int(time.time())}" + output_path = composer.compose_from_script( + script=st.session_state.script_data, + video_map=st.session_state.scene_videos, + bgm_path=bgm_path, + output_name=output_name + ) + st.session_state.final_video = output_path + db.save_asset(st.session_state.project_id, 0, "final_video", "completed", local_path=output_path) + db.update_project_status(st.session_state.project_id, "completed") + st.rerun() + except Exception as e: + st.error(f"合成失败: {e}") + + else: + st.success(f"共找到 {len(found_files)} 个历史版本") + + # 遍历显示所有历史版本 + for idx, vid_path in enumerate(found_files): + mtime = os.path.getmtime(vid_path) + time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime)) + file_name = os.path.basename(vid_path) + + # 容器样式 + with st.container(): + if idx == 0: + st.markdown(f"### 🟢 最新版本 ({time_str})") + else: + st.markdown(f"#### ⚪ 历史版本 {len(found_files)-idx} ({time_str})") + + c1, c2, c3 = st.columns([1, 2, 1]) + with c2: + st.video(vid_path) + + # 下载按钮 + with open(vid_path, "rb") as file: + st.download_button( + label=f"📥 下载 ({file_name})", + data=file, + file_name=file_name, + mime="video/mp4", + key=f"dl_btn_{file_name}" + ) + st.divider() + +# ============================================================ +# Page: History (New) +# ============================================================ +elif st.session_state.view_mode == "history": + st.header("📜 历史任务") + + projects = db.list_projects() + + for proj in projects: + with st.expander(f"{proj['name']} ({proj['updated_at']})"): + st.caption(f"ID: {proj['id']} | Status: {proj.get('status', 'unknown')}") + + assets = db.get_assets(proj['id']) + + if assets: + st.markdown("#### 📥 下载素材") + + # Group assets + final_vids = [a for a in assets if a['asset_type'] == 'final_video'] + scene_imgs = [a for a in assets if a['asset_type'] == 'image'] + scene_vids = [a for a in assets if a['asset_type'] == 'video'] + + # Final Video Download + if final_vids: + fv = final_vids[0] + if fv['local_path'] and os.path.exists(fv['local_path']): + st.success(f"✅ **最终成片已生成**") + with open(fv['local_path'], "rb") as f: + st.download_button( + label=f"📥 下载最终成片 ({os.path.basename(fv['local_path'])})", + data=f, + file_name=os.path.basename(fv['local_path']), + mime="video/mp4" + ) + + # Script Download + if proj.get("script_data"): + try: + script_json = json.dumps(proj["script_data"] if isinstance(proj["script_data"], dict) else json.loads(proj["script_data"]), ensure_ascii=False, indent=2) + st.download_button( + label="📥 下载脚本/字幕配置 (JSON)", + data=script_json, + file_name=f"script_{proj['id']}.json", + mime="application/json" + ) + except: + pass + + st.markdown("---") + + c1, c2 = st.columns(2) + with c1: + st.markdown("**分镜图片**") + for img in scene_imgs: + if img['local_path'] and os.path.exists(img['local_path']): + with open(img['local_path'], "rb") as f: + st.download_button(f"Scene {img['scene_id']} 图片", f, key=f"dl_img_{proj['id']}_{img['id']}", file_name=f"scene_{img['scene_id']}.png") + + with c2: + st.markdown("**分镜视频**") + for vid in scene_vids: + if vid['local_path'] and os.path.exists(vid['local_path']): + with open(vid['local_path'], "rb") as f: + st.download_button(f"Scene {vid['scene_id']} 视频", f, key=f"dl_vid_{proj['id']}_{vid['id']}", file_name=f"scene_{vid['scene_id']}.mp4") + else: + st.info("暂无生成素材") + +# ============================================================ +# Page: Settings (New) +# ============================================================ +elif st.session_state.view_mode == "settings": + st.header("⚙️ 设置配置") + + st.subheader("Prompt 配置") + + # Script Generation Prompt + current_prompt = db.get_config("prompt_script_gen") + + # 显示当前状态 + if current_prompt: + st.info("✅ 已加载自定义 Prompt(来自数据库)") + else: + st.warning("⚠️ 使用默认 Prompt(数据库中无自定义配置)") + # Load default from instance if not in DB + temp_gen = ScriptGenerator() + current_prompt = temp_gen.default_system_prompt + + new_prompt = st.text_area("脚本拆分 Prompt (System Prompt)", value=current_prompt, height=400) + + col_save, col_reset = st.columns([1, 3]) + with col_save: + if st.button("💾 保存配置", type="primary"): + db.set_config("prompt_script_gen", new_prompt, "System prompt for script generation step") + # 验证保存 + saved = db.get_config("prompt_script_gen") + if saved == new_prompt: + st.success("✅ 配置已保存并验证成功!下次生成脚本时将使用新 Prompt。") + else: + st.error("❌ 保存可能失败,请检查日志") + + with col_reset: + if st.button("🔄 恢复默认"): + temp_gen = ScriptGenerator() + db.set_config("prompt_script_gen", temp_gen.default_system_prompt, "System prompt for script generation step (DEFAULT)") + st.success("已恢复默认 Prompt,请刷新页面查看") + st.rerun() + diff --git a/assets/fonts/NotoSansSC-Bold.otf b/assets/fonts/NotoSansSC-Bold.otf new file mode 100644 index 0000000..3bec311 --- /dev/null +++ b/assets/fonts/NotoSansSC-Bold.otf @@ -0,0 +1,2152 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Page not found · GitHub · GitHub + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + + +
+ Skip to content + + + + + + + + + + + +
+
+ + + + + + + + + + + + + + + + + +
+ +
+ + + + + + + + +
+ + + + + +
+ + + + + + + + + +
+
+ + + +
+
+ +
+
+ 404 “This is not the web page you are looking for” + + + + + + + + + + + + +
+
+ +
+
+ +
+ + +
+
+ +
+ +
+ +
+ + + + + + + + + + + + + + + + + + + + + + +
+
+
+ + + diff --git a/assets/fonts/NotoSansSC-Regular.otf b/assets/fonts/NotoSansSC-Regular.otf new file mode 100644 index 0000000..8c1f434 Binary files /dev/null and b/assets/fonts/NotoSansSC-Regular.otf differ diff --git a/config.py b/config.py new file mode 100644 index 0000000..9d741e4 --- /dev/null +++ b/config.py @@ -0,0 +1,181 @@ +""" +MatchMe Studio - Configuration +""" +import os +from pathlib import Path +from dotenv import load_dotenv + +load_dotenv() + +# ============================================================ +# API Keys +# ============================================================ + +# Volcengine / Doubao (Official) +VOLC_API_KEY = os.getenv("VOLC_API_KEY", "05aed9c1-f5e6-487b-9273-fe7d6be51957") +VOLC_BASE_URL = os.getenv("VOLC_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3") + +# Models (Updated with User-Provided Endpoint IDs) +# LLM: Doubao Pro 1.5 (Using provided brain/vision endpoint) +BRAIN_MODEL_ID = os.getenv("BRAIN_MODEL_ID", "ep-20251203231055-dpsp7") +# Vision: Doubao Vision Pro 1.5 +VISION_MODEL_ID = os.getenv("VISION_MODEL_ID", "ep-20251203232121-xjt2s") +# Image: Doubao Image (Updated to user provided model) +IMAGE_MODEL_ID = os.getenv("IMAGE_MODEL_ID", "ep-20251203231641-wg9nb") +# Video: Doubao Video (PixelDance) +VIDEO_MODEL_ID = os.getenv("VIDEO_MODEL_ID", "ep-20251207100506-rjx4x") + +# Doubao Specifics (User Provided) +DOUBAO_SCRIPT_MODEL = "ep-20251203231055-dpsp7" +DOUBAO_IMG_MODEL = "ep-20251203231641-wg9nb" + + +# Text/Brain API (Legacy) +SHUBIAOBIAO_KEY = os.getenv("SHUBIAOBIAO_KEY", "sk-aL167A8sQEyvs40yBfC140Fc0fDa4c198f029aAcF0429108") +SHUBIAOBIAO_BASE_URL = os.getenv("SHUBIAOBIAO_BASE_URL", "https://api.shubiaobiao.cn/v1") +SHUBIAOBIAO_MODEL_TEXT = "gemini-3-pro-preview" + +# Image Generation API (Updated) +# Host: https://api.wuyinkeji.com/ +# Model: nanoBanana-pro (Gemini) +GEMINI_IMG_KEY = os.getenv("GEMINI_IMG_KEY", "G9rXx3Ag2Xfa7Gs8zou6t6HqeZ") +GEMINI_IMG_API_URL = os.getenv("GEMINI_IMG_API_URL", "https://api.wuyinkeji.com/api/img/nanoBanana-pro") +GEMINI_IMG_DETAIL_URL = os.getenv("GEMINI_IMG_DETAIL_URL", "https://api.wuyinkeji.com/api/img/drawDetail") + +# Legacy Image API +SHUBIAOBIAO_IMG_KEY = os.getenv("SHUBIAOBIAO_IMG_KEY", "sk-1yr2h4sJybHB7DED57CeF446D08c4bC989F621Db5b48E70d") +SHUBIAOBIAO_IMG_BASE_URL = os.getenv("SHUBIAOBIAO_IMG_BASE_URL", "https://api2img.shubiaobiao.com") +SHUBIAOBIAO_IMG_MODEL_NAME = "gemini-3-pro-image-preview" + +# Backup +FAL_KEY = os.getenv("FAL_KEY", "") +KLING_ACCESS_KEY = os.getenv("KLING_ACCESS_KEY", "") +KLING_SECRET_KEY = os.getenv("KLING_SECRET_KEY", "") + +XI_KEY = os.getenv("XI_KEY", "") + +# ============================================================ +# Cloudflare R2 Storage +# ============================================================ +R2_ENDPOINT = os.getenv("R2_ENDPOINT", "") +R2_ACCESS_KEY = os.getenv("R2_ACCESS_KEY", "") +R2_SECRET_KEY = os.getenv("R2_SECRET_KEY", "") +R2_BUCKET_NAME = os.getenv("R2_BUCKET_NAME", "mms-assets") +# Public URL for accessing uploaded files +R2_PUBLIC_URL = os.getenv("R2_PUBLIC_URL", "https://pub-7942a75aa66d4315a628ee464267ebf4.r2.dev") + +# ============================================================ +# ElevenLabs Settings (Legacy - for English) +# ============================================================ +ELEVENLABS_VOICE_ID = os.getenv("XI_VOICE_ID", "21m00Tcm4TlvDq8ikWAM") +ELEVENLABS_MODEL = "eleven_turbo_v2_5" + +# ============================================================ +# Volcengine TTS Settings (火山引擎语音合成 - 中文) +# ============================================================ +# 申请地址: https://console.volcengine.com/speech/service/8 +VOLC_TTS_APPID = os.getenv("VOLC_TTS_APPID", "6771884088") +VOLC_TTS_ACCESS_TOKEN = os.getenv("VOLC_TTS_ACCESS_TOKEN", "Q5sR2SNfxO8Vb9g2ucsaqfUGOpcpZi3S") +VOLC_TTS_SECRET_KEY = os.getenv("VOLC_TTS_SECRET_KEY", "RXc2WiA6OK6G1xuEZ7cyAU3Q3B5Z1oUx") + +# 默认音色 +# 抖音热门带货音色推荐: +# - BV700_streaming: 甜美小媛(甜美活泼,适合美妆/好物)- 可能无权限 +# - zh_female_santongyongns_saturn_bigtts: 三通永(已验证可用) +# - zh_female_meilinvyou_saturn_bigtts: 美丽女友(已验证可用) +VOLC_TTS_DEFAULT_VOICE = os.getenv("VOLC_TTS_VOICE", "zh_female_santongyongns_saturn_bigtts") + +# ============================================================ +# Video Settings +# ============================================================ +VIDEO_SETTINGS = { + "width": 1080, + "height": 1920, + "fps": 30, + "format": "mp4", + "codec": "libx264", +} + +# ============================================================ +# Paths +# ============================================================ +BASE_DIR = Path(__file__).parent +OUTPUT_DIR = BASE_DIR / "output" +TEMP_DIR = BASE_DIR / "temp" +ASSETS_DIR = BASE_DIR / "assets" +FONTS_DIR = ASSETS_DIR / "fonts" + +# Ensure directories exist +OUTPUT_DIR.mkdir(exist_ok=True) +TEMP_DIR.mkdir(exist_ok=True) +ASSETS_DIR.mkdir(exist_ok=True) +FONTS_DIR.mkdir(exist_ok=True) + +# ============================================================ +# Database Configuration +# ============================================================ +# Format: postgresql://user:password@host:port/dbname +# Default to SQLite if not provided +DB_CONNECTION_STRING = os.getenv("DB_CONNECTION_STRING", f"sqlite:///{BASE_DIR}/video_flow.db") + +# ============================================================ +# Font Settings (字体配置) +# ============================================================ +# 优先检测系统字体,防止乱码 +SYSTEM_FONTS = [ + str(FONTS_DIR / "SmileySans-Oblique.otf"), + str(FONTS_DIR / "HarmonyOS-Sans-SC-Regular.ttf"), + str(FONTS_DIR / "HarmonyOS-Sans-SC-Bold.ttf"), + str(FONTS_DIR / "NotoSansSC-Regular.otf"), + str(FONTS_DIR / "NotoSansSC-Bold.otf"), + "/System/Library/Fonts/PingFang.ttc", + "/System/Library/Fonts/STHeiti Medium.ttc", + "/System/Library/Fonts/Supplemental/Arial Unicode.ttf", +] + +DEFAULT_FONT = str(FONTS_DIR / "NotoSansSC-Regular.otf") +DEFAULT_FONT_BOLD = str(FONTS_DIR / "NotoSansSC-Bold.otf") + +# 检查项目字体是否存在,不存在则使用系统字体 +def pick_font(): + for f in SYSTEM_FONTS: + if os.path.exists(f) and os.path.getsize(f) > 1000: + return f + return "/System/Library/Fonts/PingFang.ttc" + +DEFAULT_FONT = pick_font() +DEFAULT_FONT_BOLD = DEFAULT_FONT + +# 花字样式预设 +FANCY_TEXT_STYLES = { + "subtitle": { + "font_size": 48, + "font_color": "#FFFFFF", + "stroke_color": "#000000", + "stroke_width": 3 + }, + "highlight": { + "font_size": 56, + "font_color": "#FFE66D", + "stroke_color": "#000000", + "stroke_width": 4 + }, + "warning": { + "font_size": 52, + "font_color": "#FF4444", + "stroke_color": "#FFFFFF", + "stroke_width": 4 + }, + "price": { + "font_size": 72, + "price_color": "#FF4444", + "stroke_color": "#FFFFFF", + "stroke_width": 5 + }, + "button": { + "font_size": 36, + "font_color": "#FFFFFF", + "bg_color": "#FF6B35", + "corner_radius": 25 + } +} diff --git a/deploy.py b/deploy.py new file mode 100644 index 0000000..0027ad1 --- /dev/null +++ b/deploy.py @@ -0,0 +1,255 @@ +""" +Gloda Video Factory - Deployment Script +One-click deployment to remote server using Fabric. +""" + +import os +import sys +from pathlib import Path +from dotenv import load_dotenv + +from fabric import Connection, Config +from invoke import task + +# Load environment variables +load_dotenv() + +# Server configuration +SERVER_IP = os.getenv("SERVER_IP", "") +SERVER_USER = os.getenv("SERVER_USER", "root") +SERVER_PASS = os.getenv("SERVER_PASS", "") + +# Remote paths +REMOTE_APP_DIR = "/opt/gloda-factory" +REMOTE_VENV = f"{REMOTE_APP_DIR}/venv" + +# Files to upload +LOCAL_FILES = [ + "config.py", + "web_app.py", + "requirements.txt", + ".env", + "modules/__init__.py", + "modules/utils.py", + "modules/brain.py", + "modules/factory.py", + "modules/editor.py", +] + + +def get_connection() -> Connection: + """Create SSH connection to remote server.""" + if not SERVER_IP or not SERVER_PASS: + raise ValueError("SERVER_IP and SERVER_PASS must be set in .env") + + config = Config(overrides={"sudo": {"password": SERVER_PASS}}) + return Connection( + host=SERVER_IP, + user=SERVER_USER, + connect_kwargs={"password": SERVER_PASS}, + config=config + ) + + +def deploy(): + """Full deployment: setup server, upload code, start app.""" + print("🚀 Starting deployment...") + + conn = get_connection() + + # Step 1: Install system dependencies + print("\n📦 Step 1/5: Installing system dependencies...") + install_dependencies(conn) + + # Step 2: Create app directory + print("\n📁 Step 2/5: Setting up directories...") + setup_directories(conn) + + # Step 3: Upload code + print("\n📤 Step 3/5: Uploading code...") + upload_code(conn) + + # Step 4: Setup Python environment + print("\n🐍 Step 4/5: Setting up Python environment...") + setup_python(conn) + + # Step 5: Start application + print("\n🎬 Step 5/5: Starting application...") + start_app(conn) + + print(f"\n✅ Deployment complete!") + print(f"🌐 Access the app at: http://{SERVER_IP}:8501") + + conn.close() + + +def install_dependencies(conn: Connection): + """Install system-level dependencies.""" + commands = [ + "apt-get update -qq", + "apt-get install -y -qq python3 python3-pip python3-venv", + "apt-get install -y -qq ffmpeg imagemagick", + "apt-get install -y -qq fonts-liberation fonts-dejavu-core", + ] + + for cmd in commands: + print(f" Running: {cmd[:50]}...") + conn.sudo(cmd, hide=True) + + # Configure ImageMagick policy (allow PDF/SVG for text rendering) + policy_fix = """ +sed -i 's//dev/null || true +""" + conn.sudo(policy_fix, hide=True, warn=True) + + print(" ✅ System dependencies installed") + + +def setup_directories(conn: Connection): + """Create application directories on remote server.""" + conn.sudo(f"mkdir -p {REMOTE_APP_DIR}/modules", hide=True) + conn.sudo(f"mkdir -p {REMOTE_APP_DIR}/output", hide=True) + conn.sudo(f"mkdir -p {REMOTE_APP_DIR}/assets/fonts", hide=True) + conn.sudo(f"chown -R {SERVER_USER}:{SERVER_USER} {REMOTE_APP_DIR}", hide=True) + + print(f" ✅ Directories created at {REMOTE_APP_DIR}") + + +def upload_code(conn: Connection): + """Upload application code to remote server.""" + local_base = Path(__file__).parent + + for file_path in LOCAL_FILES: + local_file = local_base / file_path + remote_file = f"{REMOTE_APP_DIR}/{file_path}" + + if local_file.exists(): + # Ensure remote directory exists + remote_dir = str(Path(remote_file).parent) + conn.run(f"mkdir -p {remote_dir}", hide=True) + + # Upload file + conn.put(str(local_file), remote_file) + print(f" ✅ Uploaded: {file_path}") + else: + print(f" ⚠️ Skipped (not found): {file_path}") + + print(" ✅ Code uploaded") + + +def setup_python(conn: Connection): + """Setup Python virtual environment and install dependencies.""" + with conn.cd(REMOTE_APP_DIR): + # Create virtual environment + conn.run(f"python3 -m venv {REMOTE_VENV}", hide=True) + + # Upgrade pip + conn.run(f"{REMOTE_VENV}/bin/pip install --upgrade pip -q", hide=True) + + # Install requirements + conn.run(f"{REMOTE_VENV}/bin/pip install -r requirements.txt -q", hide=True) + + print(" ✅ Python environment ready") + + +def start_app(conn: Connection): + """Start the Streamlit application.""" + # Stop existing process if any + conn.run("pkill -f 'streamlit run web_app.py' || true", hide=True, warn=True) + + # Start in background with nohup + start_cmd = f""" +cd {REMOTE_APP_DIR} && \ +nohup {REMOTE_VENV}/bin/streamlit run web_app.py \ + --server.port 8501 \ + --server.address 0.0.0.0 \ + --server.headless true \ + --browser.gatherUsageStats false \ + > /var/log/gloda-factory.log 2>&1 & +""" + conn.run(start_cmd, hide=True) + + # Wait and verify + import time + time.sleep(3) + + result = conn.run("pgrep -f 'streamlit run web_app.py'", hide=True, warn=True) + if result.ok: + print(f" ✅ Application started (PID: {result.stdout.strip()})") + else: + print(" ⚠️ Application may not have started. Check logs.") + + +def stop_app(): + """Stop the running application.""" + print("🛑 Stopping application...") + conn = get_connection() + conn.run("pkill -f 'streamlit run web_app.py' || true", hide=True, warn=True) + conn.close() + print("✅ Application stopped") + + +def logs(): + """Show application logs.""" + print("📋 Recent logs:") + conn = get_connection() + conn.run("tail -50 /var/log/gloda-factory.log", warn=True) + conn.close() + + +def status(): + """Check application status.""" + print("📊 Checking status...") + conn = get_connection() + + result = conn.run("pgrep -f 'streamlit run web_app.py'", hide=True, warn=True) + if result.ok: + print(f"✅ Application is running (PID: {result.stdout.strip()})") + print(f"🌐 URL: http://{SERVER_IP}:8501") + else: + print("❌ Application is not running") + + conn.close() + + +def restart(): + """Restart the application.""" + print("🔄 Restarting application...") + conn = get_connection() + + # Stop + conn.run("pkill -f 'streamlit run web_app.py' || true", hide=True, warn=True) + + import time + time.sleep(2) + + # Start + start_app(conn) + conn.close() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Gloda Factory Deployment Tool") + parser.add_argument( + "command", + choices=["deploy", "start", "stop", "restart", "status", "logs"], + help="Deployment command to run" + ) + + args = parser.parse_args() + + commands = { + "deploy": deploy, + "stop": stop_app, + "status": status, + "logs": logs, + "restart": restart, + "start": lambda: start_app(get_connection()), + } + + commands[args.command]() + + + diff --git a/docs/SYSTEM_PROMPT_VIDEO_SCRIPT_v2.md b/docs/SYSTEM_PROMPT_VIDEO_SCRIPT_v2.md new file mode 100644 index 0000000..8191276 --- /dev/null +++ b/docs/SYSTEM_PROMPT_VIDEO_SCRIPT_v2.md @@ -0,0 +1,317 @@ +# SYSTEM CONTEXT + +**Role**: 你是一名精通抖音电商算法、搜索转化心理学与 AI 视频工程化的创意总监。 +**Task**: 为商品详情页(PDP)首图设计高转化率、可直接执行的 AI 视频脚本 (JSON)。 + +--- + +# 🎯 GOALS & KPI (业务核心) + +1. **GPM First**: 一切为了提升千次曝光成交额 (GPM) 和下单转化率。 + +2. **搜索心智 (Search Intent)**: 用户通过搜索关键词或商品卡进入,处于"决策验证期"。视频必须**"所见即所得"**,前 3 秒直接承接搜索预期。 + +3. **静音法则 (Mute Play)**: 默认静音播放。必须依赖高视觉冲击力和醒目花字 (Fancy Text) 在前 3 秒留住用户。 + +4. **全品类转化逻辑**: 必须根据商品属性匹配最佳脚本策略(见思维链)。 + +--- + +# ⏱️ 时长规范 (Duration Rules) + +- **总时长**: 9-12 秒 (由 3-4 个分镜组成) +- **单分镜**: 固定 **3 秒** (`duration: 3`),严禁超过 3 秒 +- **原因**: AI 生成视频超过 3 秒容易出现主体变形、画面抖动、物理异常 + +--- + +# 🧠 THINKING CHAIN (思维链 - 执行逻辑) + +在输出 JSON 前,必须按以下步骤思考: + +## Step 1: Input Analysis & Categorization (定性) + +分析商品属性,将其归类为以下四种类型之一: + +| Type | 类型 | 典型品类 | 脚本策略 | +|------|------|----------|----------| +| A | 功能型 | 清洁/收纳/工具/家电 | 痛点 → 解决方案 → 爽点 | +| B | 审美型 | 服装/首饰/彩妆/摆件 | 高颜全貌 → 细节质感 → 上身/氛围 | +| C | 感官型 | 零食/饮料/水果/预制菜 | 瞬间冲击 → 微观纹理 → 食欲诱惑 | +| D | 信任型 | 母婴/滋补/茶叶/高客单 | 源头/原料 → 权威背书 → 结果呈现 | + +## Step 2: Visual Anchor Extraction (定锚) + +基于参考图,提取一段包含 **材质、颜色、形状、包装特征** 的标准视觉描述 (Visual Anchor)。 +这是防止 AI 视频变形的"防伪码",**必须复用于所有分镜的 visual_prompt**。 + +示例:`"深棕色圆形曲奇饼干,表面嵌入巧克力碎块,牛皮纸包装袋印有品牌Logo"` + +## Step 3: Scripting Strategy (编排) + +| 分镜 | 时间 | 功能 | 设计要点 | +|------|------|------|----------| +| Scene 1 | 0-3s | 搜索承接 | Visual Anchor 全貌 + 核心卖点花字 | +| Scene 2 | 3-6s | 自适应 | Type A:功能演示 / B:细节质感 / C:食欲特写 / D:原料溯源 | +| Scene 3 | 6-9s | 深化 | 对比效果 / 动态美感 / 爆浆拉丝 / 权威背书 | +| Scene 4 | 9-12s | 收尾 (可选) | 信任背书 / 使用后美好状态 / 行动号召 | + +--- + +# 🎙️ 旁白设计规范 (Voiceover Rules) + +## 核心定位 ⚠️ +旁白是**卖点传递的主力军**,不是画面解说词。10秒内必须完成:场景共鸣 → 核心卖点 → 信任背书 → 行动召唤。 + +## 技术规范 +1. **语速**: **5 字/秒** (9秒视频 = 45-50字旁白),可略超视频时长,后期 1.1x 倍速压入 +2. **气口间隔**: 两段旁白之间留 **0.3-0.5 秒** +3. **时间控制**: `start_time` 和 `duration` 单位为秒 +4. **字幕同步**: `subtitle` 与 `text` 完全一致 + +## 写作禁忌 +- ❌ 描述画面:"这是一款发夹" → ✅ 带入场景:"想要千金范?这款发夹绝了" +- ❌ 空洞形容:"非常好看" → ✅ 具体感受:"黑发棕发都显贵气" +- ❌ 无信任背书 → ✅ 加数据:"月销3万单,回购率超高" +- ❌ 无行动召唤 → ✅ 加引导:"现在下单,还送同款小号" + +## 示例对比 +``` +❌ 旧版 (24字,信息不足): +"秋冬氛围感,财阀千金风" + "毛绒质感,搭配璀璨水钻" + "精致耐看,百搭不挑人" + +✅ 新版 (52字,信息密集): +"想要秋冬千金范?这款发夹绝了" + "奥地利进口水钻,手工镶嵌不掉钻" + +"黑发棕发都显贵,扎个马尾直接气质拉满" + "月销3万单,现在下单送同款小号" +``` + +--- + +# 🎨 商家风格提示 (Style Hint - Optional) + +如果用户提供了风格关键词(如"韩风"、"高级感"、"日系"),需融入: +- `video_style`: 调整色调、光影、构图 + - 韩风 → 低饱和、柔光、简洁留白 + - 高级感 → 暗调、金属质感、几何构图 + - 日系 → 自然光、木质/棉麻元素、温暖色调 +- `fancy_text.style`: 选择匹配的字幕风格 + - 高级感 → `minimal` (白字) + - 活力 → `highlight` (黄字) + - 食欲/警示 → `warning` (红字) + +--- + +# ⚠️ EXECUTION CONSTRAINTS (执行红线) + +## 视觉干净度 (Visual Cleanliness) + +**禁止 AI 额外生成**:装饰性文字、标语、水印、非商品元素 +**必须保留**:商品包装自带的文字、Logo、品牌标识(这是商品真实外观的一部分) + +正确写法: +``` +✅ "商品正面全貌,保留包装原有设计 --no added text --no watermarks" +❌ "--no text" (这会错误移除包装文字) +``` + +## 视觉一致性 (Visual Consistency) + +所有分镜的 `visual_prompt` **必须包含完整的 Visual Anchor**,确保主体外观不变形。 + +## 运动控制 (Motion Control) + +| 允许 ✅ | 禁止 ❌ | +|---------|---------| +| 物理运镜: Zoom In/Out, Pan, Tilt | 复杂生物动作: 手部翻转、穿衣、咀嚼 | +| 环境微动: 光影流动、水珠滑落、蒸汽升腾 | 主体形变: 产品旋转360°、折叠展开 | +| 物理动态: 掰开、倾倒、碎屑飞溅 | 长时间连续动作 (>3秒) | + +--- + +# ❌ 禁止示例 (Counter-examples) + +## Bad visual_prompt +``` +❌ "一只手拿起曲奇,放入嘴中咀嚼" + → 手部和嘴部动作必然变形 +✅ "曲奇被掰开的瞬间,巧克力流心缓缓溢出,微距特写" + → 物理动作,无人体 +``` + +## Bad video_prompt +``` +❌ "镜头跟随产品旋转一周,展示各个角度" + → 超出 3 秒,旋转运动易变形 +✅ "Slow Zoom In, 光影在表面流动, 背景蒸汽微动" + → 简单运镜 + 物理微动 +``` + +## Bad fancy_text +``` +❌ "进口黄油手工烘焙每日新鲜发货限时特惠" + → 超过 6 字,静音下无法快速阅读 +✅ "进口黄油" + → 核心卖点浓缩,一眼可读 +``` + +--- + +# 📐 Visual Prompt 语法规范 + +## 结构模板 +``` +[Visual Anchor] + [主体状态/动作] + [景别] + [环境/光影] + [否定提示] +``` + +## 完整示例 +``` +"[深棕色圆形曲奇,表面嵌入巧克力碎块,牛皮纸包装] + + 饼干被掰开,流心巧克力缓缓流出 + + 微距特写,浅景深 + + 暖黄色逆光,大理石台面 + + --no added text --no watermarks --no hands" +``` + +## 否定提示规范 (--no) +- `--no added text` (禁止AI添加的文字,保留包装原有文字) +- `--no watermarks` (禁止水印) +- `--no hands` / `--no human body` (如非必要) +- `--no complex motion` (禁止复杂动作) + +--- + +# 📄 OUTPUT FORMAT (Strict JSON Schema) + +**重要**:必须保留以下顶层字段,确保与现有系统兼容。 + +```json +{ + "product_name": "商品名称", + "visual_anchor": "商品视觉锚点:材质+颜色+形状+包装特征,用于保持生图一致性", + "selling_points": ["核心卖点1", "核心卖点2", "核心卖点3"], + "target_audience": "目标人群描述", + "video_style": "视频风格 (色调/光影/构图)", + "bgm_style": "BGM风格", + "voiceover_timeline": [ + { + "id": 1, + "text": "旁白文案 (口语化, 4字/秒)", + "subtitle": "字幕文案 (与text完全一致)", + "start_time": 0.0, + "duration": 3.0 + } + ], + "scenes": [ + { + "id": 1, + "duration": 3, + "visual_prompt": "[Visual Anchor] + 场景描述 --no added text --no watermarks", + "video_prompt": "运镜 + 物理动态描述", + "fancy_text": { + "text": "最多6字", + "style": "highlight | warning | minimal", + "position": "top | center | bottom", + "start_time": 0.5, + "duration": 2.0 + } + } + ] +} +``` + +--- + +# 📝 完整示例 (Type C - 爆浆曲奇) + +**Input**: 商品名"爆浆流心曲奇",参考图为深棕色曲奇+巧克力流心特写 + +**Output**: +```json +{ + "product_name": "爆浆流心曲奇", + "visual_anchor": "深棕色圆形曲奇饼干,表面嵌入巧克力碎块,内部巧克力流心,牛皮纸包装袋印有品牌Logo", + "selling_points": ["真·爆浆流心", "进口黄油", "香浓不腻"], + "target_audience": "18-35岁女性,追求零食品质,喜欢巧克力甜品", + "video_style": "Macro photography, warm golden backlight, shallow DOF, rustic wood surface", + "bgm_style": "ASMR crackling + light upbeat rhythm", + "voiceover_timeline": [ + { + "id": 1, + "text": "下午嘴馋了?来一口真·爆浆流心曲奇", + "subtitle": "下午嘴馋了?来一口真·爆浆流心曲奇", + "start_time": 0.0, + "duration": 3.0 + }, + { + "id": 2, + "text": "新西兰进口黄油,纯可可脂,咬开瞬间流心爆浆", + "subtitle": "新西兰进口黄油,纯可可脂,咬开瞬间流心爆浆", + "start_time": 3.2, + "duration": 3.5 + }, + { + "id": 3, + "text": "已售50万盒,回购率超高,现在下单买二送一", + "subtitle": "已售50万盒,回购率超高,现在下单买二送一", + "start_time": 7.0, + "duration": 3.0 + } + ], + "scenes": [ + { + "id": 1, + "duration": 3, + "visual_prompt": "[深棕色圆形曲奇饼干,表面嵌入巧克力碎块,牛皮纸包装印有品牌标识] 正面全貌堆叠展示,大理石台面,暖黄逆光,浅景深 --no added text --no watermarks", + "video_prompt": "Slow Zoom In, 光影在曲奇表面缓缓流动,背景轻微虚化", + "fancy_text": { + "text": "爆浆流心", + "style": "warning", + "position": "center", + "start_time": 0.5, + "duration": 2.0 + } + }, + { + "id": 2, + "duration": 3, + "visual_prompt": "[深棕色圆形曲奇饼干,内部巧克力流心] 饼干被掰开的瞬间,巧克力流心缓缓溢出,微距特写,暖色调 --no hands --no added text --no watermarks", + "video_prompt": "Static macro shot, 流心自然流动,碎屑微微散落", + "fancy_text": { + "text": "真·爆浆", + "style": "highlight", + "position": "bottom", + "start_time": 0.3, + "duration": 2.0 + } + }, + { + "id": 3, + "duration": 3, + "visual_prompt": "[深棕色圆形曲奇饼干,牛皮纸包装印有品牌标识] 包装盒俯拍,旁边散落黄油块和可可豆原料,简洁浅色背景 --no added text --no watermarks", + "video_prompt": "Slow Pan Right, 依次掠过原料和包装", + "fancy_text": { + "text": "进口黄油", + "style": "minimal", + "position": "top", + "start_time": 0.5, + "duration": 2.0 + } + } + ] +} +``` + +--- + +# ✅ 输出前自检清单 + +1. [ ] `product_name`, `visual_anchor`, `selling_points`, `target_audience` 是否存在于顶层? +2. [ ] `visual_anchor` 是否包含:材质+颜色+形状+包装特征? +3. [ ] `video_style`, `bgm_style` 是否存在于顶层? +4. [ ] 每个分镜 duration 是否 = 3? +5. [ ] 总时长是否在 9-12 秒范围内? +6. [ ] voiceover_timeline 使用的是 `start_time` 和 `duration` (秒) 而非 ratio? +7. [ ] 旁白语速是否 ≤ 4字/秒? +8. [ ] fancy_text 是否 ≤ 6 字? +9. [ ] 是否使用 `--no added text` 而非 `--no text`? +10. [ ] 是否避免了复杂人体动作描述? diff --git a/main_flow.py b/main_flow.py new file mode 100644 index 0000000..4b4744f --- /dev/null +++ b/main_flow.py @@ -0,0 +1,339 @@ +""" +Video Flow v2.0 - 命令行主流程控制器 + +独立的 CLI 入口,支持命令行参数调用完整的视频生成流程。 +与 app.py (Streamlit UI) 分离,共用 modules 层。 + +Usage: + python main_flow.py --help + + python main_flow.py \ + --product-name "网红气质大号发量多!高马尾香蕉夹" \ + --images /path/to/主图1.png /path/to/主图2.png \ + --category "钟表配饰-时尚饰品-发饰" \ + --price "3.99元" \ + --tags "回头客|款式好看|材质好" \ + --model doubao \ + --output final_hairclip +""" +import argparse +import logging +import sys +import json +import time +import random +from pathlib import Path +from typing import Dict, List, Optional + +# 设置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout), + logging.FileHandler("video_flow.log") + ] +) + +import config +from modules.script_gen import ScriptGenerator +from modules.image_gen import ImageGenerator +from modules.video_gen import VideoGenerator +from modules.composer import VideoComposer + +logger = logging.getLogger("MainFlow") + + +def parse_args(): + """解析命令行参数""" + parser = argparse.ArgumentParser( + description="Video Flow CLI - 商品短视频自动生成命令行工具", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例: + # 使用默认测试数据 + python main_flow.py --demo + + # 指定商品信息 + python main_flow.py \\ + --product-name "网红气质大号发量多!高马尾香蕉夹" \\ + --images ./素材/发夹/原始稿/主图1.png ./素材/发夹/原始稿/主图2.png \\ + --category "钟表配饰-时尚饰品-发饰" \\ + --price "3.99元" \\ + --tags "回头客|款式好看|材质好" \\ + --model doubao \\ + --output final_hairclip + """ + ) + + # 基本参数 + parser.add_argument("--demo", action="store_true", help="使用内置测试数据(发夹案例)") + parser.add_argument("--product-name", type=str, help="商品标题") + parser.add_argument("--images", nargs="+", type=str, help="商品主图路径列表 (建议 3-5 张)") + + # 商品信息 + parser.add_argument("--category", type=str, default="", help="商品类目") + parser.add_argument("--price", type=str, default="", help="商品价格") + parser.add_argument("--tags", type=str, default="", help="评价标签 (用于提炼卖点)") + parser.add_argument("--params", type=str, default="", help="商品参数") + parser.add_argument("--style-hint", type=str, default="", help="风格提示 (如: 韩风、高级感)") + + # 模型选择 + parser.add_argument("--script-model", choices=["shubiaobiao", "doubao"], default="doubao", + help="脚本生成模型 (default: doubao)") + parser.add_argument("--image-model", choices=["shubiaobiao", "doubao", "gemini", "doubao-group"], + default="doubao", help="图片生成模型 (default: doubao)") + + # 输出选项 + parser.add_argument("--output", type=str, default="final_video", help="输出文件名 (不含扩展名)") + parser.add_argument("--project-id", type=str, default=None, help="项目ID (默认自动生成)") + + # 可选步骤控制 + parser.add_argument("--skip-video", action="store_true", help="跳过视频生成步骤 (仅生成脚本和图片)") + parser.add_argument("--skip-compose", action="store_true", help="跳过合成步骤") + + return parser.parse_args() + + +def get_demo_data() -> tuple: + """获取内置测试数据 (发夹案例)""" + product_name = "网红气质大号发量多!高马尾香蕉夹 马尾显发量蓬松神器马尾夹" + product_info = { + "category": "钟表配饰-时尚饰品-发饰", + "price": "3.99元", + "tags": "回头客|款式好看|材质好|尺寸合适|颜色好看|很好用|做工好|质感不错|很牢固", + "params": "金属材质:非金属; 非金属材质:树脂; 发夹分类:香蕉夹; 风格:日韩|简约风|法式|瑞丽风", + "style_hint": "" + } + + # 原始图片路径 + base_image_dir = Path("/Volumes/Tony/video-flow/素材/发夹/原始稿") + original_images = [ + str(base_image_dir / "主图1.png"), + str(base_image_dir / "主图2.png"), + str(base_image_dir / "主图3.png") + ] + + return product_name, product_info, original_images + + +def match_bgm_by_style(bgm_style: str, bgm_dir: Path) -> Optional[str]: + """ + 根据脚本 bgm_style 智能匹配 BGM 文件 + - 匹配成功:随机选一个匹配的 BGM + - 匹配失败:随机选任意一个 BGM + """ + # 获取所有 BGM 文件 (支持 .mp3 和 .mp4) + bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3")) + bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')] + + if not bgm_files: + return None + + # 关键词匹配 + if bgm_style: + style_lower = bgm_style.lower() + keywords = ["活泼", "欢快", "轻松", "舒缓", "休闲", "温柔", "随性", "百搭", "bling", "节奏"] + matched_keywords = [kw for kw in keywords if kw in style_lower] + + matched_files = [] + for f in bgm_files: + fname = f.name + if any(kw in fname for kw in matched_keywords): + matched_files.append(f) + + if matched_files: + return str(random.choice(matched_files)) + + # 无匹配则随机选一个 + return str(random.choice(bgm_files)) + + +def run_video_flow(args) -> Optional[str]: + """执行完整的视频生成流程""" + + # ===== 1. 准备输入数据 ===== + if args.demo: + logger.info("Using DEMO data (发夹案例)...") + product_name, product_info, original_images = get_demo_data() + else: + if not args.product_name or not args.images: + logger.error("Must provide --product-name and --images, or use --demo") + return None + + product_name = args.product_name + product_info = { + "category": args.category, + "price": args.price, + "tags": args.tags, + "params": args.params, + "style_hint": args.style_hint + } + original_images = args.images + + # 验证图片是否存在 + valid_images = [p for p in original_images if Path(p).exists()] + if not valid_images: + logger.error("No valid input images found!") + logger.error(f"Checked paths: {original_images}") + return None + + logger.info(f"Found {len(valid_images)} valid images") + + # 生成项目 ID + project_id = args.project_id or f"CLI-{int(time.time())}" + logger.info(f"Project ID: {project_id}") + + # ===== 2. 生成脚本 ===== + logger.info("="*50) + logger.info("Step 1: Generating Script...") + logger.info("="*50) + + script_gen = ScriptGenerator() + script = script_gen.generate_script( + product_name, + product_info, + valid_images, + model_provider=args.script_model + ) + + if not script: + logger.error("Script generation failed.") + return None + + # 保存脚本供检查 + script_path = config.OUTPUT_DIR / f"script_{project_id}.json" + with open(script_path, "w", encoding="utf-8") as f: + json.dump(script, f, ensure_ascii=False, indent=2) + logger.info(f"Script saved to {script_path}") + + scenes = script.get("scenes", []) + logger.info(f"Generated {len(scenes)} scenes") + + # ===== 3. 生成分镜图片 ===== + logger.info("="*50) + logger.info("Step 2: Generating Scene Images...") + logger.info("="*50) + + image_gen = ImageGenerator() + visual_anchor = script.get("visual_anchor", "") + + scene_images: Dict[int, str] = {} + + if args.image_model == "doubao-group": + # 组图生成模式 + logger.info("Using Doubao Group Image Generation...") + scene_images = image_gen.generate_group_images_doubao( + scenes=scenes, + reference_images=valid_images, + visual_anchor=visual_anchor + ) + else: + # 顺序生成模式 + current_refs = list(valid_images) + + for idx, scene in enumerate(scenes): + scene_id = scene["id"] + logger.info(f"Generating image for Scene {scene_id} ({idx+1}/{len(scenes)})...") + + img_path = image_gen.generate_single_scene_image( + scene=scene, + original_image_path=current_refs, + previous_image_path=None, + model_provider=args.image_model, + visual_anchor=visual_anchor + ) + + if img_path: + scene_images[scene_id] = img_path + current_refs.append(img_path) + logger.info(f"Scene {scene_id} image: {img_path}") + else: + logger.warning(f"Failed to generate image for Scene {scene_id}") + + if not scene_images: + logger.error("Image generation failed (no images generated).") + return None + + logger.info(f"Generated {len(scene_images)} scene images.") + + if args.skip_video: + logger.info("Skipping video generation (--skip-video)") + return None + + # ===== 4. 生成分镜视频 ===== + logger.info("="*50) + logger.info("Step 3: Generating Scene Videos...") + logger.info("="*50) + + video_gen = VideoGenerator() + scene_videos = video_gen.generate_scene_videos(project_id, script, scene_images) + + if not scene_videos: + logger.error("Video generation failed (or partially failed).") + return None + + logger.info(f"Generated {len(scene_videos)} scene videos.") + + if args.skip_compose: + logger.info("Skipping composition (--skip-compose)") + return None + + # ===== 5. 合成最终视频 ===== + logger.info("="*50) + logger.info("Step 4: Composing Final Video...") + logger.info("="*50) + + composer = VideoComposer(voice_type=config.VOLC_TTS_DEFAULT_VOICE) + + # 智能匹配 BGM + bgm_style = script.get("bgm_style", "") + bgm_path = match_bgm_by_style(bgm_style, config.ASSETS_DIR / "bgm") + if bgm_path: + logger.info(f"Selected BGM: {Path(bgm_path).name} (style: {bgm_style or 'default'})") + + # 合成 + output_name = f"{args.output}_{project_id}" + final_video = composer.compose_from_script( + script=script, + video_map=scene_videos, + bgm_path=bgm_path, + output_name=output_name + ) + + logger.info("="*50) + logger.info(f"✅ Workflow Complete!") + logger.info(f" Final Video: {final_video}") + logger.info(f" Script: {script_path}") + logger.info("="*50) + + return final_video + + +def main(): + """CLI 入口""" + args = parse_args() + + # 验证参数 + if not args.demo and not args.product_name: + print("Error: Must provide --product-name and --images, or use --demo") + print("Run with --help for usage information.") + sys.exit(1) + + try: + result = run_video_flow(args) + if result: + sys.exit(0) + else: + sys.exit(1) + except KeyboardInterrupt: + logger.info("Interrupted by user") + sys.exit(130) + except Exception as e: + logger.exception(f"Unexpected error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() + diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..e21424e --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1,14 @@ +""" +Gloda Video Factory - Modules Package +""" + +__all__ = [ + "utils", + "brain", + "factory", + "editor", + "ffmpeg_utils", + "fancy_text", + "composer" +] + diff --git a/modules/asr.py b/modules/asr.py new file mode 100644 index 0000000..9f24815 --- /dev/null +++ b/modules/asr.py @@ -0,0 +1,81 @@ +""" +MatchMe Studio - ASR Module (Whisper via ShuBiaoBiao) +""" +import logging +import subprocess +from pathlib import Path +from typing import Optional +from openai import OpenAI + +import config + +logger = logging.getLogger(__name__) + +client = OpenAI( + api_key=config.SHUBIAOBIAO_KEY, + base_url=config.SHUBIAOBIAO_BASE_URL +) + + +def extract_audio_from_video(video_path: str) -> str: + """Extract audio track from video using ffmpeg.""" + video_path = Path(video_path) + audio_path = config.TEMP_DIR / f"{video_path.stem}_audio.mp3" + + cmd = [ + "ffmpeg", "-y", + "-i", str(video_path), + "-vn", # No video + "-acodec", "libmp3lame", + "-ar", "16000", # 16kHz for Whisper + "-ac", "1", # Mono + str(audio_path) + ] + + try: + subprocess.run(cmd, check=True, capture_output=True) + logger.info(f"Audio extracted to {audio_path}") + return str(audio_path) + except subprocess.CalledProcessError as e: + logger.error(f"FFmpeg error: {e.stderr.decode()}") + raise RuntimeError("Failed to extract audio from video") + + +def transcribe(audio_path: str) -> str: + """Transcribe audio to text using Whisper API.""" + logger.info(f"Transcribing {audio_path}...") + + try: + with open(audio_path, "rb") as audio_file: + response = client.audio.transcriptions.create( + model="whisper-1", + file=audio_file, + language="zh", # Chinese + response_format="text" + ) + + text = response if isinstance(response, str) else response.text + logger.info(f"Transcription complete: {len(text)} chars") + return text + + except Exception as e: + logger.error(f"Whisper API error: {e}") + raise + + +def transcribe_video(video_path: str) -> str: + """Extract audio from video and transcribe.""" + audio_path = extract_audio_from_video(video_path) + return transcribe(audio_path) + + + + + + + + + + + + diff --git a/modules/brain.py b/modules/brain.py new file mode 100644 index 0000000..3ff6554 --- /dev/null +++ b/modules/brain.py @@ -0,0 +1,346 @@ +""" +MatchMe Studio - Brain Module (Multi-stage Analysis & Script Generation) +""" +import json +import logging +from typing import Dict, Any, List, Optional +from openai import OpenAI + +import config + +logger = logging.getLogger(__name__) + +# Use Volcengine (Doubao) via OpenAI Compatible Interface +client = OpenAI( + api_key=config.VOLC_API_KEY, + base_url=config.VOLC_BASE_URL +) + +# ============================================================ +# Stage 1: Analyze Materials +# ============================================================ + +ANALYZE_SYSTEM_PROMPT = """你是一位资深短视频创作总监,专精TikTok/抖音爆款内容。 + +任务:深度分析用户提供的素材和需求,识别产品特性、使用场景、目标人群。 + +分析维度: +1. 产品/服务核心卖点(从素材中提取视觉特征) +2. 视觉风格特征(颜色、质感、包装) +3. 潜在目标受众 +4. 内容调性建议 + +然后检查是否缺少关键信息,如果缺少,生成2-5个问题帮助完善需求。 +每个问题必须与短视频创作直接相关。 + +输出严格JSON格式: +{ + "analysis": "详细分析结果,包括从素材中识别到的视觉元素...", + "detected_info": { + "product": "识别到的产品名称和类型", + "visual_features": ["视觉特征1", "视觉特征2"], + "audience": "推测的目标人群", + "style": "推测的风格" + }, + "missing_info": ["缺少的信息1", "缺少的信息2"], + "questions": [ + { + "id": "q1", + "text": "问题文字(说明为什么这个问题重要)", + "options": ["选项A", "选项B", "选项C"], + "allow_multiple": true, + "allow_custom": true + } + ], + "ready": false +} + +如果信息足够,ready=true,questions为空数组。 +""" + +def analyze_materials( + prompt: str, + image_urls: List[str] = None, + asr_text: str = "" +) -> Dict[str, Any]: + """ + Deep analysis of user materials. + Returns analysis text and questions if info is missing. + """ + logger.info("Brain: Analyzing materials...") + + # Using Vision Model format (Doubao Vision) + # Input format: messages with content list (text + image_url) + + content_parts = [{"type": "text", "text": f"用户需求: {prompt}"}] + + if asr_text: + content_parts.append({"type": "text", "text": f"\n视频原声(ASR转写): {asr_text}"}) + + if image_urls: + content_parts.append({"type": "text", "text": "\n用户上传的素材图片(请仔细分析这些图片中的产品特征):"}) + for url in image_urls: + content_parts.append({ + "type": "image_url", + "image_url": {"url": url} + }) + + messages = [ + # Note: Some vision models might not support 'system' role with images well, + # but Doubao usually follows standard chat structure. + # If system prompt fails, prepend it to user content. + {"role": "system", "content": ANALYZE_SYSTEM_PROMPT}, + {"role": "user", "content": content_parts} + ] + + try: + # Use Vision Model for Analysis + response = client.chat.completions.create( + model=config.VISION_MODEL_ID, + messages=messages, + temperature=0.7, + max_tokens=4000 + ) + + content = response.choices[0].message.content.strip() + if content.startswith("```"): + parts = content.split("```") + if len(parts) > 1: + content = parts[1] + if content.startswith("json"): content = content[4:] + + return json.loads(content) + + except Exception as e: + logger.error(f"Brain Analyze Error: {e}") + raise + + +# ============================================================ +# Stage 2: Refine Brief with Answers +# ============================================================ + +REFINE_SYSTEM_PROMPT = """你是短视频创作总监。 +根据原始需求、AI分析结果、用户补充回答,整合为完整的创意简报。 + +注意:用户选择的风格偏好(如ASMR、剧情、视觉流等)必须作为核心创作方向贯穿整个简报。 + +输出JSON: +{ + "brief": { + "product": "产品名称", + "product_visual_description": "产品视觉描述(颜色、形状、包装、质感等,用于后续图片生成)", + "selling_points": ["卖点1", "卖点2"], + "target_audience": "目标人群", + "platform": "投放平台", + "style": "视频风格(必须明确,如ASMR/剧情/视觉流等)", + "style_requirements": "该风格的具体创作要求(如ASMR需要:开盖声、质感特写、无人脸等)", + "creativity_level": "创意程度", + "reference": "对标账号/竞品", + "user_assets_description": "用户上传素材的描述(用于后续继承)" + }, + "creative_summary": "整体创意概述(50字以内,描述这个视频的核心创意方向)", + "ready": true +} +""" + +def refine_brief( + original_prompt: str, + analysis: Dict[str, Any], + answers: Dict[str, Any], + image_urls: List[str] = None +) -> Dict[str, Any]: + """ + Integrate user answers into a complete creative brief. + """ + logger.info("Brain: Refining brief with answers...") + + user_content = f""" +原始需求: {original_prompt} + +AI分析结果: {json.dumps(analysis, ensure_ascii=False)} + +用户补充回答: {json.dumps(answers, ensure_ascii=False)} + +用户上传的素材URL: {json.dumps(image_urls or [], ensure_ascii=False)} +""" + + try: + # Use Text LLM for reasoning/refining if no new images involved + # But to keep it simple, we can stick to BRAIN_MODEL_ID (Doubao Pro) + response = client.chat.completions.create( + model=config.BRAIN_MODEL_ID, + messages=[ + {"role": "system", "content": REFINE_SYSTEM_PROMPT}, + {"role": "user", "content": user_content} + ], + temperature=0.5, + max_tokens=3000 + ) + + content = response.choices[0].message.content.strip() + if content.startswith("```"): + parts = content.split("```") + if len(parts) > 1: + content = parts[1] + if content.startswith("json"): content = content[4:] + + return json.loads(content) + + except Exception as e: + logger.error(f"Brain Refine Error: {e}") + raise + + +# ============================================================ +# Stage 3: Generate Script +# ============================================================ + +SCRIPT_SYSTEM_PROMPT = """你是顶级短视频编导,专精{style}风格内容创作。 + +根据创意简报生成爆款脚本。必须严格遵循用户选择的风格要求。 + +脚本结构要求: +1. creative_summary: 整体创意概述(这条视频的核心创意是什么) +2. hook: 前3秒钩子设计(必须抓眼球,符合{style}风格) +3. scenes: 3-8个分镜 +4. cta: 结尾行动号召(纯文本字符串) + +每个分镜(scene)必须包含: +- id: 分镜编号 +- duration: 时长(5/10/15秒,符合视频模型参数) +- timeline: 时间轴 (如 "0:00-0:05") +- image_prompt: 【关键】用于AI生图的详细英文prompt,必须包含: + * 产品的具体视觉描述(继承自brief中的product_visual_description) + * 8k, hyper-realistic, cinematic lighting + * 色调、环境、构图、焦点 + * 风格要求(如ASMR需要:macro shot, satisfying texture, no human face) +- keyframe: { + "color_tone": "色调", + "environment": "环境/背景", + "foreground": "前景元素", + "focus": "视觉焦点", + "subject": "主体描述", + "composition": "构图方式" + } +- camera_movement: 运镜描述(如:slow zoom in, pan left, static) +- story_beat: 这个分镜在整体故事中的作用 +- voiceover: 旁白文字({style}风格,如ASMR应简短或无旁白,用音效代替) +- sound_design: 音效设计(如:开盖声、水滴声、环境白噪音) +- rhythm: {"change": "保持/加快/放慢", "multiplier": 1.0} + +旁白要求: +- 必须连贯,形成完整的叙事 +- 符合{style}风格(ASMR风格应极简或无旁白) +- 每句旁白要能独立成句,但连起来是完整故事 + +输出严格JSON格式。 +""" + +def generate_script( + brief: Dict[str, Any], + image_urls: List[str] = None, + regenerate_feedback: str = "" +) -> Dict[str, Any]: + """ + Generate complete video script with scenes. + """ + logger.info("Brain: Generating script...") + + style = brief.get("style", "现代广告") + system_prompt = SCRIPT_SYSTEM_PROMPT.replace("{style}", style) + + content_parts = [{"type": "text", "text": f"创意简报: {json.dumps(brief, ensure_ascii=False)}"}] + + if regenerate_feedback: + content_parts.append({"type": "text", "text": f"\n用户反馈(请据此调整): {regenerate_feedback}"}) + + if image_urls: + content_parts.append({"type": "text", "text": "\n用户上传的参考素材(生成的image_prompt必须参考这些素材中的产品外观):"}) + for url in image_urls: + content_parts.append({ + "type": "image_url", + "image_url": {"url": url} + }) + + try: + response = client.chat.completions.create( + model=config.VISION_MODEL_ID, # Use Vision model to see reference images if available + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": content_parts} + ], + temperature=0.8, + max_tokens=8000 + ) + + content = response.choices[0].message.content.strip() + if content.startswith("```"): + parts = content.split("```") + if len(parts) > 1: + content = parts[1] + if content.startswith("json"): content = content[4:] + + return json.loads(content) + + except Exception as e: + logger.error(f"Brain Script Error: {e}") + raise + + +# ============================================================ +# Stage 4: Regenerate Single Scene +# ============================================================ + +def regenerate_scene( + full_script: Dict[str, Any], + scene_id: int, + feedback: str, + brief: Dict[str, Any] = None +) -> Dict[str, Any]: + """ + Regenerate a single scene based on feedback. + """ + logger.info(f"Brain: Regenerating scene {scene_id}...") + + style = brief.get("style", "现代广告") if brief else "现代广告" + + system_prompt = f"""你是短视频编导,专精{style}风格。根据用户反馈重新生成指定分镜。 +保持与其他分镜的风格连贯性。 +image_prompt必须继承产品的视觉描述。 +只输出新的scene对象(JSON)。 +""" + + user_content = f""" +完整脚本: {json.dumps(full_script, ensure_ascii=False)} + +创意简报: {json.dumps(brief, ensure_ascii=False) if brief else "无"} + +需要重新生成的分镜ID: {scene_id} + +用户反馈: {feedback} +""" + + try: + response = client.chat.completions.create( + model=config.BRAIN_MODEL_ID, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content} + ], + temperature=0.8, + max_tokens=2000 + ) + + content = response.choices[0].message.content.strip() + if content.startswith("```"): + parts = content.split("```") + if len(parts) > 1: + content = parts[1] + if content.startswith("json"): content = content[4:] + + return json.loads(content) + + except Exception as e: + logger.error(f"Brain Regenerate Scene Error: {e}") + raise diff --git a/modules/composer.py b/modules/composer.py new file mode 100644 index 0000000..198cfb6 --- /dev/null +++ b/modules/composer.py @@ -0,0 +1,717 @@ +""" +视频合成器模块 +整合视频拼接、花字叠加、旁白配音的完整流程 +""" +import os +import time +import logging +from pathlib import Path +from typing import Dict, Any, List, Optional, Union + +import config +from modules import ffmpeg_utils, fancy_text, factory, storage +from modules.text_renderer import renderer + +logger = logging.getLogger(__name__) + + +class VideoComposer: + """视频合成器""" + + def __init__( + self, + output_dir: str = None, + target_size: tuple = (1080, 1920), + voice_type: str = "sweet_female" + ): + """ + 初始化合成器 + + Args: + output_dir: 输出目录 + target_size: 目标分辨率 (width, height) + voice_type: 默认旁白音色 + """ + self.output_dir = Path(output_dir) if output_dir else config.OUTPUT_DIR + self.output_dir.mkdir(exist_ok=True) + self.target_size = target_size + self.voice_type = voice_type + + # 临时文件追踪 + self._temp_files = [] + + def _add_temp(self, path: str): + """记录临时文件""" + if path: + self._temp_files.append(path) + + def cleanup(self): + """清理临时文件""" + for f in self._temp_files: + try: + if os.path.exists(f): + os.remove(f) + except Exception as e: + logger.warning(f"Failed to cleanup {f}: {e}") + self._temp_files = [] + + def compose( + self, + video_paths: List[str], + subtitles: List[Dict[str, Any]] = None, + fancy_texts: List[Dict[str, Any]] = None, + voiceover_text: str = None, + voiceover_segments: List[Dict[str, Any]] = None, + bgm_path: str = None, + bgm_volume: float = 0.15, + output_name: str = None, + upload_to_r2: bool = False + ) -> str: + """ + 完整视频合成流程 + + Args: + video_paths: 分镜视频路径列表 + subtitles: 字幕配置列表 [{text, start, duration, style}] + fancy_texts: 花字配置列表 [{text, style, x, y, start, duration}] + voiceover_text: 完整旁白文本(会自动生成并混音) + voiceover_segments: 分段旁白配置 [{text, start}],与 voiceover_text 二选一 + bgm_path: 背景音乐路径 + bgm_volume: BGM音量 + output_name: 输出文件名(不含扩展名) + upload_to_r2: 是否上传到R2存储 + + Returns: + 最终视频路径(或R2 URL) + """ + if not video_paths: + raise ValueError("No video paths provided") + + timestamp = int(time.time()) + output_name = output_name or f"composed_{timestamp}" + + logger.info(f"Starting composition: {len(video_paths)} videos") + + try: + # Step 1: 拼接视频 + merged_path = str(config.TEMP_DIR / f"{output_name}_merged.mp4") + ffmpeg_utils.concat_videos(video_paths, merged_path, self.target_size) + self._add_temp(merged_path) + current_video = merged_path + + # Step 1.1: 若无音轨,补一条静音底,避免后续滤镜找不到 0:a + silent_path = str(config.TEMP_DIR / f"{output_name}_silent.mp4") + ffmpeg_utils.add_silence_audio(current_video, silent_path) + self._add_temp(silent_path) + current_video = silent_path + + # Step 2: 添加字幕 (白字黑边,无底框,下半区域居中) + if subtitles: + subtitled_path = str(config.TEMP_DIR / f"{output_name}_subtitled.mp4") + subtitle_style = { + "font": ffmpeg_utils._get_font_path(), + "fontsize": 60, + "fontcolor": "white", + "borderw": 5, + "bordercolor": "black", + "box": 0, # 无底框 + "y": "h-200", # 下半区域居中 + } + ffmpeg_utils.add_multiple_subtitles( + current_video, subtitles, subtitled_path, default_style=subtitle_style + ) + self._add_temp(subtitled_path) + current_video = subtitled_path + + # Step 3: 叠加花字 (支持原子化参数) + if fancy_texts: + overlay_configs = [] + for ft in fancy_texts: + text = ft.get("text", "") + style = ft.get("style") + custom_style = ft.get("custom_style") + + # 如果 style 是字典,说明是原子化参数,直接使用 + if isinstance(style, dict): + img_path = renderer.render(text, style, cache=False) + elif custom_style and isinstance(custom_style, dict): + # 兼容旧逻辑:如果有 custom_style,尝试通过原子化渲染器渲染 + if "font_size" in custom_style: + img_path = renderer.render(text, custom_style, cache=False) + else: + # 回退到旧版 fancy_text + img_path = fancy_text.create_fancy_text( + text=text, + style=style if isinstance(style, str) else "subtitle", + custom_style={ + **(custom_style or {}), + "font_name": "/System/Library/Fonts/PingFang.ttc", + }, + cache=False + ) + else: + # 旧版逻辑 + img_path = fancy_text.create_fancy_text( + text=text, + style=style if isinstance(style, str) else "subtitle", + custom_style={ + "font_name": "/System/Library/Fonts/PingFang.ttc", + }, + cache=False + ) + + overlay_configs.append({ + "path": img_path, + "x": ft.get("x", "(W-w)/2"), + "y": ft.get("y", "(H-h)/2"), + "start": ft.get("start", 0), + "duration": ft.get("duration", 999) + }) + + fancy_path = str(config.TEMP_DIR / f"{output_name}_fancy.mp4") + ffmpeg_utils.overlay_multiple_images( + current_video, overlay_configs, fancy_path + ) + self._add_temp(fancy_path) + current_video = fancy_path + + # Step 4: 生成并混合旁白(火山 WS 优先,失败回退 Edge) + if voiceover_text: + vo_path = factory.generate_voiceover_volcengine( + text=voiceover_text, + voice_type=self.voice_type + ) + self._add_temp(vo_path) + + voiced_path = str(config.TEMP_DIR / f"{output_name}_voiced.mp4") + ffmpeg_utils.mix_audio( + current_video, vo_path, voiced_path, + audio_volume=1.5, + video_volume=0.2 + ) + self._add_temp(voiced_path) + current_video = voiced_path + + elif voiceover_segments: + current_video = self._add_segmented_voiceover( + current_video, voiceover_segments, output_name + ) + + # Step 5: 添加BGM(淡入淡出,若 duck 失败会自动退回低音量混合) + if bgm_path: + bgm_output = str(config.TEMP_DIR / f"{output_name}_bgm.mp4") + ffmpeg_utils.add_bgm( + current_video, bgm_path, bgm_output, + bgm_volume=bgm_volume, + ducking=False, # 为避免兼容性问题,这里禁用 duck,保持低音量 + duck_gain_db=-6.0, + fade_in=1.0, + fade_out=1.0 + ) + self._add_temp(bgm_output) + current_video = bgm_output + + # Step 6: 输出最终文件 + final_path = str(self.output_dir / f"{output_name}.mp4") + + # 复制到输出目录 + import shutil + shutil.copy(current_video, final_path) + + logger.info(f"Composition complete: {final_path}") + + # 上传到R2 + if upload_to_r2: + r2_url = storage.upload_file(final_path) + logger.info(f"Uploaded to R2: {r2_url}") + return r2_url + + return final_path + + finally: + # 清理临时文件(保留最终输出) + self.cleanup() + + def _add_segmented_voiceover( + self, + video_path: str, + segments: List[Dict[str, Any]], + output_name: str + ) -> str: + """添加分段旁白""" + if not segments: + return video_path + + # 为每段生成音频 + audio_files = [] + for i, seg in enumerate(segments): + text = seg.get("text", "") + if not text: + continue + + voice = seg.get("voice_type", self.voice_type) + audio_path = factory.generate_voiceover_volcengine( + text=text, + voice_type=voice, + output_path=str(config.TEMP_DIR / f"{output_name}_seg_{i}.mp3") + ) + + if audio_path: + audio_files.append({ + "path": audio_path, + "start": seg.get("start", 0) + }) + self._add_temp(audio_path) + + if not audio_files: + return video_path + + # 依次混入音频 + current = video_path + for i, af in enumerate(audio_files): + output = str(config.TEMP_DIR / f"{output_name}_seg_mixed_{i}.mp4") + ffmpeg_utils.mix_audio( + current, af["path"], output, + audio_volume=1.0, + video_volume=0.2 if i == 0 else 1.0, # 只在第一次降低原视频音量 + audio_start=af["start"] + ) + self._add_temp(output) + current = output + + return current + + def compose_from_script( + self, + script: Dict[str, Any], + video_map: Dict[int, str], + bgm_path: str = None, + output_name: str = None + ) -> str: + """ + 基于生成脚本和视频映射进行合成 + + Args: + script: 标准化分镜脚本 + video_map: 场景ID到视频路径的映射 + bgm_path: BGM路径 + output_name: 输出文件名 + """ + scenes = script.get("scenes", []) + if not scenes: + raise ValueError("Empty script") + + video_paths = [] + fancy_texts = [] + + # 1. 收集视频路径和花字 (按分镜顺序) + total_duration = 0.0 + + for scene in scenes: + scene_id = scene["id"] + video_path = video_map.get(scene_id) + + if not video_path or not os.path.exists(video_path): + logger.warning(f"Missing video for scene {scene_id}, skipping") + continue + + # 获取实际视频时长 + try: + info = ffmpeg_utils.get_video_info(video_path) + duration = float(info.get("duration", 5.0)) + except: + duration = 5.0 + + video_paths.append(video_path) + + # 花字 (白字黑边,无底框,固定在上半区域居中) + if "fancy_text" in scene: + ft = scene["fancy_text"] + if isinstance(ft, dict): + text = ft.get("text", "") + + if text: + # 固定样式:白字黑边,无底框 + fixed_style = { + "font_size": 72, + "font_color": "#FFFFFF", + "stroke": {"color": "#000000", "width": 5} + # 无 background,不加底框 + } + + fancy_texts.append({ + "text": text, + "style": fixed_style, + "x": "(W-w)/2", # 居中 + "y": "180", # 上半区域 + "start": total_duration + float(ft.get("start_time", 0)), + "duration": float(ft.get("duration", duration)) + }) + + total_duration += duration + + # 2. 拼接视频 + timestamp = int(time.time()) + output_name = output_name or f"composed_{timestamp}" + + merged_path = str(config.TEMP_DIR / f"{output_name}_merged.mp4") + ffmpeg_utils.concat_videos(video_paths, merged_path, self.target_size) + self._add_temp(merged_path) + current_video = merged_path + + # 3. 处理整体旁白时间轴 (New Logic) + voiceover_timeline = script.get("voiceover_timeline", []) + mixed_audio_path = str(config.TEMP_DIR / f"{output_name}_mixed_vo.mp3") + + # 初始化静音底轨 (长度为 total_duration) + ffmpeg_utils._run_ffmpeg([ + ffmpeg_utils.FFMPEG_PATH, "-y", + "-f", "lavfi", "-i", "anullsrc=r=44100:cl=stereo", + "-t", str(total_duration), + "-c:a", "mp3", + mixed_audio_path + ]) + self._add_temp(mixed_audio_path) + + subtitles = [] + + if voiceover_timeline: + for i, item in enumerate(voiceover_timeline): + text = item.get("text", "") + sub_text = item.get("subtitle", text) + + # 支持两种格式: + # 新格式: start_time (秒), duration (秒) - 直接使用绝对时间 + # 旧格式: start_ratio (0-1), duration_ratio (0-1) - 按比例计算 + if "start_time" in item: + # 新格式:直接使用秒 + target_start = float(item.get("start_time", 0)) + target_duration = float(item.get("duration", 3)) + else: + # 旧格式:按比例计算(向后兼容) + start_ratio = float(item.get("start_ratio", 0)) + duration_ratio = float(item.get("duration_ratio", 0)) + target_start = start_ratio * total_duration + target_duration = duration_ratio * total_duration + + if not text: continue + + # 生成 TTS + tts_path = factory.generate_voiceover_volcengine( + text=text, + voice_type=self.voice_type, + output_path=str(config.TEMP_DIR / f"{output_name}_vo_{i}.mp3") + ) + self._add_temp(tts_path) + + # 调整时长 + adjusted_path = str(config.TEMP_DIR / f"{output_name}_vo_adj_{i}.mp3") + ffmpeg_utils.adjust_audio_duration(tts_path, target_duration, adjusted_path) + self._add_temp(adjusted_path) + + # 混合到总音轨 + new_mixed = str(config.TEMP_DIR / f"{output_name}_mixed_{i}.mp3") + ffmpeg_utils.mix_audio_at_offset(mixed_audio_path, adjusted_path, target_start, new_mixed) + mixed_audio_path = new_mixed # Update current mixed path + self._add_temp(new_mixed) + + # 添加字幕配置 (完全同步) + subtitles.append({ + "text": ffmpeg_utils.wrap_text_smart(sub_text), + "start": target_start, + "duration": target_duration, + "style": {} # Default + }) + + # 4. 将合成好的旁白混入视频 + voiced_path = str(config.TEMP_DIR / f"{output_name}_voiced.mp4") + ffmpeg_utils.mix_audio( + current_video, mixed_audio_path, voiced_path, + audio_volume=1.5, + video_volume=0.2 # 压低原音 + ) + self._add_temp(voiced_path) + current_video = voiced_path + + # 5. 添加字幕 (使用新的 ffmpeg_utils.add_multiple_subtitles) + if subtitles: + subtitled_path = str(config.TEMP_DIR / f"{output_name}_subtitled.mp4") + subtitle_style = { + "font": ffmpeg_utils._get_font_path(), + "fontsize": 60, + "fontcolor": "white", + "borderw": 5, + "bordercolor": "black", + "box": 0, # 无底框 + "y": "h-200", # 下半区域居中 + } + ffmpeg_utils.add_multiple_subtitles( + current_video, subtitles, subtitled_path, default_style=subtitle_style + ) + self._add_temp(subtitled_path) + current_video = subtitled_path + + # 6. 添加花字 + if fancy_texts: + fancy_path = str(config.TEMP_DIR / f"{output_name}_fancy.mp4") + + overlay_configs = [] + for ft in fancy_texts: + # 渲染花字图片 + img_path = renderer.render(ft["text"], ft["style"], cache=False) + overlay_configs.append({ + "path": img_path, + "x": ft["x"], + "y": ft["y"], + "start": ft["start"], + "duration": ft["duration"] + }) + + ffmpeg_utils.overlay_multiple_images( + current_video, overlay_configs, fancy_path + ) + self._add_temp(fancy_path) + current_video = fancy_path + + # 7. 添加 BGM + if bgm_path: + bgm_output = str(config.TEMP_DIR / f"{output_name}_bgm.mp4") + ffmpeg_utils.add_bgm( + current_video, bgm_path, bgm_output, + bgm_volume=0.15 + ) + self._add_temp(bgm_output) + current_video = bgm_output + + # 8. 输出最终文件 + final_path = str(self.output_dir / f"{output_name}.mp4") + import shutil + shutil.copy(current_video, final_path) + + logger.info(f"Composition complete: {final_path}") + + self.cleanup() + return final_path + + + def compose_standard_task(self, task_config: Dict[str, Any]) -> str: + """ + 执行标准合成任务 (Legacy) + """ + settings = task_config.get("settings", {}) + self.voice_type = settings.get("voice_type", self.voice_type) + + # 1. 准备视频片段 + video_paths = [] + for seg in task_config.get("segments", []): + path = seg.get("path") or seg.get("video_path") + if not path: continue + video_paths.append(path) + + # 2. 解析时间轴 + subtitles = [] + fancy_texts = [] + voiceover_segments = [] + + for item in task_config.get("timeline", []): + itype = item.get("type") + + if not itype: + if "text" in item and ("style" in item or "x" in item or "y" in item): + itype = "fancy_text" + elif "text" in item and "duration" in item and "start" in item: + itype = "subtitle" + elif "text" in item and "start" in item: + itype = "voiceover" + else: + continue + + if itype == "subtitle": + subtitles.append(item) + elif itype == "fancy_text": + if "x" not in item and "position" in item: + item["x"] = item["position"].get("x") + item["y"] = item["position"].get("y") + fancy_texts.append(item) + elif itype == "voiceover": + voiceover_segments.append(item) + + return self.compose( + video_paths=video_paths, + subtitles=subtitles, + fancy_texts=fancy_texts, + voiceover_segments=voiceover_segments, + bgm_path=settings.get("bgm_path"), + bgm_volume=settings.get("bgm_volume", 0.06), + output_name=settings.get("output_name"), + upload_to_r2=settings.get("upload_to_r2", False) + ) + + +def compose_product_video( + video_paths: List[str], + subtitle_configs: List[Dict[str, Any]] = None, + fancy_text_configs: List[Dict[str, Any]] = None, + voiceover_text: str = None, + bgm_path: str = None, + output_path: str = None, + voice_type: str = "sweet_female" +) -> str: + """便捷函数:合成商品短视频""" + composer = VideoComposer(voice_type=voice_type) + + output_name = None + if output_path: + output_name = Path(output_path).stem + composer.output_dir = Path(output_path).parent + + return composer.compose( + video_paths=video_paths, + subtitles=subtitle_configs, + fancy_texts=fancy_text_configs, + voiceover_text=voiceover_text, + bgm_path=bgm_path, + output_name=output_name + ) + + +def quick_compose( + video_folder: str, + script: List[Dict[str, Any]], + output_path: str = None, + voice_type: str = "sweet_female", + bgm_path: str = None +) -> str: + """快速合成:从文件夹读取视频,配合脚本合成""" + folder = Path(video_folder) + + video_files = sorted([ + f for f in folder.iterdir() + if f.suffix.lower() in ['.mp4', '.mov', '.avi', '.mkv'] + ]) + + video_paths = [] + subtitles = [] + fancy_texts = [] + voiceovers = [] + + current_time = 0 + + for i, item in enumerate(script): + if "video" in item: + vp = folder / item["video"] + elif i < len(video_files): + vp = video_files[i] + else: + logger.warning(f"No video for script item {i}") + continue + + video_paths.append(str(vp)) + + try: + info = ffmpeg_utils.get_video_info(str(vp)) + duration = info.get("duration", 5) + except: + duration = item.get("duration", 5) + + if "subtitle" in item: + subtitles.append({ + "text": item["subtitle"], + "start": current_time, + "duration": duration, + "style": item.get("subtitle_style", {}) + }) + + if "fancy_text" in item: + ft = item["fancy_text"] + if isinstance(ft, str): + ft = {"text": ft} + fancy_texts.append({ + "text": ft.get("text", ""), + "style": ft.get("style", "highlight"), + "custom_style": ft.get("custom_style"), + "x": ft.get("x", "(W-w)/2"), + "y": ft.get("y", 200), + "start": current_time, + "duration": duration + }) + + if "voiceover" in item: + voiceovers.append(item["voiceover"]) + + current_time += duration + + voiceover_text = "。".join(voiceovers) if voiceovers else None + + return compose_product_video( + video_paths=video_paths, + subtitle_configs=subtitles if subtitles else None, + fancy_text_configs=fancy_texts if fancy_texts else None, + voiceover_text=voiceover_text, + bgm_path=bgm_path, + output_path=output_path, + voice_type=voice_type + ) + + +# ============================================================ +# 示例用法 +# ============================================================ + +def example_hairclip_video(): + """示例:发夹商品视频合成""" + 素材目录 = Path("/Volumes/Tony/video-flow/素材/发夹/合成图拆分镜") + + video_paths = [ + str(素材目录 / "视频-分镜1.mp4"), + str(素材目录 / "视频-分镜2.mp4"), + str(素材目录 / "视频-分镜3.mp4"), + str(素材目录 / "视频-分镜4.mp4"), + str(素材目录 / "视频-分镜5.mp4"), + ] + + script = [ + { + "subtitle": "塌马尾 vs 高颅顶", + "fancy_text": { + "text": "塌马尾 vs 高颅顶", + "style": "comparison", + "y": 150 + }, + "voiceover": "普通马尾和高颅顶马尾的区别,你看出来了吗", + }, + { + "subtitle": "3秒出门,无需皮筋", + "fancy_text": {"text": "发量+50%", "style": "bubble", "y": 300}, + "voiceover": "只需要三秒钟,不需要皮筋,发量瞬间增加百分之五十", + }, + { + "subtitle": "发量+50%", + "voiceover": "蓬松的高颅顶效果,让你瞬间变美", + }, + { + "subtitle": "狂甩不掉!", + "fancy_text": {"text": "狂甩不掉!", "style": "warning", "y": 400}, + "voiceover": "而且超级牢固,怎么甩都不会掉", + }, + { + "subtitle": "¥3.99 立即抢购", + "fancy_text": {"text": "3.99", "style": "price", "y": 500}, + "voiceover": "只要三块九毛九,点击下方链接立即购买", + }, + ] + + output = quick_compose( + video_folder=str(素材目录), + script=script, + output_path="/Volumes/Tony/video-flow/output/发夹_合成视频.mp4", + voice_type="sweet_female" + ) + + print(f"视频合成完成: {output}") + return output + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + example_hairclip_video() diff --git a/modules/db_manager.py b/modules/db_manager.py new file mode 100644 index 0000000..09c5552 --- /dev/null +++ b/modules/db_manager.py @@ -0,0 +1,305 @@ +""" +数据库管理模块 (SQLAlchemy) +负责项目数据、任务状态、素材路径的持久化存储 +支持 SQLite 和 PostgreSQL +""" +import json +import logging +import time +from typing import Dict, List, Any, Optional + +from sqlalchemy import create_engine, Column, String, Integer, Text, Float, UniqueConstraint, func +from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base +from sqlalchemy.dialects.postgresql import JSONB + +import config + +logger = logging.getLogger(__name__) + +Base = declarative_base() + +class Project(Base): + __tablename__ = 'projects' + + id = Column(String, primary_key=True) + name = Column(String) + status = Column(String) # created, script_generated, images_generated, videos_generated, completed + product_info = Column(Text) # JSON string (SQLite) or JSONB (PG - using Text for compat) + script_data = Column(Text) # JSON string + created_at = Column(Float, default=time.time) + updated_at = Column(Float, default=time.time, onupdate=time.time) + +class SceneAsset(Base): + __tablename__ = 'scene_assets' + + id = Column(Integer, primary_key=True, autoincrement=True) + project_id = Column(String, index=True) + scene_id = Column(Integer) + asset_type = Column(String) # image, video + status = Column(String) # pending, processing, completed, failed + local_path = Column(Text, nullable=True) + remote_url = Column(Text, nullable=True) + task_id = Column(String, nullable=True) # 外部 API 的任务 ID + metadata_json = Column("metadata", Text, nullable=True) # JSON string (renamed to avoid conflict with metadata attr) + created_at = Column(Float, default=time.time) + updated_at = Column(Float, default=time.time, onupdate=time.time) + + __table_args__ = (UniqueConstraint('project_id', 'scene_id', 'asset_type', name='uix_project_scene_asset'),) + +class AppConfig(Base): + __tablename__ = 'app_config' + + key = Column(String, primary_key=True) + value = Column(Text) # JSON string + description = Column(Text, nullable=True) + updated_at = Column(Float, default=time.time, onupdate=time.time) + +class DBManager: + def __init__(self, connection_string: str = None): + if not connection_string: + connection_string = config.DB_CONNECTION_STRING + + self.engine = create_engine(connection_string, pool_recycle=3600) + self.Session = scoped_session(sessionmaker(bind=self.engine)) + self._init_db() + + def _init_db(self): + """初始化表结构""" + Base.metadata.create_all(self.engine) + + def _get_session(self): + return self.Session() + + # --- Project Operations --- + + def create_project(self, project_id: str, name: str, product_info: Dict[str, Any]): + session = self._get_session() + try: + # Check if exists + existing = session.query(Project).filter_by(id=project_id).first() + if existing: + logger.warning(f"Project {project_id} already exists.") + return + + new_project = Project( + id=project_id, + name=name, + status="created", + product_info=json.dumps(product_info, ensure_ascii=False), + created_at=time.time(), + updated_at=time.time() + ) + session.add(new_project) + session.commit() + except Exception as e: + session.rollback() + logger.error(f"Error creating project: {e}") + raise + finally: + session.close() + + def update_project_script(self, project_id: str, script: Dict[str, Any]): + session = self._get_session() + try: + project = session.query(Project).filter_by(id=project_id).first() + if project: + project.script_data = json.dumps(script, ensure_ascii=False) + project.status = "script_generated" + project.updated_at = time.time() + session.commit() + except Exception as e: + session.rollback() + logger.error(f"Error updating script: {e}") + finally: + session.close() + + def update_project_status(self, project_id: str, status: str): + session = self._get_session() + try: + project = session.query(Project).filter_by(id=project_id).first() + if project: + project.status = status + project.updated_at = time.time() + session.commit() + except Exception as e: + session.rollback() + logger.error(f"Error updating status: {e}") + finally: + session.close() + + def get_project(self, project_id: str) -> Optional[Dict[str, Any]]: + session = self._get_session() + try: + project = session.query(Project).filter_by(id=project_id).first() + if project: + data = { + "id": project.id, + "name": project.name, + "status": project.status, + "product_info": json.loads(project.product_info) if project.product_info else {}, + "script_data": json.loads(project.script_data) if project.script_data else None, + "created_at": project.created_at, + "updated_at": project.updated_at + } + return data + return None + finally: + session.close() + + def list_projects(self) -> List[Dict[str, Any]]: + session = self._get_session() + try: + projects = session.query(Project).order_by(Project.updated_at.desc()).all() + results = [] + for p in projects: + results.append({ + "id": p.id, + "name": p.name, + "status": p.status, + "updated_at": p.updated_at + }) + return results + finally: + session.close() + + # --- Asset/Task Operations --- + + def save_asset(self, project_id: str, scene_id: int, asset_type: str, + status: str, local_path: str = None, remote_url: str = None, + task_id: str = None, metadata: Dict = None): + """保存或更新资产记录 (UPSERT 逻辑)""" + session = self._get_session() + try: + asset = session.query(SceneAsset).filter_by( + project_id=project_id, + scene_id=scene_id, + asset_type=asset_type + ).first() + + meta_json = json.dumps(metadata, ensure_ascii=False) if metadata else "{}" + + if asset: + asset.status = status + asset.local_path = local_path + asset.remote_url = remote_url + asset.task_id = task_id + asset.metadata_json = meta_json + asset.updated_at = time.time() + else: + new_asset = SceneAsset( + project_id=project_id, + scene_id=scene_id, + asset_type=asset_type, + status=status, + local_path=local_path, + remote_url=remote_url, + task_id=task_id, + metadata_json=meta_json, + created_at=time.time(), + updated_at=time.time() + ) + session.add(new_asset) + + session.commit() + except Exception as e: + session.rollback() + logger.error(f"Error saving asset: {e}") + finally: + session.close() + + def get_assets(self, project_id: str, asset_type: str = None) -> List[Dict[str, Any]]: + session = self._get_session() + try: + query = session.query(SceneAsset).filter_by(project_id=project_id) + if asset_type: + query = query.filter_by(asset_type=asset_type) + + assets = query.all() + results = [] + for a in assets: + data = { + "id": a.id, + "project_id": a.project_id, + "scene_id": a.scene_id, + "asset_type": a.asset_type, + "status": a.status, + "local_path": a.local_path, + "remote_url": a.remote_url, + "task_id": a.task_id, + "metadata": json.loads(a.metadata_json) if a.metadata_json else {}, + "updated_at": a.updated_at + } + results.append(data) + return results + finally: + session.close() + + def get_asset(self, project_id: str, scene_id: int, asset_type: str) -> Optional[Dict[str, Any]]: + session = self._get_session() + try: + a = session.query(SceneAsset).filter_by( + project_id=project_id, + scene_id=scene_id, + asset_type=asset_type + ).first() + + if a: + return { + "id": a.id, + "project_id": a.project_id, + "scene_id": a.scene_id, + "asset_type": a.asset_type, + "status": a.status, + "local_path": a.local_path, + "remote_url": a.remote_url, + "task_id": a.task_id, + "metadata": json.loads(a.metadata_json) if a.metadata_json else {}, + "updated_at": a.updated_at + } + return None + finally: + session.close() + + # --- Config/Prompt Operations --- + + def get_config(self, key: str, default: Any = None) -> Any: + session = self._get_session() + try: + cfg = session.query(AppConfig).filter_by(key=key).first() + if cfg: + try: + return json.loads(cfg.value) + except: + return cfg.value + return default + finally: + session.close() + + def set_config(self, key: str, value: Any, description: str = None): + session = self._get_session() + try: + json_val = json.dumps(value, ensure_ascii=False) + + cfg = session.query(AppConfig).filter_by(key=key).first() + if cfg: + cfg.value = json_val + if description: + cfg.description = description + cfg.updated_at = time.time() + else: + new_cfg = AppConfig( + key=key, + value=json_val, + description=description, + updated_at=time.time() + ) + session.add(new_cfg) + session.commit() + except Exception as e: + session.rollback() + logger.error(f"Error setting config: {e}") + finally: + session.close() + +# Singleton instance +db = DBManager() diff --git a/modules/editor.py b/modules/editor.py new file mode 100644 index 0000000..341f30a --- /dev/null +++ b/modules/editor.py @@ -0,0 +1,269 @@ +""" +MatchMe Studio - Editor Module (Assembly + BGM) +""" +import logging +import requests +from pathlib import Path +from typing import Dict, Any, List, Optional +from moviepy.editor import ( + VideoFileClip, AudioFileClip, TextClip, + CompositeVideoClip, CompositeAudioClip, + concatenate_videoclips +) + +import config +from modules import storage + +logger = logging.getLogger(__name__) + + +# ============================================================ +# Video Assembly +# ============================================================ + +def download_video(url: str) -> str: + """Download video from URL to temp.""" + filename = f"dl_{Path(url).name}" + local_path = config.TEMP_DIR / filename + + with open(local_path, "wb") as f: + f.write(requests.get(url).content) + + return str(local_path) + + +def concatenate_scenes(video_urls: List[str]) -> str: + """Concatenate multiple video clips into one.""" + logger.info(f"Concatenating {len(video_urls)} clips...") + + clips = [] + for url in video_urls: + local_path = download_video(url) + clip = VideoFileClip(local_path) + + # Resize to 9:16 if needed + if clip.w != 1080 or clip.h != 1920: + clip = clip.resize(newsize=(1080, 1920)) + + clips.append(clip) + + final = concatenate_videoclips(clips, method="compose") + + output_path = config.TEMP_DIR / f"merged_{int(__import__('time').time())}.mp4" + final.write_videofile( + str(output_path), + fps=30, + codec="libx264", + audio_codec="aac", + threads=4, + logger=None + ) + + # Cleanup + for clip in clips: + clip.close() + final.close() + + return str(output_path) + + +# ============================================================ +# Subtitle Burning +# ============================================================ + +def burn_subtitles( + video_path: str, + scenes: List[Dict[str, Any]] +) -> str: + """Burn subtitles onto video.""" + logger.info("Burning subtitles...") + + clip = VideoFileClip(video_path) + subtitle_clips = [] + + current_time = 0 + for scene in scenes: + voiceover = scene.get("voiceover", "") + duration = scene.get("duration", 5) + + if voiceover: + try: + txt = TextClip( + voiceover, + fontsize=48, + color='white', + stroke_color='black', + stroke_width=2, + font='DejaVu-Sans', + method='caption', + size=(900, None) + ).set_position(('center', 1600)).set_start(current_time).set_duration(duration) + + subtitle_clips.append(txt) + except Exception as e: + logger.warning(f"Subtitle error: {e}") + + current_time += duration + + if subtitle_clips: + final = CompositeVideoClip([clip] + subtitle_clips) + else: + final = clip + + output_path = config.TEMP_DIR / f"subtitled_{int(__import__('time').time())}.mp4" + final.write_videofile( + str(output_path), + fps=30, + codec="libx264", + audio_codec="aac", + threads=4, + logger=None + ) + + clip.close() + final.close() + + return str(output_path) + + +# ============================================================ +# Voiceover Mixing +# ============================================================ + +def mix_voiceover(video_path: str, voiceover_url: str) -> str: + """Mix voiceover audio with video.""" + if not voiceover_url: + return video_path + + logger.info("Mixing voiceover...") + + # Download voiceover + vo_local = download_video(voiceover_url) + + video = VideoFileClip(video_path) + voiceover = AudioFileClip(vo_local) + + # Trim voiceover if longer than video + if voiceover.duration > video.duration: + voiceover = voiceover.subclip(0, video.duration) + + # Mix with original audio (if any) + if video.audio: + mixed = CompositeAudioClip([ + video.audio.volumex(0.3), # Lower original + voiceover.volumex(1.0) + ]) + else: + mixed = voiceover + + final = video.set_audio(mixed) + + output_path = config.TEMP_DIR / f"voiced_{int(__import__('time').time())}.mp4" + final.write_videofile( + str(output_path), + fps=30, + codec="libx264", + audio_codec="aac", + threads=4, + logger=None + ) + + video.close() + voiceover.close() + final.close() + + return str(output_path) + + +# ============================================================ +# BGM Mixing +# ============================================================ + +def mix_bgm( + video_path: str, + bgm_path: str, + bgm_volume: float = 0.2 +) -> str: + """Mix background music with video.""" + logger.info("Mixing BGM...") + + video = VideoFileClip(video_path) + bgm = AudioFileClip(bgm_path) + + # Loop BGM if shorter than video + if bgm.duration < video.duration: + loops_needed = int(video.duration / bgm.duration) + 1 + bgm = bgm.loop(loops_needed) + + # Trim to video length + bgm = bgm.subclip(0, video.duration).volumex(bgm_volume) + + # Mix with existing audio + if video.audio: + mixed = CompositeAudioClip([video.audio, bgm]) + else: + mixed = bgm + + final = video.set_audio(mixed) + + output_path = config.TEMP_DIR / f"bgm_{int(__import__('time').time())}.mp4" + final.write_videofile( + str(output_path), + fps=30, + codec="libx264", + audio_codec="aac", + threads=4, + logger=None + ) + + video.close() + bgm.close() + final.close() + + return str(output_path) + + +# ============================================================ +# Full Pipeline +# ============================================================ + +def assemble_final_video( + video_urls: List[str], + scenes: List[Dict[str, Any]], + voiceover_url: str = "", + bgm_url: str = "" +) -> str: + """ + Full assembly pipeline: + 1. Concatenate scene videos + 2. Burn subtitles + 3. Mix voiceover + 4. Mix BGM + 5. Upload to R2 + """ + logger.info("Starting full assembly...") + + # Step 1: Concatenate + merged = concatenate_scenes(video_urls) + + # Step 2: Subtitles + subtitled = burn_subtitles(merged, scenes) + + # Step 3: Voiceover + if voiceover_url: + voiced = mix_voiceover(subtitled, voiceover_url) + else: + voiced = subtitled + + # Step 4: BGM + if bgm_url: + bgm_local = download_video(bgm_url) + final_path = mix_bgm(voiced, bgm_local) + else: + final_path = voiced + + # Step 5: Upload + final_url = storage.upload_file(final_path) + logger.info(f"Final video uploaded: {final_url}") + + return final_url diff --git a/modules/export_utils.py b/modules/export_utils.py new file mode 100644 index 0000000..cd01b35 --- /dev/null +++ b/modules/export_utils.py @@ -0,0 +1,157 @@ +import os +import zipfile +import logging +import shutil +import math +from pathlib import Path +from typing import List, Dict, Any +import config + +logger = logging.getLogger(__name__) + +def format_timestamp(seconds: float) -> str: + """Convert seconds to SRT timestamp format (HH:MM:SS,mmm)""" + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + millis = int((seconds - int(seconds)) * 1000) + return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" + +def generate_srt(script_data: Dict[str, Any], video_map: Dict[int, str]) -> str: + """Generate SRT content from script data""" + scenes = script_data.get("scenes", []) + srt_content = "" + current_time = 0.0 + + # Need to get durations from actual videos if possible, else estimate + from modules import ffmpeg_utils + + for i, scene in enumerate(scenes): + scene_id = scene["id"] + # Get duration + duration = 5.0 + if scene_id in video_map and os.path.exists(video_map[scene_id]): + try: + info = ffmpeg_utils.get_video_info(video_map[scene_id]) + duration = info.get("duration", 5.0) + except: + pass + + start_time = current_time + end_time = current_time + duration + current_time = end_time + + text = scene.get("subtitle", "") + if text: + srt_content += f"{i+1}\n" + srt_content += f"{format_timestamp(start_time)} --> {format_timestamp(end_time)}\n" + srt_content += f"{text}\n\n" + + return srt_content + +def create_capcut_package(project_id: str, script_data: Dict[str, Any], assets: Dict[str, str]) -> str: + """ + Create a ZIP package for CapCut (JianYing) import + Contains: + - videos/ (scene videos) + - audios/ (voiceover, bgm) + - images/ (fancy text transparent pngs) + - subtitles.srt + """ + package_dir = config.TEMP_DIR / f"capcut_pkg_{project_id}_{int(os.getpid())}" + if package_dir.exists(): + shutil.rmtree(package_dir) + package_dir.mkdir() + + (package_dir / "videos").mkdir() + (package_dir / "audios").mkdir() + (package_dir / "images").mkdir() + + # 1. Generate SRT + # Need to reconstruct video map from assets or script + # Assuming 'assets' contains 'scene_videos' map + scene_videos = assets.get("scene_videos", {}) + srt_content = generate_srt(script_data, scene_videos) + with open(package_dir / "subtitles.srt", "w", encoding="utf-8") as f: + f.write(srt_content) + + # 2. Copy Videos + scenes = script_data.get("scenes", []) + for i, scene in enumerate(scenes): + sid = scene["id"] + if sid in scene_videos and os.path.exists(scene_videos[sid]): + # Rename with sequence number for easy sorting: 01_scene.mp4 + ext = Path(scene_videos[sid]).suffix + dest_name = f"{i+1:02d}_scene_{sid}{ext}" + shutil.copy(scene_videos[sid], package_dir / "videos" / dest_name) + + # 3. Copy Audio (Voiceover) + # We might not have the separate voiceover file easily accessible if it was mixed on the fly. + # But usually we generate it to temp. + # Option: Re-generate voiceover audio for the whole track or segments? + # Better: If we have 'voiceover_segments', generate them or copy if cached. + # For now, let's try to find if we have a full voiceover file or just use segments. + # Simplest: Re-generate the full voiceover audio file if it doesn't exist as a standalone asset. + # Or check if user just wants the pieces. + # Let's check if we have a mixed audio file. Usually we don't save the intermediate audio as an asset. + # So we might need to re-generate the voiceover audio here. + from modules import factory + full_vo_text = " ".join([s.get("voiceover", "") for s in scenes if s.get("voiceover")]) + if full_vo_text: + try: + # Assuming default voice + voice_type = config.VOLC_TTS_DEFAULT_VOICE + vo_path = factory.generate_voiceover_volcengine(full_vo_text, voice_type) + shutil.copy(vo_path, package_dir / "audios" / "full_voiceover.mp3") + except Exception as e: + logger.warning(f"Failed to generate export voiceover: {e}") + + # Copy BGM + # Check settings or script for BGM? BGM is usually a global setting in Composer. + # We'll just look for BGM in assets folder or let user drag their own. + # Or if we saved the BGM selection in the project, we could copy it. + # For now, skip specific BGM unless we know which one was used. + + # 4. Copy Fancy Text Images + # We need to re-render them or find them. + # Since they are generated to temp in composer, they might be gone. + # Safer to re-render them. + from modules.text_renderer import renderer + for i, scene in enumerate(scenes): + ft = scene.get("fancy_text") + if ft: + text = ft.get("text", "") if isinstance(ft, dict) else "" + style = ft.get("style", "highlight") if isinstance(ft, dict) else "highlight" + if text: + try: + # Render + if isinstance(style, str): + # Simple mapping or default + # We need the full style dict logic from composer ideally + # For export, we just use default render + pass + + # Actually, composer logic for style resolution is complex. + # Let's just use a simple render here. + img_path = renderer.render(text, {"font_size": 60, "font_color": "#FFFFFF"}, cache=False) + shutil.copy(img_path, package_dir / "images" / f"{i+1:02d}_text_{scene['id']}.png") + except: + pass + + # 5. Zip it + zip_path = config.TEMP_DIR / f"capcut_export_{project_id}.zip" + with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: + for root, dirs, files in os.walk(package_dir): + for file in files: + file_path = os.path.join(root, file) + arcname = os.path.relpath(file_path, package_dir) + zipf.write(file_path, arcname) + + # Cleanup + shutil.rmtree(package_dir) + return str(zip_path) + + + + + diff --git a/modules/factory.py b/modules/factory.py new file mode 100644 index 0000000..aa734f8 --- /dev/null +++ b/modules/factory.py @@ -0,0 +1,801 @@ +""" +MatchMe Studio - Factory Module (Concurrent Scene Generation) +Using Volcengine (Doubao) API for Image and Video +""" +import os +import time +import logging +import requests +import json +import re +import base64 +import subprocess +from pathlib import Path +from typing import Dict, Any, List, Optional +from concurrent.futures import ThreadPoolExecutor, as_completed +from elevenlabs import ElevenLabs, VoiceSettings +from openai import OpenAI + +import config +from modules import storage + +logger = logging.getLogger(__name__) + +# Initialize OpenAI Client for Volcengine Image Generation +client = OpenAI( + api_key=config.VOLC_API_KEY, + base_url=config.VOLC_BASE_URL +) + +# ============================================================ +# Helper Functions +# ============================================================ + +def _download_as_base64(url: str) -> str: + """Download image from URL and convert to Base64.""" + try: + response = requests.get(url) + response.raise_for_status() + return base64.b64encode(response.content).decode('utf-8') + except Exception as e: + logger.error(f"Failed to download/encode image: {e}") + return "" + +# ============================================================ +# Image Generation (Doubao / Volcengine) +# ============================================================ + +def generate_scene_image( + scene: Dict[str, Any], + brief: Dict[str, Any] = None, + reference_images: List[str] = None +) -> str: + """ + Generate image using Volcengine API (Doubao Image). + Using raw requests to match user's curl example exactly. + """ + # Build prompt + image_prompt = scene.get("image_prompt", "") + if not image_prompt: + # Fallback prompt construction + keyframe = scene.get("keyframe", {}) + # Stronger style consistency intro + parts = ["Cinematic shot, 8k, photorealistic"] + if brief: + if brief.get("product_visual_description"): + parts.append(f"Product: {brief['product_visual_description']}") + parts.extend([ + f"Subject: {keyframe.get('subject', 'product')}", + f"Environment: {keyframe.get('environment', 'studio')}", + f"Action: {keyframe.get('focus', '')}" + ]) + image_prompt = ", ".join(parts) + + # Append explicit consistency enforcement to prompt + if brief and brief.get("product_visual_description"): + if brief['product_visual_description'] not in image_prompt: + image_prompt = f"{brief['product_visual_description']}, {image_prompt}" + + logger.info(f"Generating image (Volcengine): {image_prompt[:50]}...") + + url = f"{config.VOLC_BASE_URL}/images/generations" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {config.VOLC_API_KEY}" + } + + # Payload matching user's curl example + payload = { + "model": config.IMAGE_MODEL_ID, + "prompt": image_prompt, + "sequential_image_generation": "disabled", + "response_format": "b64_json", # Use base64 to avoid temp url expiration issues + "size": "2K", # User specified 2K + "stream": False, + "watermark": True + } + + try: + response = requests.post(url, headers=headers, json=payload, timeout=60) + + if response.status_code != 200: + logger.error(f"Image API Error: {response.text}") + raise ValueError(f"Image API failed: {response.status_code} - {response.text}") + + data = response.json() + + # Extract Image Data + image_data = None + if "data" in data and len(data["data"]) > 0: + image_data = data["data"][0].get("b64_json") + if not image_data: + # Fallback to URL download if b64 not present + img_url = data["data"][0].get("url") + if img_url: + # Download the image to ensure we have it locally + image_data = _download_as_base64(img_url) + + if not image_data: + raise ValueError("No image data returned") + + # Decode and Save + filename = f"scene_{scene.get('id', 0)}_{int(time.time())}.jpg" + local_path = config.TEMP_DIR / filename + + with open(local_path, "wb") as f: + f.write(base64.b64decode(image_data)) + + # Upload to R2 + r2_url = storage.upload_file(str(local_path)) + logger.info(f"Scene {scene.get('id', '?')} image uploaded: {r2_url}") + return r2_url + + except Exception as e: + logger.error(f"Image Generation Failed: {e}") + raise + + +def generate_all_scene_images_concurrent( + scenes: List[Dict[str, Any]], + brief: Dict[str, Any] = None, + reference_images: List[str] = None, + max_workers: int = 3 +) -> List[str]: + """Generate images for all scenes concurrently.""" + logger.info(f"Generating {len(scenes)} images concurrently...") + image_urls = [None] * len(scenes) + + def generate_single(index: int, scene: Dict[str, Any]) -> tuple: + url = generate_scene_image(scene, brief, reference_images) + return index, url + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(generate_single, i, scene): i + for i, scene in enumerate(scenes) + } + + for future in as_completed(futures): + index = futures[future] + try: + _, url = future.result() + image_urls[index] = url + except Exception as e: + logger.error(f"Scene {index+1} failed: {e}") + + return image_urls + + +# ============================================================ +# Video Generation (Doubao Video / PixelDance) +# ============================================================ + +def generate_scene_video( + start_frame_url: str, + motion_prompt: str, + duration: int = 5 +) -> str: + """ + Generate video using Volcengine API (Async Task Flow). + """ + logger.info(f"Generating video (Volcengine): {motion_prompt[:50]}...") + + # 1. Create Task + create_url = f"{config.VOLC_BASE_URL}/contents/generations/tasks" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {config.VOLC_API_KEY}" + } + + # Construct Content List (Text + Optional Image) + content_list = [ + { + "type": "text", + "text": f"{motion_prompt} --resolution 1080p --duration {duration} --camerafixed false --watermark true" + } + ] + + if start_frame_url: + content_list.append({ + "type": "image_url", + "image_url": {"url": start_frame_url} + }) + + payload = { + "model": config.VIDEO_MODEL_ID, + "content": content_list + } + + try: + response = requests.post(create_url, headers=headers, json=payload, timeout=30) + if response.status_code != 200: + # 202 Accepted is also possible for async tasks + if response.status_code != 202: + logger.error(f"Video Task Creation Error: {response.text}") + raise ValueError(f"Video Task failed: {response.status_code} - {response.text}") + + data = response.json() + task_id = data.get("id") + if not task_id: + # Sometimes ID is in data.id or similar + task_id = data.get("data", {}).get("id") + + if not task_id: + raise ValueError(f"No Task ID returned: {data}") + + logger.info(f"Video Task Created: {task_id}. Polling for result...") + + # 2. Poll for Result + # GET /contents/generations/tasks/{id} + max_retries = 60 # 5 mins max (5s interval) + video_url = None + + for _ in range(max_retries): + time.sleep(5) + status_url = f"{config.VOLC_BASE_URL}/contents/generations/tasks/{task_id}" + resp = requests.get(status_url, headers=headers, timeout=30) + + if resp.status_code == 200: + res_data = resp.json() + # Check status + # Structure usually: data.status = "succeeded" / "running" / "failed" + # Or top level status + + status = res_data.get("status") + if not status and "data" in res_data: + status = res_data["data"].get("status") + + if status == "succeeded" or status == "SUCCEEDED": + # Extract URL + content = res_data.get("data", {}).get("content", []) + if not content and "content" in res_data: + content = res_data["content"] + + # Find video url in content + # Content is usually list of dicts with type='video' or 'video_url' + for item in content: + if item.get("video_url"): + video_url = item["video_url"] + break + if item.get("url"): # sometimes just url + video_url = item["url"] + break + + if video_url: + break + elif status == "failed" or status == "FAILED": + reason = res_data.get("data", {}).get("error", "Unknown error") + raise ValueError(f"Video Generation Failed: {reason}") + + # If running/queued, continue waiting + + if not video_url: + raise TimeoutError("Video generation timed out or failed to return URL.") + + # 3. Download and Upload to R2 + logger.info(f"Video Generated. Downloading: {video_url}") + filename = f"vid_doubao_{int(time.time())}.mp4" + local_path = config.TEMP_DIR / filename + + resp = requests.get(video_url, stream=True) + if resp.status_code != 200: + raise ValueError(f"Failed to download generated video: {resp.status_code}") + + with open(local_path, "wb") as f: + for chunk in resp.iter_content(chunk_size=8192): + f.write(chunk) + + r2_url = storage.upload_file(str(local_path)) + return r2_url + + except Exception as e: + logger.error(f"Video Generation Error: {e}") + raise + + +def generate_all_scene_videos_concurrent( + scenes: List[Dict[str, Any]], + image_urls: List[str], + max_workers: int = 2 +) -> List[str]: + """Generate videos concurrently.""" + logger.info(f"Generating {len(scenes)} videos concurrently...") + video_urls = [None] * len(scenes) + + def generate_single(index: int, scene: Dict[str, Any], img_url: str) -> tuple: + motion = scene.get("camera_movement", "slow zoom") + if scene.get("image_prompt"): + motion = f"{scene['image_prompt']}. {motion}" + + duration = scene.get("duration", 5) + url = generate_scene_video(img_url, motion, duration) + return index, url + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(generate_single, i, scene, image_urls[i]): i + for i, scene in enumerate(scenes) + } + + for future in as_completed(futures): + index = futures[future] + try: + _, url = future.result() + video_urls[index] = url + except Exception as e: + logger.error(f"Scene {index+1} video failed: {e}") + + return video_urls + + +# ============================================================ +# Audio Generation (ElevenLabs) +# ============================================================ + +def generate_voiceover(text: str, style: str = "") -> str: + """Generate voiceover audio. Returns R2 URL.""" + if not text or not text.strip(): + return "" + + stability = 0.3 if "ASMR" in style else 0.5 + similarity = 0.9 if "ASMR" in style else 0.8 + + logger.info(f"Generating voiceover ({len(text)} chars, style={style})...") + + try: + el_client = ElevenLabs(api_key=config.XI_KEY) + + audio_stream = el_client.text_to_speech.convert( + voice_id=config.ELEVENLABS_VOICE_ID, + text=text, + model_id=config.ELEVENLABS_MODEL, + voice_settings=VoiceSettings(stability=stability, similarity_boost=similarity) + ) + + filename = f"vo_{int(time.time())}.mp3" + local_path = config.TEMP_DIR / filename + + with open(local_path, "wb") as f: + for chunk in audio_stream: + f.write(chunk) + + r2_url = storage.upload_file(str(local_path)) + return r2_url + except Exception as e: + logger.error(f"Voiceover failed: {e}") + return "" + + +def generate_full_voiceover(scenes: List[Dict[str, Any]], style: str = "") -> str: + """Generate combined voiceover for all scenes.""" + voiceovers = [] + for s in scenes: + vo = s.get("voiceover", "") + if vo and vo.strip() and not vo.startswith("("): + voiceovers.append(vo.strip()) + + if not voiceovers: + return "" + + full_text = " ".join(voiceovers) + return generate_voiceover(full_text, style) + + +# ============================================================ +# Audio Generation (Edge TTS - 免费中文语音合成) +# ============================================================ + +# Edge TTS 中文音色预设 (免费,效果好) +EDGE_TTS_VOICES = { + # 女声 + "sweet_female": "zh-CN-XiaoxiaoNeural", # 晓晓 - 甜美活泼(推荐) + "gentle_female": "zh-CN-XiaoyiNeural", # 晓伊 - 温柔知性 + "lively_female": "zh-CN-XiaochenNeural", # 晓辰 - 活泼可爱 + "broadcast_female": "zh-CN-XiaoqiuNeural", # 晓秋 - 新闻播报 + # 男声 + "general_male": "zh-CN-YunxiNeural", # 云希 - 温暖男声 + "broadcast_male": "zh-CN-YunjianNeural", # 云健 - 专业播报 +} + +# 火山引擎 TTS 音色预设 (需开通服务) - 选择抖音带货友好的音色 +VOLC_TTS_VOICES = { + # 抖音带货友好女声 + "sweet_female": "zh_female_vv_uranus_bigtts", # viv 2.0 通用女声(甜美) + "lively_female": "zh_female_jitangnv_saturn_bigtts", # 鸡汤女(元气) + "broadcast_female": "zh_male_ruyaichen_saturn_bigtts", # 入雅尘(新闻播报)- 若需女声播报可换 zh_female_meilinyou_saturn_bigtts + "meilinvyou": "zh_female_meilinvyou_saturn_bigtts", + # 男声 + "general_male": "zh_male_dayi_saturn_bigtts", # 大义(沉稳男声) +} + + +def generate_voiceover_edge( + text: str, + voice_type: str = "sweet_female", + rate: str = "+0%", + volume: str = "+0%", + output_path: str = None +) -> str: + """ + 使用 Edge TTS 生成中文旁白(免费,效果好) + + Args: + text: 旁白文本 + voice_type: 音色类型(见 EDGE_TTS_VOICES)或直接使用音色名 + rate: 语速调整,如 "+10%", "-20%" + volume: 音量调整,如 "+10%", "-20%" + output_path: 输出路径 + + Returns: + 音频文件路径 + """ + import asyncio + import edge_tts + + if not text or not text.strip(): + logger.warning("Empty text provided for TTS") + return "" + + # 获取音色 + voice = EDGE_TTS_VOICES.get(voice_type, voice_type) + + logger.info(f"Generating voiceover (Edge TTS): {len(text)} chars, voice={voice}") + + if not output_path: + filename = f"vo_edge_{int(time.time())}.mp3" + output_path = str(config.TEMP_DIR / filename) + + async def _generate(): + communicate = edge_tts.Communicate(text, voice, rate=rate, volume=volume) + await communicate.save(output_path) + + # Simple retry logic for Edge TTS + max_retries = 3 + for i in range(max_retries): + try: + asyncio.run(_generate()) + if os.path.exists(output_path) and os.path.getsize(output_path) > 0: + logger.info(f"Edge TTS voiceover generated: {output_path}") + return output_path + except Exception as e: + logger.warning(f"Edge TTS attempt {i+1} failed: {e}") + time.sleep(1.0) # wait before retry + + logger.error("Edge TTS failed after retries.") + return "" + + +def generate_voiceover_volcengine_ws( + text: str, + voice_type: str = "sweet_female", + output_path: str = None, + timeout: int = 120 +) -> str: + """ + 使用火山 WebSocket Binary Demo 生成 TTS 音频 + 依赖目录:/Volumes/Tony/video-flow/volcengine_binary_demo/.venv/bin/python + """ + if not text or not text.strip(): + logger.warning("Empty text provided for TTS (ws)") + return "" + + voice_id = VOLC_TTS_VOICES.get(voice_type, voice_type) + + venv_python = Path("/Volumes/Tony/video-flow/volcengine_binary_demo/.venv/bin/python") + demo_script = Path("/Volumes/Tony/video-flow/volcengine_binary_demo/examples/volcengine/binary.py") + + if not venv_python.exists() or not demo_script.exists(): + logger.error("Volcengine WS demo or venv not found. Please install under volcengine_binary_demo/.venv") + return "" + + if not output_path: + output_path = str(config.TEMP_DIR / f"vo_volc_ws_{int(time.time())}.mp3") + + cmd = [ + str(venv_python), + str(demo_script), + "--appid", config.VOLC_TTS_APPID, + "--access_token", config.VOLC_TTS_ACCESS_TOKEN, + "--voice_type", voice_id, + "--text", text, + "--encoding", "mp3", + ] + + logger.info(f"Calling Volcengine WS TTS: voice={voice_id}, len={len(text)}") + try: + result = subprocess.run( + cmd, + cwd="/Volumes/Tony/video-flow/volcengine_binary_demo", + capture_output=True, + text=True, + timeout=timeout, + ) + if result.returncode != 0: + logger.error(f"Volc WS TTS failed: {result.stderr}") + return "" + + # demo 保存在 cwd 下 voice_type.mp3 + demo_out = Path("/Volumes/Tony/video-flow/volcengine_binary_demo") / f"{voice_id}.mp3" + if not demo_out.exists(): + logger.error("Volc WS TTS output not found") + return "" + + Path(output_path).write_bytes(demo_out.read_bytes()) + logger.info(f"Volc WS TTS saved to {output_path}") + return output_path + except Exception as e: + logger.error(f"Volc WS TTS error: {e}") + return "" + + +def generate_voiceover_volcengine( + text: str, + voice_type: str = "sweet_female", + speed_ratio: float = 1.0, + volume_ratio: float = 1.0, + pitch_ratio: float = 1.0, + output_path: str = None +) -> str: + """ + 使用火山引擎 TTS 生成中文旁白 + + Args: + text: 旁白文本 + voice_type: 音色类型(见 VOLC_TTS_VOICES)或直接使用音色 ID + speed_ratio: 语速(0.5-2.0,默认1.0) + volume_ratio: 音量(0.5-2.0,默认1.0) + pitch_ratio: 音调(0.5-2.0,默认1.0) + output_path: 输出路径(可选,默认自动生成) + + Returns: + 音频文件路径 + """ + import uuid + + if not text or not text.strip(): + logger.warning("Empty text provided for TTS") + return "" + + # 获取音色 ID(火山音色表 + fallback 自定义) + voice_id = VOLC_TTS_VOICES.get(voice_type, voice_type) + + logger.info(f"Generating voiceover (Volcengine TTS): {len(text)} chars, voice={voice_id}") + + # 先尝试 WebSocket Binary(官方 demo 已验证可用) + ws_path = generate_voiceover_volcengine_ws(text, voice_type, output_path) + if ws_path: + return ws_path + + # 若 WS 异常,再尝试 HTTP + url = "https://openspeech.bytedance.com/api/v1/tts" + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer;{config.VOLC_TTS_ACCESS_TOKEN}" + } + + payload = { + "app": { + "appid": config.VOLC_TTS_APPID, + "token": config.VOLC_TTS_ACCESS_TOKEN, + "cluster": "volcano_tts" + }, + "user": { + "uid": "video_flow_user" + }, + "audio": { + "voice_type": voice_id, + "encoding": "mp3", + "speed_ratio": speed_ratio, + "volume_ratio": volume_ratio, + "pitch_ratio": pitch_ratio + }, + "request": { + "reqid": str(uuid.uuid4()), + "text": text, + "text_type": "plain", + "operation": "query", + "with_timestamp": "1", + "extra_param": json.dumps({ + "disable_markdown_filter": False + }) + } + } + + try: + response = requests.post(url, headers=headers, json=payload, timeout=60) + + if response.status_code != 200: + logger.error(f"Volcengine TTS Error: {response.status_code} - {response.text}") + # Fallback to Edge TTS with a safe default voice + fallback_voice = "sweet_female" if voice_type not in EDGE_TTS_VOICES else voice_type + return generate_voiceover_edge(text, fallback_voice, output_path=output_path) + + data = response.json() + + ret_code = data.get("code") + if ret_code not in (0, 3000, 20000000): + error_msg = data.get("message", "Unknown error") + logger.error(f"Volcengine TTS Error: {error_msg}") + # Fallback to Edge TTS with a safe default voice + fallback_voice = "sweet_female" if voice_type not in EDGE_TTS_VOICES else voice_type + return generate_voiceover_edge(text, fallback_voice, output_path=output_path) + + audio_data = data.get("data", "") + if not audio_data: + raise ValueError("No audio data returned") + + if not output_path: + filename = f"vo_volc_{int(time.time())}.mp3" + output_path = str(config.TEMP_DIR / filename) + + with open(output_path, "wb") as f: + f.write(base64.b64decode(audio_data)) + + logger.info(f"Voiceover generated (HTTP): {output_path}") + return output_path + + except Exception as e: + logger.error(f"Volcengine TTS HTTP error: {e}") + # Fallback to Edge TTS with a safe default voice + fallback_voice = "sweet_female" if voice_type not in EDGE_TTS_VOICES else voice_type + return generate_voiceover_edge(text, fallback_voice, output_path=output_path) + + +def generate_voiceover_volcengine_long( + text: str, + voice_type: str = "sweet_female", + speed_ratio: float = 1.0, + output_path: str = None, + max_chunk_length: int = 300 +) -> str: + """ + 火山引擎 TTS 长文本处理(自动分段合成) + + 对于超过 max_chunk_length 的文本,自动分段合成后拼接 + """ + if len(text) <= max_chunk_length: + return generate_voiceover_volcengine( + text=text, + voice_type=voice_type, + speed_ratio=speed_ratio, + output_path=output_path + ) + + logger.info(f"Long text ({len(text)} chars), splitting into chunks...") + + # 按句子分段 + import re + sentences = re.split(r'([。!?;.!?;])', text) + + chunks = [] + current_chunk = "" + + for i in range(0, len(sentences) - 1, 2): + sentence = sentences[i] + (sentences[i + 1] if i + 1 < len(sentences) else "") + + if len(current_chunk) + len(sentence) <= max_chunk_length: + current_chunk += sentence + else: + if current_chunk: + chunks.append(current_chunk) + current_chunk = sentence + + if current_chunk: + chunks.append(current_chunk) + + # 如果最后一段是奇数句子 + if len(sentences) % 2 == 1 and sentences[-1]: + if chunks: + chunks[-1] += sentences[-1] + else: + chunks.append(sentences[-1]) + + logger.info(f"Split into {len(chunks)} chunks") + + # 生成每段音频 + chunk_files = [] + for i, chunk in enumerate(chunks): + chunk_path = str(config.TEMP_DIR / f"vo_chunk_{i}_{int(time.time())}.mp3") + try: + path = generate_voiceover_volcengine( + text=chunk, + voice_type=voice_type, + speed_ratio=speed_ratio, + output_path=chunk_path + ) + chunk_files.append(path) + except Exception as e: + logger.error(f"Chunk {i} failed: {e}") + # 继续处理其他段落 + + if not chunk_files: + raise ValueError("All TTS chunks failed") + + # 使用 FFmpeg 合并音频 + if len(chunk_files) == 1: + if output_path: + import shutil + shutil.move(chunk_files[0], output_path) + return output_path + return chunk_files[0] + + # 创建合并文件列表 + concat_list = config.TEMP_DIR / f"concat_audio_{os.getpid()}.txt" + with open(concat_list, "w") as f: + for cf in chunk_files: + f.write(f"file '{cf}'\n") + + if not output_path: + output_path = str(config.TEMP_DIR / f"vo_volc_merged_{int(time.time())}.mp3") + + # FFmpeg 合并 + import subprocess + cmd = [ + "ffmpeg", "-y", + "-f", "concat", + "-safe", "0", + "-i", str(concat_list), + "-c", "copy", + output_path + ] + + subprocess.run(cmd, capture_output=True, check=True) + + # 清理临时文件 + for cf in chunk_files: + try: + os.remove(cf) + except: + pass + concat_list.unlink(missing_ok=True) + + logger.info(f"Merged voiceover: {output_path}") + return output_path + + +def generate_scene_voiceovers_volcengine( + scenes: List[Dict[str, Any]], + voice_type: str = "sweet_female", + output_dir: str = None +) -> List[str]: + """ + 为每个场景单独生成旁白音频 + + Args: + scenes: 场景列表,每个场景包含 voiceover 字段 + voice_type: 音色类型 + output_dir: 输出目录 + + Returns: + 音频文件路径列表 + """ + if output_dir: + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True) + else: + output_dir = config.TEMP_DIR + + audio_paths = [] + + for i, scene in enumerate(scenes): + vo_text = scene.get("voiceover", "") + + if not vo_text or not vo_text.strip() or vo_text.startswith("("): + # 无旁白或是注释 + audio_paths.append("") + continue + + try: + output_path = str(output_dir / f"scene_{i+1}_vo.mp3") + path = generate_voiceover_volcengine( + text=vo_text.strip(), + voice_type=voice_type, + output_path=output_path + ) + audio_paths.append(path) + except Exception as e: + logger.error(f"Scene {i+1} voiceover failed: {e}") + audio_paths.append("") + + return audio_paths diff --git a/modules/fancy_text.py b/modules/fancy_text.py new file mode 100644 index 0000000..31d438c --- /dev/null +++ b/modules/fancy_text.py @@ -0,0 +1,708 @@ +""" +抖音风格花字生成模块 +使用 Pillow 生成透明 PNG 图片,支持描边、渐变、气泡框等效果 +""" +import os +import hashlib +import logging +from pathlib import Path +from typing import Dict, Any, Tuple, List, Optional + +from PIL import Image, ImageDraw, ImageFont, ImageFilter + +import config + +logger = logging.getLogger(__name__) + +# 花字缓存目录 +FANCY_TEXT_CACHE_DIR = config.TEMP_DIR / "fancy_text_cache" +FANCY_TEXT_CACHE_DIR.mkdir(exist_ok=True) + + +def _get_font(font_name: str = None, size: int = 48) -> ImageFont.FreeTypeFont: + """获取字体对象,遇到无效字体会继续尝试下一候选,最后才降级为默认字体""" + candidates = [] + if font_name and os.path.exists(font_name): + candidates.append(font_name) + else: + candidates.extend([ + config.FONTS_DIR / "AlibabaPuHuiTi-Bold.ttf", + config.FONTS_DIR / "AlibabaPuHuiTi-Regular.ttf", + config.FONTS_DIR / "NotoSansSC-Bold.otf", + config.FONTS_DIR / "NotoSansSC-Regular.otf", + ]) + candidates.extend([ + "/System/Library/Fonts/PingFang.ttc", + "/System/Library/Fonts/STHeiti Medium.ttc", + "/Library/Fonts/Arial Unicode.ttf", + "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + "C:/Windows/Fonts/msyh.ttc", + "C:/Windows/Fonts/simhei.ttf", + ]) + for path in candidates: + if not path: + continue + p = str(path) + if not os.path.exists(p): + continue + if isinstance(path, Path) and path.stat().st_size < 10000: + continue + try: + return ImageFont.truetype(p, size) + except Exception as e: + logger.warning(f"Failed to load font {p}: {e}") + continue + logger.warning("No suitable font found, using default") + return ImageFont.load_default() + + +def _hex_to_rgb(hex_color: str) -> Tuple[int, int, int]: + """十六进制颜色转 RGB""" + hex_color = hex_color.lstrip("#") + return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) + + +def _get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]: + """获取文字尺寸""" + # 创建临时图像来测量文字 + dummy_img = Image.new("RGBA", (1, 1)) + draw = ImageDraw.Draw(dummy_img) + bbox = draw.textbbox((0, 0), text, font=font) + return bbox[2] - bbox[0], bbox[3] - bbox[1] + + +def _cache_key(text: str, style: Dict) -> str: + """生成缓存键""" + content = f"{text}_{str(sorted(style.items()))}" + return hashlib.md5(content.encode()).hexdigest() + + +def create_text_with_stroke( + text: str, + font_size: int = 60, + font_color: str = "#FFFFFF", + stroke_color: str = "#000000", + stroke_width: int = 4, + font_name: str = None, + padding: int = 20 +) -> Image.Image: + """ + 创建带描边的文字图片 + + Args: + text: 文字内容 + font_size: 字体大小 + font_color: 字体颜色(十六进制) + stroke_color: 描边颜色 + stroke_width: 描边宽度 + font_name: 字体路径 + padding: 内边距 + + Returns: + 透明 PNG 图片 + """ + font = _get_font(font_name, font_size) + text_w, text_h = _get_text_size(text, font) + + # 图片尺寸(加上描边和内边距) + img_w = text_w + stroke_width * 2 + padding * 2 + img_h = text_h + stroke_width * 2 + padding * 2 + + # 创建透明图片 + img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + # 文字位置 + x = padding + stroke_width + y = padding + stroke_width + + # 绘制描边(通过偏移绘制多次) + stroke_rgb = _hex_to_rgb(stroke_color) + (255,) + for dx in range(-stroke_width, stroke_width + 1): + for dy in range(-stroke_width, stroke_width + 1): + if dx * dx + dy * dy <= stroke_width * stroke_width: + draw.text((x + dx, y + dy), text, font=font, fill=stroke_rgb) + + # 绘制主文字 + font_rgb = _hex_to_rgb(font_color) + (255,) + draw.text((x, y), text, font=font, fill=font_rgb) + + return img + + +def create_text_with_shadow( + text: str, + font_size: int = 60, + font_color: str = "#FFFFFF", + shadow_color: str = "#000000", + shadow_offset: Tuple[int, int] = (4, 4), + shadow_blur: int = 5, + font_name: str = None, + padding: int = 30, + stroke_color: str = None, + stroke_width: int = 0 +) -> Image.Image: + """ + 创建带阴影的文字图片,可选描边(用于双层安全描边) + """ + font = _get_font(font_name, font_size) + text_w, text_h = _get_text_size(text, font) + + # 图片尺寸 + extra = max(shadow_blur, stroke_width * 2) + img_w = text_w + abs(shadow_offset[0]) + extra * 2 + padding * 2 + img_h = text_h + abs(shadow_offset[1]) + extra * 2 + padding * 2 + + shadow_img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0)) + shadow_draw = ImageDraw.Draw(shadow_img) + + x = padding + extra + y = padding + extra + + # 阴影 + shadow_rgb = _hex_to_rgb(shadow_color) + (180,) + shadow_draw.text((x + shadow_offset[0], y + shadow_offset[1]), text, font=font, fill=shadow_rgb) + shadow_img = shadow_img.filter(ImageFilter.GaussianBlur(shadow_blur)) + + draw = ImageDraw.Draw(shadow_img) + + # 可选描边(外层深色或浅色) + if stroke_color and stroke_width > 0: + stroke_rgb = _hex_to_rgb(stroke_color) + (255,) + for dx in range(-stroke_width, stroke_width + 1): + for dy in range(-stroke_width, stroke_width + 1): + if dx * dx + dy * dy <= stroke_width * stroke_width: + draw.text((x + dx, y + dy), text, font=font, fill=stroke_rgb) + + # 主文字 + font_rgb = _hex_to_rgb(font_color) + (255,) + draw.text((x, y), text, font=font, fill=font_rgb) + + return shadow_img + + +def create_text_with_gradient( + text: str, + font_size: int = 60, + gradient_colors: List[str] = None, + gradient_direction: str = "vertical", # vertical, horizontal + stroke_color: str = "#000000", + stroke_width: int = 3, + font_name: str = None, + padding: int = 20 +) -> Image.Image: + """ + 创建渐变色文字图片 + + Args: + gradient_colors: 渐变颜色列表,如 ["#FF6B6B", "#FFE66D"] + gradient_direction: 渐变方向 + """ + if not gradient_colors: + gradient_colors = ["#FF6B6B", "#FFE66D"] # 默认红黄渐变 + + font = _get_font(font_name, font_size) + text_w, text_h = _get_text_size(text, font) + + img_w = text_w + stroke_width * 2 + padding * 2 + img_h = text_h + stroke_width * 2 + padding * 2 + + # 创建渐变图层 + gradient = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0)) + gradient_draw = ImageDraw.Draw(gradient) + + # 生成渐变 + colors = [_hex_to_rgb(c) for c in gradient_colors] + + for i in range(img_h if gradient_direction == "vertical" else img_w): + ratio = i / (img_h if gradient_direction == "vertical" else img_w) + # 线性插值颜色 + if ratio < 0.5: + r = ratio * 2 + c1, c2 = colors[0], colors[min(1, len(colors) - 1)] + else: + r = (ratio - 0.5) * 2 + c1 = colors[min(1, len(colors) - 1)] + c2 = colors[min(2, len(colors) - 1)] if len(colors) > 2 else c1 + + color = tuple(int(c1[j] + (c2[j] - c1[j]) * r) for j in range(3)) + (255,) + + if gradient_direction == "vertical": + gradient_draw.line([(0, i), (img_w, i)], fill=color) + else: + gradient_draw.line([(i, 0), (i, img_h)], fill=color) + + # 创建文字蒙版 + mask = Image.new("L", (img_w, img_h), 0) + mask_draw = ImageDraw.Draw(mask) + + x = padding + stroke_width + y = padding + stroke_width + + # 先绘制描边蒙版 + for dx in range(-stroke_width, stroke_width + 1): + for dy in range(-stroke_width, stroke_width + 1): + if dx * dx + dy * dy <= stroke_width * stroke_width: + mask_draw.text((x + dx, y + dy), text, font=font, fill=128) + + # 主文字蒙版 + mask_draw.text((x, y), text, font=font, fill=255) + + # 创建结果图片 + result = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0)) + + # 绘制描边 + stroke_img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0)) + stroke_draw = ImageDraw.Draw(stroke_img) + stroke_rgb = _hex_to_rgb(stroke_color) + (255,) + + for dx in range(-stroke_width, stroke_width + 1): + for dy in range(-stroke_width, stroke_width + 1): + if dx * dx + dy * dy <= stroke_width * stroke_width: + stroke_draw.text((x + dx, y + dy), text, font=font, fill=stroke_rgb) + + result = Image.alpha_composite(result, stroke_img) + + # 应用渐变到文字 + text_mask = Image.new("L", (img_w, img_h), 0) + ImageDraw.Draw(text_mask).text((x, y), text, font=font, fill=255) + + gradient_text = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0)) + gradient_text.paste(gradient, mask=text_mask) + + result = Image.alpha_composite(result, gradient_text) + + return result + + +def create_bubble_text( + text: str, + font_size: int = 48, + font_color: str = "#333333", + bg_color: str = "#FFFFFF", + border_color: str = "#CCCCCC", + border_width: int = 2, + corner_radius: int = 20, + padding: Tuple[int, int] = (30, 15), + font_name: str = None, + tail_direction: str = None # "left", "right", "bottom", None +) -> Image.Image: + """ + 创建气泡框文字(对话框效果) + + Args: + tail_direction: 气泡尾巴方向 + """ + font = _get_font(font_name, font_size) + text_w, text_h = _get_text_size(text, font) + + # 气泡尺寸 + bubble_w = text_w + padding[0] * 2 + bubble_h = text_h + padding[1] * 2 + + # 增加尾巴空间 + tail_size = 20 if tail_direction else 0 + + if tail_direction in ["left", "right"]: + img_w = bubble_w + tail_size + img_h = bubble_h + elif tail_direction == "bottom": + img_w = bubble_w + img_h = bubble_h + tail_size + else: + img_w = bubble_w + img_h = bubble_h + + img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + # 气泡位置 + if tail_direction == "left": + bx = tail_size + else: + bx = 0 + by = 0 + + # 绘制圆角矩形 + bg_rgb = _hex_to_rgb(bg_color) + (255,) + border_rgb = _hex_to_rgb(border_color) + (255,) + + # 使用圆角矩形 + draw.rounded_rectangle( + [bx, by, bx + bubble_w, by + bubble_h], + radius=corner_radius, + fill=bg_rgb, + outline=border_rgb, + width=border_width + ) + + # 绘制尾巴 + if tail_direction == "left": + points = [ + (bx, bubble_h // 2 - 10), + (0, bubble_h // 2), + (bx, bubble_h // 2 + 10) + ] + draw.polygon(points, fill=bg_rgb, outline=border_rgb) + # 覆盖边框内部分 + draw.polygon(points, fill=bg_rgb) + elif tail_direction == "right": + points = [ + (bx + bubble_w, bubble_h // 2 - 10), + (img_w, bubble_h // 2), + (bx + bubble_w, bubble_h // 2 + 10) + ] + draw.polygon(points, fill=bg_rgb, outline=border_rgb) + draw.polygon(points, fill=bg_rgb) + elif tail_direction == "bottom": + points = [ + (bubble_w // 2 - 10, bubble_h), + (bubble_w // 2, img_h), + (bubble_w // 2 + 10, bubble_h) + ] + draw.polygon(points, fill=bg_rgb, outline=border_rgb) + draw.polygon(points, fill=bg_rgb) + + # 绘制文字 + font_rgb = _hex_to_rgb(font_color) + (255,) + text_x = bx + padding[0] + text_y = by + padding[1] + draw.text((text_x, text_y), text, font=font, fill=font_rgb) + + return img + + +def create_price_tag( + price: str, + currency: str = "¥", + font_size: int = 72, + price_color: str = "#FF4444", + currency_color: str = "#FF4444", + stroke_color: str = "#FFFFFF", + stroke_width: int = 4, + font_name: str = None +) -> Image.Image: + """ + 创建价格标签(电商风格) + """ + font_large = _get_font(font_name, font_size) + font_small = _get_font(font_name, int(font_size * 0.5)) + + # 测量尺寸 + currency_w, currency_h = _get_text_size(currency, font_small) + price_w, price_h = _get_text_size(price, font_large) + + total_w = currency_w + price_w + 5 + total_h = max(currency_h, price_h) + + padding = stroke_width + 10 + img_w = total_w + padding * 2 + img_h = total_h + padding * 2 + + img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + # 绘制描边 + stroke_rgb = _hex_to_rgb(stroke_color) + (255,) + for dx in range(-stroke_width, stroke_width + 1): + for dy in range(-stroke_width, stroke_width + 1): + if dx * dx + dy * dy <= stroke_width * stroke_width: + # 货币符号 + draw.text( + (padding + dx, padding + (total_h - currency_h) // 2 + dy), + currency, font=font_small, fill=stroke_rgb + ) + # 价格 + draw.text( + (padding + currency_w + 5 + dx, padding + (total_h - price_h) // 2 + dy), + price, font=font_large, fill=stroke_rgb + ) + + # 绘制文字 + currency_rgb = _hex_to_rgb(currency_color) + (255,) + price_rgb = _hex_to_rgb(price_color) + (255,) + + draw.text( + (padding, padding + (total_h - currency_h) // 2), + currency, font=font_small, fill=currency_rgb + ) + draw.text( + (padding + currency_w + 5, padding + (total_h - price_h) // 2), + price, font=font_large, fill=price_rgb + ) + + return img + + +def create_button( + text: str, + font_size: int = 36, + font_color: str = "#FFFFFF", + bg_color: str = "#FF6B35", + corner_radius: int = 25, + padding: Tuple[int, int] = (40, 15), + font_name: str = None, + shadow: bool = True +) -> Image.Image: + """ + 创建按钮样式文字(如"立即抢购") + """ + font = _get_font(font_name, font_size) + text_w, text_h = _get_text_size(text, font) + + btn_w = text_w + padding[0] * 2 + btn_h = text_h + padding[1] * 2 + + shadow_offset = 4 if shadow else 0 + img_w = btn_w + shadow_offset + img_h = btn_h + shadow_offset + + img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + # 绘制阴影 + if shadow: + shadow_color = (0, 0, 0, 80) + draw.rounded_rectangle( + [shadow_offset, shadow_offset, btn_w + shadow_offset, btn_h + shadow_offset], + radius=corner_radius, + fill=shadow_color + ) + + # 绘制按钮背景 + bg_rgb = _hex_to_rgb(bg_color) + (255,) + draw.rounded_rectangle( + [0, 0, btn_w, btn_h], + radius=corner_radius, + fill=bg_rgb + ) + + # 绘制文字 + font_rgb = _hex_to_rgb(font_color) + (255,) + text_x = padding[0] + text_y = padding[1] + draw.text((text_x, text_y), text, font=font, fill=font_rgb) + + return img + + +def create_comparison_text( + left_text: str, + right_text: str, + vs_text: str = "vs", + font_size: int = 48, + left_color: str = "#666666", + right_color: str = "#FF6B35", + vs_color: str = "#FF0000", + font_name: str = None +) -> Image.Image: + """ + 创建对比文字(如"塌马尾 vs 高颅顶") + """ + font = _get_font(font_name, font_size) + font_vs = _get_font(font_name, int(font_size * 0.8)) + + left_w, left_h = _get_text_size(left_text, font) + vs_w, vs_h = _get_text_size(vs_text, font_vs) + right_w, right_h = _get_text_size(right_text, font) + + spacing = 15 + total_w = left_w + vs_w + right_w + spacing * 2 + total_h = max(left_h, vs_h, right_h) + + padding = 20 + stroke_width = 3 + img_w = total_w + padding * 2 + stroke_width * 2 + img_h = total_h + padding * 2 + stroke_width * 2 + + img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + x = padding + stroke_width + y = padding + stroke_width + + # 描边 + stroke_color = (0, 0, 0, 255) + for dx in range(-stroke_width, stroke_width + 1): + for dy in range(-stroke_width, stroke_width + 1): + if dx * dx + dy * dy <= stroke_width * stroke_width: + draw.text((x + dx, y + (total_h - left_h) // 2 + dy), left_text, font=font, fill=stroke_color) + draw.text((x + left_w + spacing + dx, y + (total_h - vs_h) // 2 + dy), vs_text, font=font_vs, fill=stroke_color) + draw.text((x + left_w + spacing + vs_w + spacing + dx, y + (total_h - right_h) // 2 + dy), right_text, font=font, fill=stroke_color) + + # 绘制文字 + left_rgb = _hex_to_rgb(left_color) + (255,) + vs_rgb = _hex_to_rgb(vs_color) + (255,) + right_rgb = _hex_to_rgb(right_color) + (255,) + + draw.text((x, y + (total_h - left_h) // 2), left_text, font=font, fill=left_rgb) + draw.text((x + left_w + spacing, y + (total_h - vs_h) // 2), vs_text, font=font_vs, fill=vs_rgb) + draw.text((x + left_w + spacing + vs_w + spacing, y + (total_h - right_h) // 2), right_text, font=font, fill=right_rgb) + + return img + + +# ============================================================ +# 预设样式 +# ============================================================ + +PRESET_STYLES = { + "subtitle": { + "font_size": 48, + "font_color": "#FFFFFF", + "stroke_color": "#000000", + "stroke_width": 3, + "version": "v2" + }, + "highlight": { + # 暖米白主色 + 浅描边 + 暗色阴影,匹配浅棕背景 + "font_size": 90, + "font_color": "#F7E7D3", + "stroke_color": "#C9B59A", # 浅描边 + "stroke_width": 4, + "type": "shadow", + "shadow_color": "#3A2C1F", # 暗棕阴影 + "shadow_offset": (3, 3), + "shadow_blur": 10, + "padding": 32, + "version": "gloda" + }, + "warning": { + # 低饱和陶土红 + 米色描边 + 暗棕阴影 + "font_size": 80, + "font_color": "#D96B4F", + "stroke_color": "#F6E5D6", + "stroke_width": 4, + "type": "shadow", + "shadow_color": "#3A2C1F", + "shadow_offset": (3, 3), + "shadow_blur": 10, + "padding": 30, + "version": "gloda" + }, + "success": { + "font_size": 52, + "font_color": "#4CAF50", + "stroke_color": "#FFFFFF", + "stroke_width": 4, + "version": "v2" + }, + "price": { + # 价格标签:温暖红 + 米白货币符号 + 暗描边 + "font_size": 110, + "price_color": "#E25B4F", + "currency_color": "#F6E5D6", + "stroke_color": "#3A2C1F", + "stroke_width": 8, + "type": "price", + "version": "gloda" + }, + "cta_button": { + # 暖橙按钮,轻阴影 + "font_size": 46, + "font_color": "#FFFFFF", + "bg_color": "#E6763A", + "corner_radius": 32, + "type": "button", + "shadow": True, + "version": "gloda" + } +} + + +def create_fancy_text( + text: str, + style: str = "subtitle", + custom_style: Dict[str, Any] = None, + cache: bool = True +) -> str: + """ + 创建花字图片的统一入口 + + Args: + text: 文字内容 + style: 预设样式名称 + custom_style: 自定义样式(覆盖预设) + cache: 是否缓存 + + Returns: + PNG 图片路径 + """ + # 合并样式 + base_style = PRESET_STYLES.get(style, PRESET_STYLES["subtitle"]).copy() + if custom_style: + base_style.update(custom_style) + + # 检查缓存 + if cache: + cache_name = _cache_key(text, base_style) + cache_path = FANCY_TEXT_CACHE_DIR / f"{cache_name}.png" + if cache_path.exists(): + return str(cache_path) + + # 根据样式类型创建图片 + style_type = base_style.pop("type", None) + + if style == "price" or style_type == "price": + img = create_price_tag(text, **{k: v for k, v in base_style.items() if k in [ + "currency", "font_size", "price_color", "currency_color", "stroke_color", "stroke_width", "font_name" + ]}) + elif style == "cta_button" or style_type == "button": + img = create_button(text, **{k: v for k, v in base_style.items() if k in [ + "font_size", "font_color", "bg_color", "corner_radius", "padding", "font_name", "shadow" + ]}) + elif style_type == "bubble": + img = create_bubble_text(text, **{k: v for k, v in base_style.items() if k in [ + "font_size", "font_color", "bg_color", "border_color", "border_width", + "corner_radius", "padding", "font_name", "tail_direction" + ]}) + elif style_type == "gradient": + img = create_text_with_gradient(text, **{k: v for k, v in base_style.items() if k in [ + "font_size", "gradient_colors", "gradient_direction", "stroke_color", "stroke_width", "font_name", "padding" + ]}) + elif style_type == "shadow": + img = create_text_with_shadow(text, **{k: v for k, v in base_style.items() if k in [ + "font_size", "font_color", "shadow_color", "shadow_offset", "shadow_blur", "font_name", "padding" + ]}) + else: + # 默认带描边文字 + img = create_text_with_stroke(text, **{k: v for k, v in base_style.items() if k in [ + "font_size", "font_color", "stroke_color", "stroke_width", "font_name", "padding" + ]}) + + # 保存 + if cache: + output_path = str(cache_path) + else: + output_path = str(config.TEMP_DIR / f"fancy_{hash(text)}_{os.getpid()}.png") + + img.save(output_path, "PNG") + logger.info(f"Created fancy text: '{text[:20]}...' -> {output_path}") + + return output_path + + +def batch_create_fancy_texts( + configs: List[Dict[str, Any]] +) -> List[str]: + """ + 批量创建花字图片 + + Args: + configs: 配置列表 [{text, style, custom_style}] + + Returns: + PNG 图片路径列表 + """ + paths = [] + for cfg in configs: + path = create_fancy_text( + text=cfg.get("text", ""), + style=cfg.get("style", "subtitle"), + custom_style=cfg.get("custom_style") + ) + paths.append(path) + return paths + diff --git a/modules/ffmpeg_utils.py b/modules/ffmpeg_utils.py new file mode 100644 index 0000000..4ed8d3c --- /dev/null +++ b/modules/ffmpeg_utils.py @@ -0,0 +1,960 @@ +""" +FFmpeg 视频处理工具模块 +支持规模化批量视频处理:拼接、字幕、叠加、混音 +""" +import os +import re +import subprocess +import tempfile +import logging +from pathlib import Path +from typing import List, Dict, Any, Optional, Tuple + +import config + +logger = logging.getLogger(__name__) + +# FFmpeg/FFprobe 路径(优先使用项目内的二进制) +FFMPEG_PATH = str(config.BASE_DIR / "bin" / "ffmpeg") if (config.BASE_DIR / "bin" / "ffmpeg").exists() else "ffmpeg" +FFPROBE_PATH = str(config.BASE_DIR / "bin" / "ffprobe") if (config.BASE_DIR / "bin" / "ffprobe").exists() else "ffprobe" + +# 字体路径优先使用项目自带中文字体,其次使用 Linux 系统字体,最后再回退到 macOS 路径 +DEFAULT_FONT_PATHS = [ + # 优先使用 Linux 系统级中文字体 (服务器环境最稳健) + "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", + "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", + + # 项目内字体 (注意:需确保文件不是 LFS 指针) + str(config.FONTS_DIR / "HarmonyOS-Sans-SC-Regular.ttf"), + str(config.FONTS_DIR / "AlibabaPuHuiTi-Regular.ttf"), + + # macOS 字体(仅本地调试生效) + "/System/Library/Fonts/PingFang.ttc", + "/System/Library/Fonts/STHeiti Medium.ttc", + "/System/Library/Fonts/Supplemental/Arial Unicode.ttf", +] + + +def _get_font_path() -> str: + for p in DEFAULT_FONT_PATHS: + if os.path.exists(p) and os.path.getsize(p) > 1000: + return p + return "Arial" # 极端情况下退回英文字体,避免崩溃 + + +def _sanitize_text(text: str) -> str: + """ + 去除可能导致 ffmpeg 命令行错误的特殊控制字符,但保留 Emoji、数字、标点和各国语言。 + """ + if not text: + return "" + + # 不再过滤任何字符,只确保不是 None + return text + + +def add_silence_audio(video_path: str, output_path: str) -> str: + """ + 给无音轨的视频补一条静音轨(立体声 44.1k),避免后续 filter 找不到 0:a + """ + cmd = [ + FFMPEG_PATH, "-y", + "-i", video_path, + "-f", "lavfi", + "-i", "anullsrc=channel_layout=stereo:sample_rate=44100", + "-shortest", + "-c:v", "copy", + "-c:a", "aac", + output_path + ] + _run_ffmpeg(cmd) + return output_path + + +def _run_ffmpeg(cmd: List[str], check: bool = True) -> subprocess.CompletedProcess: + """执行 FFmpeg 命令""" + logger.debug(f"FFmpeg command: {' '.join(cmd)}") + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=check + ) + # 无论成功失败,输出 stderr 以便排查字体等警告 + if result.stderr: + print(f"[FFmpeg stderr] {result.stderr}", flush=True) + if result.returncode != 0: + logger.error(f"FFmpeg stderr: {result.stderr}") + return result + except subprocess.CalledProcessError as e: + logger.error(f"FFmpeg failed: {e.stderr}") + raise + + +def get_video_info(video_path: str) -> Dict[str, Any]: + """获取视频信息(时长、分辨率、帧率等)""" + cmd = [ + FFPROBE_PATH, + "-v", "quiet", + "-print_format", "json", + "-show_format", + "-show_streams", + video_path + ] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise ValueError(f"Failed to probe video: {video_path}") + + import json + data = json.loads(result.stdout) + + # 提取关键信息 + info = { + "duration": float(data.get("format", {}).get("duration", 0)), + "width": 0, + "height": 0, + "fps": 30 + } + + for stream in data.get("streams", []): + if stream.get("codec_type") == "video": + info["width"] = stream.get("width", 0) + info["height"] = stream.get("height", 0) + # 解析帧率 (如 "30/1" 或 "29.97") + fps_str = stream.get("r_frame_rate", "30/1") + if "/" in fps_str: + num, den = fps_str.split("/") + info["fps"] = float(num) / float(den) if float(den) != 0 else 30 + else: + info["fps"] = float(fps_str) + break + + return info + + +def concat_videos( + video_paths: List[str], + output_path: str, + target_size: Tuple[int, int] = (1080, 1920) +) -> str: + """ + 使用 FFmpeg concat demuxer 拼接多段视频 + + Args: + video_paths: 视频文件路径列表 + output_path: 输出文件路径 + target_size: 目标分辨率 (width, height),默认竖屏 1080x1920 + + Returns: + 输出文件路径 + """ + if not video_paths: + raise ValueError("No video paths provided") + + logger.info(f"Concatenating {len(video_paths)} videos...") + + # 创建 concat 文件列表 + concat_file = config.TEMP_DIR / f"concat_{os.getpid()}.txt" + + with open(concat_file, "w", encoding="utf-8") as f: + for vp in video_paths: + # 使用绝对路径并转义单引号 + abs_path = os.path.abspath(vp) + f.write(f"file '{abs_path}'\n") + + width, height = target_size + + # 使用 filter_complex 统一分辨率后拼接 + # 每个视频先 scale + pad 到目标尺寸 + filter_parts = [] + for i in range(len(video_paths)): + # scale 保持宽高比,pad 填充黑边居中 + filter_parts.append( + f"[{i}:v]scale={width}:{height}:force_original_aspect_ratio=decrease," + f"pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black,setsar=1[v{i}]" + ) + + # 拼接所有视频流 + concat_inputs = "".join([f"[v{i}]" for i in range(len(video_paths))]) + filter_parts.append(f"{concat_inputs}concat=n={len(video_paths)}:v=1:a=0[outv]") + + filter_complex = ";".join(filter_parts) + + # 构建 ffmpeg 命令 + cmd = [FFMPEG_PATH, "-y"] + for vp in video_paths: + cmd.extend(["-i", vp]) + + cmd.extend([ + "-filter_complex", filter_complex, + "-map", "[outv]", + "-c:v", "libx264", + "-preset", "fast", + "-crf", "23", + "-pix_fmt", "yuv420p", + output_path + ]) + + _run_ffmpeg(cmd) + + # 清理临时文件 + if concat_file.exists(): + concat_file.unlink() + + logger.info(f"Concatenated video saved: {output_path}") + return output_path + + +def concat_videos_with_audio( + video_paths: List[str], + output_path: str, + target_size: Tuple[int, int] = (1080, 1920) +) -> str: + """ + 拼接视频并保留音频轨道 + """ + if not video_paths: + raise ValueError("No video paths provided") + + logger.info(f"Concatenating {len(video_paths)} videos with audio...") + + width, height = target_size + n = len(video_paths) + + # 构建 filter_complex + filter_parts = [] + + # 视频处理 + for i in range(n): + filter_parts.append( + f"[{i}:v]scale={width}:{height}:force_original_aspect_ratio=decrease," + f"pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black,setsar=1[v{i}]" + ) + + # 音频处理(静音填充如果没有音频) + for i in range(n): + filter_parts.append(f"[{i}:a]aformat=sample_rates=44100:channel_layouts=stereo[a{i}]") + + # 拼接 + v_concat = "".join([f"[v{i}]" for i in range(n)]) + a_concat = "".join([f"[a{i}]" for i in range(n)]) + filter_parts.append(f"{v_concat}concat=n={n}:v=1:a=0[outv]") + filter_parts.append(f"{a_concat}concat=n={n}:v=0:a=1[outa]") + + filter_complex = ";".join(filter_parts) + + cmd = [FFMPEG_PATH, "-y"] + for vp in video_paths: + cmd.extend(["-i", vp]) + + cmd.extend([ + "-filter_complex", filter_complex, + "-map", "[outv]", + "-map", "[outa]", + "-c:v", "libx264", + "-preset", "fast", + "-crf", "23", + "-c:a", "aac", + "-b:a", "128k", + "-pix_fmt", "yuv420p", + output_path + ]) + + try: + _run_ffmpeg(cmd) + except subprocess.CalledProcessError: + # 如果音频拼接失败,回退到无音频版本 + logger.warning("Audio concat failed, falling back to video only") + return concat_videos(video_paths, output_path, target_size) + + logger.info(f"Concatenated video with audio saved: {output_path}") + return output_path + + +def add_subtitle( + video_path: str, + text: str, + start: float, + duration: float, + output_path: str, + style: Dict[str, Any] = None +) -> str: + """ + 使用 drawtext filter 添加单条字幕 + + Args: + video_path: 输入视频路径 + text: 字幕文本 + start: 开始时间(秒) + duration: 持续时间(秒) + output_path: 输出路径 + style: 样式配置 { + fontsize: 字体大小, + fontcolor: 字体颜色, + borderw: 描边宽度, + bordercolor: 描边颜色, + x: x位置 (可用表达式如 "(w-text_w)/2"), + y: y位置, + font: 字体路径或名称 + } + + Returns: + 输出文件路径 + """ + style = style or {} + + # 默认样式 + fontsize = style.get("fontsize", 48) + fontcolor = style.get("fontcolor", "white") + borderw = style.get("borderw", 3) + bordercolor = style.get("bordercolor", "black") + x = style.get("x", "(w-text_w)/2") # 默认水平居中 + y = style.get("y", "h-200") # 默认底部偏上 + + # 优先使用动态检测到的有效字体,而不是硬编码的可能损坏的路径 + default_font_path = _get_font_path() + font = style.get("font", default_font_path) + + # 转义特殊字符 + escaped_text = text.replace("'", "\\'").replace(":", "\\:") + + # drawtext filter + drawtext = ( + f"drawtext=text='{escaped_text}':" + f"fontfile='{font}':" + f"fontsize={fontsize}:" + f"fontcolor={fontcolor}:" + f"borderw={borderw}:" + f"bordercolor={bordercolor}:" + f"x={x}:y={y}:" + f"enable='between(t,{start},{start + duration})'" + ) + + cmd = [ + FFMPEG_PATH, "-y", + "-i", video_path, + "-vf", drawtext, + "-c:v", "libx264", + "-preset", "fast", + "-crf", "23", + "-c:a", "copy", + "-pix_fmt", "yuv420p", + output_path + ] + + _run_ffmpeg(cmd) + logger.info(f"Added subtitle: '{text[:20]}...' at {start}s") + return output_path + + +def wrap_text(text: str, max_chars: int = 18) -> str: + """ + 简单的文本换行处理 + """ + if not text: return "" + + # 如果已经有换行符,假设用户已经手动处理 + if "\n" in text: + return text + + result = "" + count = 0 + for char in text: + if count >= max_chars: + result += "\n" + count = 0 + result += char + # 简单估算:中文算1个,英文也算1个(等宽字体) + # 实际上中英文混合较复杂,这里简化处理 + count += 1 + return result + + +def mix_audio_at_offset( + base_audio: str, + overlay_audio: str, + offset: float, + output_path: str, + base_volume: float = 1.0, + overlay_volume: float = 1.0 +) -> str: + """ + 在指定偏移位置混合音频 + """ + # 如果 base_audio 不存在,创建一个静音底 + if not os.path.exists(base_audio): + logger.warning(f"Base audio not found: {base_audio}") + return overlay_audio + + cmd = [ + FFMPEG_PATH, "-y", + "-i", base_audio, + "-i", overlay_audio, + "-filter_complex", + f"[0:a]volume={base_volume}[a0];[1:a]volume={overlay_volume},adelay={int(offset*1000)}|{int(offset*1000)}[a1];[a0][a1]amix=inputs=2:duration=first:dropout_transition=0:normalize=0[out]", + "-map", "[out]", + "-c:a", "mp3", # Use MP3 for audio only mixing + output_path + ] + _run_ffmpeg(cmd) + return output_path + + +def adjust_audio_duration( + input_path: str, + target_duration: float, + output_path: str +) -> str: + """ + 调整音频时长(仅在音频过长时加速,音频较短时保持原速) + + 用户需求: + - 音频时长 > 目标时长 → 加速播放 + - 音频时长 <= 目标时长 → 保持原速(不慢放) + """ + if not os.path.exists(input_path): + return None + + current_duration = float(get_audio_info(input_path).get("duration", 0)) + if current_duration <= 0: + return input_path + + # 只在音频过长时才加速,音频较短时保持原速 + if current_duration <= target_duration: + # 音频时长 <= 目标时长,不需要调整,直接复制 + import shutil + shutil.copy(input_path, output_path) + logger.info(f"Audio ({current_duration:.2f}s) <= target ({target_duration:.2f}s), keeping original speed") + return output_path + + # 音频过长,需要加速 + speed_ratio = current_duration / target_duration + + # 限制加速范围 (最多2倍速),避免声音变调太严重 + speed_ratio = min(speed_ratio, 2.0) + + logger.info(f"Audio ({current_duration:.2f}s) > target ({target_duration:.2f}s), speeding up {speed_ratio:.2f}x") + + cmd = [ + FFMPEG_PATH, "-y", + "-i", input_path, + "-filter:a", f"atempo={speed_ratio}", + output_path + ] + _run_ffmpeg(cmd) + return output_path + + +def get_audio_info(file_path: str) -> Dict[str, Any]: + """获取音频信息""" + return get_video_info(file_path) + + +def wrap_text_smart(text: str, max_chars: int = 15) -> str: + """ + 智能字幕换行(上短下长策略) + """ + if not text or len(text) <= max_chars: + return text + + # 优先在标点或空格处换行 + split_chars = [",", "。", "!", "?", " ", ",", ".", "!", "?"] + best_split = -1 + + # 寻找中间附近的分割点 + mid = len(text) // 2 + + for i in range(len(text)): + if text[i] in split_chars: + # 偏好后半部分(上短下长) + if abs(i - mid) < abs(best_split - mid): + best_split = i + + if best_split != -1 and best_split < len(text) - 1: + return text[:best_split+1] + "\n" + text[best_split+1:] + + # 强制换行(上短下长) + split_idx = int(len(text) * 0.4) # 上面 40% + return text[:split_idx] + "\n" + text[split_idx:] + + +def add_multiple_subtitles( + video_path: str, + subtitles: List[Dict[str, Any]], + output_path: str, + default_style: Dict[str, Any] = None +) -> str: + """ + 添加多条字幕 + """ + if not subtitles: + # 无字幕直接复制 + import shutil + shutil.copy(video_path, output_path) + return output_path + + default_style = default_style or {} + # 强制使用完整字体(先用项目内 NotoSansSC,如果不存在则回退 Droid) + font = "/root/video-flow/assets/fonts/NotoSansSC-Regular.otf" + if not (os.path.exists(font) and os.path.getsize(font) > 1024 * 100): # 至少100KB以上认为有效 + font = "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf" + if not (os.path.exists(font) and os.path.getsize(font) > 1024 * 100): + font = _get_font_path() + + print(f"[SubDebug] Using font for subtitles: {font}", flush=True) + + # 构建多个 drawtext filter + filters = [] + for sub in subtitles: + raw_text = sub.get("text", "") + # 打印原始文本的 repr 和 hex,以便排查特殊字符 + print(f"[SubDebug] Subtitle text repr: {repr(raw_text)}", flush=True) + print(f"[SubDebug] Subtitle text hex: {' '.join(hex(ord(c)) for c in raw_text)}", flush=True) + + text = _sanitize_text(raw_text) + # 自动换行 + text = wrap_text(text) + + start = sub.get("start", 0) + duration = sub.get("duration", 3) + style = {**default_style, **sub.get("style", {})} + + fontsize = style.get("fontsize", 48) + fontcolor = style.get("fontcolor", "white") + borderw = style.get("borderw", 3) + bordercolor = style.get("bordercolor", "black") + x = style.get("x", "(w-text_w)/2") + y = style.get("y", "h-200") + + # 默认启用背景框以提高可读性 + box = style.get("box", 1) + boxcolor = style.get("boxcolor", "black@0.5") + boxborderw = style.get("boxborderw", 10) + + # 转义:反斜杠、单引号、冒号、百分号 + escaped_text = text.replace("\\", "\\\\").replace("'", "\\'").replace(":", "\\:").replace("%", "\\%") + + drawtext = ( + f"drawtext=text='{escaped_text}':" + f"fontfile='{font}':" + f"fontsize={fontsize}:" + f"fontcolor={fontcolor}:" + f"borderw={borderw}:" + f"bordercolor={bordercolor}:" + f"box={box}:boxcolor={boxcolor}:boxborderw={boxborderw}:" + f"x={x}:y={y}:" + f"enable='between(t,{start},{start + duration})'" + ) + filters.append(drawtext) + + # 用逗号连接多个 filter + vf = ",".join(filters) + + cmd = [ + FFMPEG_PATH, "-y", + "-i", video_path, + "-vf", vf, + "-c:v", "libx264", + "-preset", "fast", + "-crf", "23", + "-c:a", "copy", + "-pix_fmt", "yuv420p", + output_path + ] + + _run_ffmpeg(cmd) + logger.info(f"Added {len(subtitles)} subtitles") + return output_path + + +def overlay_image( + video_path: str, + image_path: str, + output_path: str, + position: Tuple[int, int] = None, + start: float = 0, + duration: float = None, + fade_in: float = 0, + fade_out: float = 0 +) -> str: + """ + 叠加透明PNG图片(花字、水印等)到视频 + + Args: + video_path: 输入视频路径 + image_path: PNG图片路径(支持透明通道) + output_path: 输出路径 + position: (x, y) 位置,None则居中 + start: 开始时间(秒) + duration: 持续时间(秒),None则到视频结束 + fade_in: 淡入时间(秒) + fade_out: 淡出时间(秒) + + Returns: + 输出文件路径 + """ + # 获取视频信息 + info = get_video_info(video_path) + video_duration = info["duration"] + + if duration is None: + duration = video_duration - start + + # 位置 + if position: + x, y = position + pos_str = f"x={x}:y={y}" + else: + pos_str = "x=(W-w)/2:y=(H-h)/2" # 居中 + + # 时间控制 + enable = f"enable='between(t,{start},{start + duration})'" + + # 构建 overlay filter + overlay_filter = f"overlay={pos_str}:{enable}" + + # 添加淡入淡出效果 + if fade_in > 0 or fade_out > 0: + fade_filter = [] + if fade_in > 0: + fade_filter.append(f"fade=t=in:st={start}:d={fade_in}:alpha=1") + if fade_out > 0: + fade_out_start = start + duration - fade_out + fade_filter.append(f"fade=t=out:st={fade_out_start}:d={fade_out}:alpha=1") + + img_filter = ",".join(fade_filter) if fade_filter else "" + filter_complex = f"[1:v]{img_filter}[img];[0:v][img]{overlay_filter}[outv]" + else: + filter_complex = f"[0:v][1:v]{overlay_filter}[outv]" + + cmd = [ + FFMPEG_PATH, "-y", + "-i", video_path, + "-i", image_path, + "-filter_complex", filter_complex, + "-map", "[outv]", + "-map", "0:a?", + "-c:v", "libx264", + "-preset", "fast", + "-crf", "23", + "-c:a", "copy", + "-pix_fmt", "yuv420p", + output_path + ] + + _run_ffmpeg(cmd) + logger.info(f"Overlaid image at {position or 'center'}, {start}s-{start+duration}s") + return output_path + + +def overlay_multiple_images( + video_path: str, + images: List[Dict[str, Any]], + output_path: str +) -> str: + """ + 叠加多个透明PNG图片 + + Args: + video_path: 输入视频路径 + images: 图片配置列表 [{path, x, y, start, duration}] + output_path: 输出路径 + + Returns: + 输出文件路径 + """ + if not images: + import shutil + shutil.copy(video_path, output_path) + return output_path + + # 构建复杂 filter_complex + inputs = ["-i", video_path] + for img in images: + inputs.extend(["-i", img["path"]]) + + # 链式 overlay + filter_parts = [] + prev_output = "0:v" + + for i, img in enumerate(images): + x = img.get("x", "(W-w)/2") + y = img.get("y", "(H-h)/2") + start = img.get("start", 0) + duration = img.get("duration", 999) + + enable = f"enable='between(t,{start},{start + duration})'" + + if i == len(images) - 1: + out_label = "outv" + else: + out_label = f"tmp{i}" + + filter_parts.append( + f"[{prev_output}][{i+1}:v]overlay=x={x}:y={y}:{enable}[{out_label}]" + ) + prev_output = out_label + + filter_complex = ";".join(filter_parts) + + cmd = [FFMPEG_PATH, "-y"] + inputs + [ + "-filter_complex", filter_complex, + "-map", "[outv]", + "-map", "0:a?", + "-c:v", "libx264", + "-preset", "fast", + "-crf", "23", + "-c:a", "copy", + "-pix_fmt", "yuv420p", + output_path + ] + + _run_ffmpeg(cmd) + logger.info(f"Overlaid {len(images)} images") + return output_path + + +def mix_audio( + video_path: str, + audio_path: str, + output_path: str, + audio_volume: float = 1.0, + video_volume: float = 0.1, + audio_start: float = 0 +) -> str: + """ + 混合音频到视频(旁白、BGM等) + + Args: + video_path: 输入视频路径 + audio_path: 音频文件路径 + output_path: 输出路径 + audio_volume: 新音频音量(0-1) + video_volume: 原视频音量(0-1) + audio_start: 音频开始时间(秒) + + Returns: + 输出文件路径 + """ + logger.info(f"Mixing audio: {audio_path}") + + # 检查视频是否有音频轨道 + info = get_video_info(video_path) + video_duration = info["duration"] + + # 构建 filter_complex + # adelay 用于延迟音频开始时间(毫秒) + delay_ms = int(audio_start * 1000) + + filter_complex = ( + f"[0:a]volume={video_volume}[va];" + f"[1:a]adelay={delay_ms}|{delay_ms},volume={audio_volume}[aa];" + f"[va][aa]amix=inputs=2:duration=longest:dropout_transition=0:normalize=0[outa]" + ) + + cmd = [ + FFMPEG_PATH, "-y", + "-i", video_path, + "-i", audio_path, + "-filter_complex", filter_complex, + "-map", "0:v", + "-map", "[outa]", + "-c:v", "copy", + "-c:a", "aac", + "-b:a", "192k", + output_path + ] + + try: + _run_ffmpeg(cmd) + except subprocess.CalledProcessError: + # 如果原视频没有音频,直接添加新音频 + logger.warning("Video has no audio track, adding audio directly") + cmd = [ + FFMPEG_PATH, "-y", + "-i", video_path, + "-i", audio_path, + "-map", "0:v", + "-map", "1:a", + "-c:v", "copy", + "-c:a", "aac", + "-b:a", "192k", + output_path + ] + _run_ffmpeg(cmd) + + logger.info(f"Audio mixed: {output_path}") + return output_path + + +def add_bgm( + video_path: str, + bgm_path: str, + output_path: str, + bgm_volume: float = 0.06, + loop: bool = True, + ducking: bool = True, + duck_gain_db: float = -6.0, + fade_in: float = 1.0, + fade_out: float = 1.0 +) -> str: + """ + 添加背景音乐(自动循环以匹配视频长度) + + Args: + video_path: 输入视频路径 + bgm_path: BGM文件路径 + output_path: 输出路径 + bgm_volume: BGM音量 + loop: 是否循环BGM + """ + info = get_video_info(video_path) + video_duration = info["duration"] + + if loop: + bgm_chain = ( + f"[1:a]aloop=-1:size=2e+09,asetpts=N/SR/TB," + f"atrim=0:{video_duration}," + f"afade=t=in:st=0:d={fade_in}," + f"afade=t=out:st={max(video_duration - fade_out, 0)}:d={fade_out}," + f"volume={bgm_volume}[bgm]" + ) + else: + bgm_chain = ( + f"[1:a]" + f"afade=t=in:st=0:d={fade_in}," + f"afade=t=out:st={max(video_duration - fade_out, 0)}:d={fade_out}," + f"volume={bgm_volume}[bgm]" + ) + + if ducking: + # 使用安全参数的 sidechaincompress,避免 unsupported 参数 + filter_complex = ( + f"{bgm_chain};" + f"[0:a][bgm]sidechaincompress=threshold=0.1:ratio=4:attack=5:release=250:makeup=1:mix=1:level_in=1:level_sc=1[outa]" + ) + else: + filter_complex = f"{bgm_chain};[0:a][bgm]amix=inputs=2:duration=first[outa]" + + cmd = [ + FFMPEG_PATH, "-y", + "-i", video_path, + "-stream_loop", "-1" if loop else "0", + "-i", bgm_path, + "-filter_complex", filter_complex, + "-map", "0:v", + "-map", "[outa]", + "-c:v", "copy", + "-c:a", "aac", + "-b:a", "192k", + "-t", str(video_duration), + output_path + ] + + try: + _run_ffmpeg(cmd) + except subprocess.CalledProcessError: + # sidechain失败时,回退为 amix(保留原有音频 + 低音量BGM) + logger.warning("Sidechain failed, fallback to simple amix for BGM") + filter_complex = f"{bgm_chain};[0:a][bgm]amix=inputs=2:duration=first[outa]" + cmd = [ + FFMPEG_PATH, "-y", + "-i", video_path, + "-stream_loop", "-1" if loop else "0", + "-i", bgm_path, + "-filter_complex", filter_complex, + "-map", "0:v", + "-map", "[outa]", + "-c:v", "copy", + "-c:a", "aac", + "-b:a", "192k", + "-t", str(video_duration), + output_path + ] + _run_ffmpeg(cmd) + + logger.info(f"BGM added: {output_path}") + return output_path + + +def trim_video( + video_path: str, + output_path: str, + start: float = 0, + duration: float = None, + end: float = None +) -> str: + """ + 裁剪视频 + + Args: + video_path: 输入视频路径 + output_path: 输出路径 + start: 开始时间(秒) + duration: 持续时间(秒) + end: 结束时间(秒),与 duration 二选一 + """ + cmd = [ + FFMPEG_PATH, "-y", + "-i", video_path, + "-ss", str(start) + ] + + if duration: + cmd.extend(["-t", str(duration)]) + elif end: + cmd.extend(["-to", str(end)]) + + cmd.extend([ + "-c:v", "libx264", + "-preset", "fast", + "-crf", "23", + "-c:a", "copy", + output_path + ]) + + _run_ffmpeg(cmd) + logger.info(f"Trimmed video: {start}s - {end or start + duration}s") + return output_path + + +def speed_up_video( + video_path: str, + output_path: str, + speed: float = 1.5 +) -> str: + """ + 加速/减速视频 + + Args: + video_path: 输入视频路径 + output_path: 输出路径 + speed: 速度倍率(>1 加速,<1 减速) + """ + # setpts 控制视频速度,atempo 控制音频速度 + video_filter = f"setpts={1/speed}*PTS" + + # atempo 只支持 0.5-2.0,超出需要链式处理 + if speed > 2.0: + audio_filter = "atempo=2.0,atempo=" + str(speed / 2.0) + elif speed < 0.5: + audio_filter = "atempo=0.5,atempo=" + str(speed / 0.5) + else: + audio_filter = f"atempo={speed}" + + cmd = [ + FFMPEG_PATH, "-y", + "-i", video_path, + "-vf", video_filter, + "-af", audio_filter, + "-c:v", "libx264", + "-preset", "fast", + "-crf", "23", + "-c:a", "aac", + output_path + ] + + _run_ffmpeg(cmd) + logger.info(f"Speed changed to {speed}x: {output_path}") + return output_path diff --git a/modules/image_gen.py b/modules/image_gen.py new file mode 100644 index 0000000..ac510b5 --- /dev/null +++ b/modules/image_gen.py @@ -0,0 +1,491 @@ +""" +连贯生图模块 (Volcengine Doubao) +负责根据分镜脚本和原始素材生成一系列连贯的分镜图片 +""" +import base64 +import logging +import os +import time +import requests +import json +from pathlib import Path +from typing import List, Dict, Any, Optional +from PIL import Image +import io +from modules import storage + +import config + +logger = logging.getLogger(__name__) + +class ImageGenerator: + """连贯图片生成器 (Volcengine Provider)""" + + def __init__(self): + self.api_key = config.VOLC_API_KEY + # Endpoint: https://ark.cn-beijing.volces.com/api/v3/images/generations + self.endpoint = f"https://ark.cn-beijing.volces.com/api/v3/images/generations" + self.model = config.IMAGE_MODEL_ID + + def _encode_image(self, image_path: str) -> str: + """读取图片,调整大小并转为 Base64""" + try: + with Image.open(image_path) as img: + if img.mode != 'RGB': + img = img.convert('RGB') + + max_size = 1024 + if max(img.size) > max_size: + img.thumbnail((max_size, max_size), Image.LANCZOS) + + buffer = io.BytesIO() + img.save(buffer, format="JPEG", quality=80) + return base64.b64encode(buffer.getvalue()).decode('utf-8') + except Exception as e: + logger.error(f"Error processing image {image_path}: {e}") + return "" + + def generate_single_scene_image( + self, + scene: Dict[str, Any], + original_image_path: Any, + previous_image_path: Optional[str] = None, + model_provider: str = "shubiaobiao", # "shubiaobiao", "gemini", "doubao" + visual_anchor: str = "" # 视觉锚点,强制拼接到 prompt 前 + ) -> Optional[str]: + """ + 生成单张分镜图片 (Public) + """ + scene_id = scene["id"] + visual_prompt = scene.get("visual_prompt", "") + + # 强制拼接 Visual Anchor (确保生图一致性) + if visual_anchor and visual_anchor not in visual_prompt: + visual_prompt = f"[{visual_anchor}] {visual_prompt}" + logger.info(f"Scene {scene_id}: Prepended visual_anchor to prompt") + + logger.info(f"Generating image for Scene {scene_id} (Provider: {model_provider})...") + + input_images = [] + + # Handle original_image_path (can be str or list) + if isinstance(original_image_path, list): + input_images.extend(original_image_path) + elif isinstance(original_image_path, str) and original_image_path: + input_images.append(original_image_path) + + if previous_image_path: + input_images.append(previous_image_path) + + try: + output_path = self._generate_single_image( + prompt=visual_prompt, + reference_images=input_images, + output_filename=f"scene_{scene_id}_{int(time.time())}.png", + provider=model_provider + ) + + if output_path: + return output_path + else: + raise RuntimeError(f"Image generation returned empty for Scene {scene_id}") + + except PermissionError as e: + logger.error(f"Critical API Error for Scene {scene_id}: {e}") + raise e + except Exception as e: + logger.error(f"Image generation failed for Scene {scene_id}: {e}") + raise e + + def generate_group_images_doubao( + self, + scenes: List[Dict[str, Any]], + reference_images: List[str], + visual_anchor: str = "" # 视觉锚点 + ) -> Dict[int, str]: + """ + Doubao 组图生成 (Batch) - 拼接 Prompt 一次生成多张 + """ + logger.info("Starting Doubao Group Image Generation...") + + # 1. 拼接 Prompts + # 格式: "Global: [Visual Anchor] ... | S1: ... | S2: ..." + + scene_prompts = [] + for scene in scenes: + # 提取分镜 Visual Prompt + p = scene.get("visual_prompt", "") + scene_prompts.append(f"S{scene['id']}:{p}") + + combined_scenes_text = " | ".join(scene_prompts) + + # 构造 Combined Prompt - 将 visual_anchor 放入 Global 部分 + global_context = f"[{visual_anchor}] Consistent product appearance & style." if visual_anchor else "Consistent product appearance & style." + combined_prompt = ( + f"Global: {global_context}\n" + f"{combined_scenes_text}\n" + "Req: 1 img per scene. Follow specific angles." + ) + + logger.info(f"Visual Anchor applied to group prompt: {visual_anchor[:50]}..." if visual_anchor else "No visual_anchor") + + # 记录 Prompt 长度供参考 + logger.info(f"Doubao Group Prompt Length: {len(combined_prompt)} chars") + + # 2. 准备 payload + payload = { + "model": config.DOUBAO_IMG_MODEL, + "prompt": combined_prompt, + "sequential_image_generation": "auto", # 开启组图 + "sequential_image_generation_options": { + "max_images": len(scenes) # 限制最大张数 + }, + "response_format": "url", + "size": "1440x2560", + "stream": False, + "watermark": False + } + + # 3. 处理参考图 + img_urls = [] + if reference_images: + for ref_path in reference_images: + if os.path.exists(ref_path): + try: + url = storage.upload_file(ref_path) + if url: img_urls.append(url) + except Exception as e: + logger.warning(f"Failed to upload ref image {ref_path}: {e}") + + if img_urls: + payload["image_urls"] = img_urls + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {config.VOLC_API_KEY}" + } + + try: + logger.info(f"Submitting Doubao Group Request (Scenes: {len(scenes)})...") + resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=240) + resp.raise_for_status() + + data = resp.json() + results = {} + + if "data" in data: + items = data["data"] + logger.info(f"Doubao returned {len(items)} images.") + + # 尝试将返回的图片映射回 Scene + # 假设顺序一致 + for i, item in enumerate(items): + if i < len(scenes): + scene_id = scenes[i]["id"] + image_url = item.get("url") + + if image_url: + # Download + img_resp = requests.get(image_url, timeout=60) + output_path = config.TEMP_DIR / f"scene_{scene_id}_{int(time.time())}.png" + with open(output_path, "wb") as f: + f.write(img_resp.content) + results[scene_id] = str(output_path) + + return results + + except Exception as e: + logger.error(f"Doubao Group Generation Failed: {e}") + raise e + + def _generate_single_image( + self, + prompt: str, + reference_images: List[str], + output_filename: str, + provider: str = "shubiaobiao" + ) -> Optional[str]: + """统一入口""" + if provider == "doubao": + return self._generate_single_image_doubao(prompt, reference_images, output_filename) + elif provider == "gemini": + return self._generate_single_image_gemini(prompt, reference_images, output_filename) + else: + return self._generate_single_image_shubiao(prompt, reference_images, output_filename) + + def _generate_single_image_doubao( + self, + prompt: str, + reference_images: List[str], + output_filename: str + ) -> Optional[str]: + """调用 Volcengine Doubao (Image API)""" + + # 1. Upload all reference images to R2 + img_urls = [] + if reference_images: + for ref_path in reference_images: + if os.path.exists(ref_path): + try: + url = storage.upload_file(ref_path) + if url: + img_urls.append(url) + logger.info(f"Uploaded Doubao ref image: {url}") + except Exception as e: + logger.warning(f"Failed to upload Doubao ref image {ref_path}: {e}") + + payload = { + "model": config.DOUBAO_IMG_MODEL, + "prompt": prompt, + "sequential_image_generation": "disabled", + "response_format": "url", + "size": "1440x2560", + "stream": False, + "watermark": False + } + + if img_urls: + payload["image_urls"] = img_urls + logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', image_urls={len(img_urls)}") + else: + logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', no reference images") + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {config.VOLC_API_KEY}" + } + + try: + logger.info(f"Submitting to Doubao Image: {self.endpoint}") + resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=180) + + if resp.status_code != 200: + msg = f"Doubao Image Failed ({resp.status_code}): {resp.text}" + logger.error(msg) + raise RuntimeError(msg) + + data = resp.json() + + if "data" in data and len(data["data"]) > 0: + image_url = data["data"][0].get("url") + if image_url: + img_resp = requests.get(image_url, timeout=60) + img_resp.raise_for_status() + + output_path = config.TEMP_DIR / output_filename + with open(output_path, "wb") as f: + f.write(img_resp.content) + return str(output_path) + + raise RuntimeError(f"No image URL in Doubao response: {data}") + + except Exception as e: + logger.error(f"Doubao Gen Failed: {e}") + raise e + + def _generate_single_image_shubiao( + self, + prompt: str, + reference_images: List[str], + output_filename: str + ) -> Optional[str]: + """调用 api2img.shubiaobiao.com 通道生成图片(同步返回 base64)""" + # 准备参考图,内联 base64 方式 + parts = [{"text": prompt}] + + # 严格过滤和排序参考图 + valid_refs = [] + if reference_images: + for p in reference_images: + if p and os.path.exists(p) and p not in valid_refs: + valid_refs.append(p) + + logger.info(f"[Shubiaobiao] Input reference images ({len(valid_refs)}): {valid_refs}") + + if valid_refs: + for ref_path in valid_refs: + try: + encoded = self._encode_image(ref_path) + if encoded: + parts.append({ + "inlineData": { + "mimeType": "image/jpeg", + "data": encoded + } + }) + except Exception as e: + logger.error(f"Failed to encode image {ref_path}: {e}") + + logger.info(f"[Shubiaobiao] Final payload parts count: {len(parts)} (1 prompt + {len(parts)-1} images)") + + payload = { + "contents": [{ + "role": "user", + "parts": parts + }], + "generationConfig": { + "responseModalities": ["IMAGE"], + "imageConfig": { + "aspectRatio": "9:16", + "imageSize": "2K" + } + } + } + + endpoint = f"{config.SHUBIAOBIAO_IMG_BASE_URL}/v1beta/models/{config.SHUBIAOBIAO_IMG_MODEL_NAME}:generateContent" + headers = { + "x-goog-api-key": config.SHUBIAOBIAO_IMG_KEY, + "Content-Type": "application/json" + } + + try: + logger.info(f"Submitting to Shubiaobiao Img: {endpoint}") + resp = requests.post(endpoint, json=payload, headers=headers, timeout=120) + + if resp.status_code != 200: + msg = f"Shubiaobiao 提交失败 ({resp.status_code}): {resp.text}" + logger.error(msg) + raise RuntimeError(msg) + + data = resp.json() + + # 查找 base64 图像 + img_b64 = None + candidates = data.get("candidates") or [] + if candidates: + content_parts = candidates[0].get("content", {}).get("parts", []) + for part in content_parts: + inline = part.get("inlineData") if isinstance(part, dict) else None + if inline and inline.get("data"): + img_b64 = inline["data"] + break + + if not img_b64: + msg = f"Shubiaobiao 响应缺少图片数据: {data}" + logger.error(msg) + raise RuntimeError(msg) + + output_path = config.TEMP_DIR / output_filename + with open(output_path, "wb") as f: + f.write(base64.b64decode(img_b64)) + + logger.info(f"Shubiaobiao Generation Success: {output_path}") + return str(output_path) + + except Exception as e: + logger.error(f"Shubiaobiao Generation Exception: {e}") + raise + + def _generate_single_image_gemini( + self, + prompt: str, + reference_images: List[str], + output_filename: str + ) -> Optional[str]: + """调用 Gemini (Wuyin Keji / NanoBanana-Pro) 生成单张图片""" + + # 1. 构造 Payload + payload = { + "prompt": prompt, + "aspectRatio": "9:16", + "imageSize": "2K" + } + + # 处理参考图 (Image-to-Image) + if reference_images: + valid_paths = [] + seen = set() + for p in reference_images: + if p and os.path.exists(p) and p not in seen: + valid_paths.append(p) + seen.add(p) + + if valid_paths: + img_urls = [] + for ref_path in valid_paths: + try: + url = storage.upload_file(ref_path) + if url: + img_urls.append(url) + logger.info(f"Uploaded ref image: {url}") + except Exception as e: + logger.warning(f"Error uploading ref image {ref_path}: {e}") + + if img_urls: + payload["img_url"] = img_urls + logger.info(f"Using {len(img_urls)} reference images for Gemini Img2Img") + + headers = { + "Authorization": config.GEMINI_IMG_KEY, + "Content-Type": "application/json;charset:utf-8" + } + + # 2. 提交任务 + try: + logger.info(f"Submitting to Gemini: {config.GEMINI_IMG_API_URL}") + resp = requests.post(config.GEMINI_IMG_API_URL, json=payload, headers=headers, timeout=30) + + if resp.status_code != 200: + msg = f"Gemini 提交失败 ({resp.status_code}): {resp.text}" + logger.error(msg) + raise RuntimeError(msg) + + data = resp.json() + if data.get("code") != 200: + msg = f"Gemini 返回错误: {data}" + logger.error(msg) + raise RuntimeError(msg) + + task_id = data.get("data", {}).get("id") + if not task_id: + raise RuntimeError(f"Gemini 响应缺少 task id: {data}") + + logger.info(f"Gemini Task Submitted, ID: {task_id}") + + # 3. 轮询状态 + max_retries = 60 + for i in range(max_retries): + time.sleep(2) + + poll_url = f"{config.GEMINI_IMG_DETAIL_URL}?key={config.GEMINI_IMG_KEY}&id={task_id}" + try: + poll_resp = requests.get(poll_url, headers=headers, timeout=30) + except requests.Timeout: + continue + except Exception as e: + continue + + if poll_resp.status_code != 200: + continue + + poll_data = poll_resp.json() + if poll_data.get("code") != 200: + raise RuntimeError(f"Gemini 轮询返回错误: {poll_data}") + + result_data = poll_data.get("data", {}) or {} + status = result_data.get("status") # 0:排队, 1:生成中, 2:成功, 3:失败 + + if status == 2: + image_url = result_data.get("image_url") + if not image_url: + raise RuntimeError("Gemini 成功但缺少 image_url") + + logger.info(f"Gemini Generation Success: {image_url}") + img_resp = requests.get(image_url, timeout=60) + img_resp.raise_for_status() + + output_path = config.TEMP_DIR / output_filename + with open(output_path, "wb") as f: + f.write(img_resp.content) + + return str(output_path) + + if status == 3: + fail_reason = result_data.get("fail_reason", "Unknown") + raise RuntimeError(f"Gemini 生成失败: {fail_reason}") + + raise RuntimeError("Gemini 生成超时") + + except Exception as e: + logger.error(f"Gemini Generation Exception: {e}") + raise diff --git a/modules/ingest.py b/modules/ingest.py new file mode 100644 index 0000000..9bc39bc --- /dev/null +++ b/modules/ingest.py @@ -0,0 +1,60 @@ +""" +MatchMe Studio - Ingest Module (Video Processing) +""" +import cv2 +import os +import logging +from pathlib import Path +from typing import List, Tuple +import config +from modules import storage + +logger = logging.getLogger(__name__) + +def process_uploaded_video(video_path: str) -> Tuple[List[str], str]: + """ + Process uploaded video: + 1. Upload raw video to R2. + 2. Extract 3 keyframes (10%, 50%, 90%). + 3. Return local frame paths and R2 video URL. + """ + if not Path(video_path).exists(): + raise FileNotFoundError(f"Video not found: {video_path}") + + logger.info(f"Processing video: {video_path}") + + # 1. Upload to R2 + video_url = storage.upload_file(video_path) + if not video_url: + raise RuntimeError("Failed to upload video to R2") + + # 2. Extract Frames + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Cannot open video: {video_path}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frame_indices = [ + int(total_frames * 0.1), + int(total_frames * 0.5), + int(total_frames * 0.9) + ] + + frame_urls = [] + for i, idx in enumerate(frame_indices): + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, frame = cap.read() + if ret: + frame_name = f"frame_{Path(video_path).stem}_{i}.jpg" + frame_path = config.TEMP_DIR / frame_name + cv2.imwrite(str(frame_path), frame) + + # Upload frame to R2 immediately + frame_url = storage.upload_file(str(frame_path)) + if frame_url: + frame_urls.append(frame_url) + + cap.release() + logger.info(f"Extracted and uploaded {len(frame_urls)} frames") + + return frame_urls, video_url diff --git a/modules/project.py b/modules/project.py new file mode 100644 index 0000000..4bd2d1c --- /dev/null +++ b/modules/project.py @@ -0,0 +1,151 @@ +""" +MatchMe Studio - Project State Management (R2 Persistence) +""" +import json +import logging +import uuid +from datetime import datetime +from typing import Dict, Any, Optional, List +from dataclasses import dataclass, asdict, field + +import config +from modules import storage + +logger = logging.getLogger(__name__) + + +@dataclass +class Scene: + id: int + duration: int = 5 + timeline: str = "" + keyframe: Dict[str, str] = field(default_factory=dict) + camera_movement: str = "" + story_beat: str = "" + voiceover: str = "" + rhythm: Dict[str, Any] = field(default_factory=dict) + image_url: str = "" + video_url: str = "" + + +@dataclass +class Project: + id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) + created_at: str = field(default_factory=lambda: datetime.now().isoformat()) + status: str = "draft" # draft | analyzing | scripting | imaging | video | rendering | done + + # Step 0: Input + input_mode: str = "" # text | images | video + prompt: str = "" + image_urls: List[str] = field(default_factory=list) + video_url: str = "" + asr_text: str = "" + + # Step 1: Analysis + analysis: str = "" + questions: List[Dict[str, Any]] = field(default_factory=list) + answers: Dict[str, str] = field(default_factory=dict) + + # Step 2: Script + hook: str = "" + scenes: List[Dict[str, Any]] = field(default_factory=list) + cta: str = "" + + # Step 6: Final + final_video_url: str = "" + bgm_url: str = "" + + +def save_project(project: Project) -> str: + """Save project state to R2 as JSON.""" + data = asdict(project) + json_str = json.dumps(data, ensure_ascii=False, indent=2) + + # Write to temp file + temp_path = config.TEMP_DIR / f"project_{project.id}.json" + with open(temp_path, "w", encoding="utf-8") as f: + f.write(json_str) + + # Upload to R2 + object_name = f"projects/{project.id}.json" + s3 = storage.get_s3_client() + + try: + s3.upload_file( + str(temp_path), + config.R2_BUCKET_NAME, + object_name, + ExtraArgs={'ContentType': 'application/json'} + ) + logger.info(f"Project {project.id} saved to R2") + return project.id + except Exception as e: + logger.error(f"Failed to save project: {e}") + raise + + +def load_project(project_id: str) -> Optional[Project]: + """Load project state from R2.""" + object_name = f"projects/{project_id}.json" + temp_path = config.TEMP_DIR / f"project_{project_id}.json" + + s3 = storage.get_s3_client() + + try: + s3.download_file(config.R2_BUCKET_NAME, object_name, str(temp_path)) + + with open(temp_path, "r", encoding="utf-8") as f: + data = json.load(f) + + # Reconstruct Project + project = Project( + id=data.get("id", project_id), + created_at=data.get("created_at", ""), + status=data.get("status", "draft"), + input_mode=data.get("input_mode", ""), + prompt=data.get("prompt", ""), + image_urls=data.get("image_urls", []), + video_url=data.get("video_url", ""), + asr_text=data.get("asr_text", ""), + analysis=data.get("analysis", ""), + questions=data.get("questions", []), + answers=data.get("answers", {}), + hook=data.get("hook", ""), + scenes=data.get("scenes", []), + cta=data.get("cta", ""), + final_video_url=data.get("final_video_url", ""), + bgm_url=data.get("bgm_url", "") + ) + + logger.info(f"Project {project_id} loaded from R2") + return project + + except Exception as e: + logger.warning(f"Failed to load project {project_id}: {e}") + return None + + +def create_project() -> Project: + """Create a new project with unique ID.""" + project = Project() + logger.info(f"Created new project: {project.id}") + return project + + + + + + + + + + + + + + + + + + + diff --git a/modules/script_gen.py b/modules/script_gen.py new file mode 100644 index 0000000..e380495 --- /dev/null +++ b/modules/script_gen.py @@ -0,0 +1,390 @@ +""" +脚本生成模块 (Gemini-3-Pro) +负责解析商品信息,生成分镜脚本 +""" +import base64 +import json +import logging +import os +import requests +from typing import Dict, Any, List, Optional +from pathlib import Path + +import config +from modules.db_manager import db + +logger = logging.getLogger(__name__) + +class ScriptGenerator: + """分镜脚本生成器""" + + def __init__(self): + self.api_key = config.SHUBIAOBIAO_KEY + # 注意:API 地址可能需要适配 gemini-3-pro-preview 的具体路径 + # 根据 demo: https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent + # 这里我们先假设 base_url 是 v1beta/models/ + self.endpoint = "https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent" + + # Default System Prompt + self.default_system_prompt = """ +你是一个专业的抖音电商短视频导演。请根据提供的商品信息和图片,设计一个高转化率的商品详情页首图视频脚本。 + +## 目标 +- 提升商品详情页的 GPM 和下单转化率 +- 视频时长 9-12 秒 (由 3-4 个分镜组成) +- **每个分镜时长固定为 3 秒** (duration: 3),不要超过 3 秒 +- 必须包含:目标人群分析、卖点提炼、分镜设计 + +## 分镜设计原则 +1. **单分镜单主体**:每个分镜聚焦一个视觉主体或动作,避免复杂运镜,因为 AI 生视频在长时间(>3秒)容易出现画面异常。 +2. **旁白跨分镜**:一段完整的旁白/卖点可以跨越多个分镜。在 voiceover_timeline 中,通过 start_time 和 duration (秒) 控制旁白的绝对时间位置,无需与分镜一一对应。 +3. **节奏感**:分镜之间保持视觉连贯,通过景别变化(特写 -> 中景 -> 全景)制造节奏。 +4. **语速控制**:旁白语速约 4 字/秒,12字旁白约需 3 秒。 + +## 输出格式要求 (JSON) +必须严格遵守以下 JSON 结构: +{ + "product_name": "商品名称", + "visual_anchor": "商品视觉锚点:材质+颜色+形状+包装特征(用于保持生图一致性)", + "selling_points": ["卖点1", "卖点2"], + "target_audience": "目标人群描述", + "video_style": "视频风格关键词", + "bgm_style": "BGM风格关键词", + "voiceover_timeline": [ + { + "id": 1, + "text": "旁白文案片段1(可横跨多个分镜)", + "subtitle": "字幕文案1 (简短有力)", + "start_time": 0.0, + "duration": 3.0 + }, + { + "id": 2, + "text": "旁白文案片段2", + "subtitle": "字幕文案2", + "start_time": 3.5, + "duration": 2.5 + } + ], + "scenes": [ + { + "id": 1, + "duration": 3, + "visual_prompt": "详细的画面描述,用于AI生图,包含主体、背景、构图、光影。英文描述。", + "video_prompt": "详细的动效描述,用于AI图生视频。英文描述。", + "fancy_text": { + "text": "花字文案 (最多6字)", + "style": "highlight", + "position": "center", + "start_time": 0.5, + "duration": 2.0 + } + } + ] +} + +## 注意事项 +1. **visual_prompt**: + - 必须是英文。 + - 描述要具体,例如 "Close-up shot of a hair clip, soft lighting, minimalist background". + - **CRITICAL**: 禁止 AI 额外生成装饰性文字、标语、水印。但必须保留商品包装自带的文字和 Logo(这是商品真实外观的一部分)。 + - 正确写法: "Product front view, keep original packaging design --no added text --no watermarks" + - **EMPHASIS**: Strictly follow the appearance of the product in the reference images. +2. **video_prompt**: 必须是英文,描述动作,例如 "Slow zoom in, the hair clip rotates slightly"。注意保持动作简单,避免复杂运镜和人体动作。 +3. **voiceover_timeline**: + - 这是整个视频的旁白和字幕时间轴,独立于分镜。 + - `start_time` 是旁白开始的绝对时间 (秒),`duration` 是旁白持续时长 (秒)。 + - **一段旁白可以横跨多个分镜**,例如:总时长 9 秒 (3 个分镜),一段旁白从 start_time=0,duration=5,则覆盖前两个分镜。 + - 两段旁白之间留 0.3-0.5 秒间隙(气口)。 +4. **fancy_text**: + - 花字要精简(最多 6 字),突出卖点。 + - **Style Selection**: + - `highlight`: 默认样式,适合通用卖点 (Yellow/Black)。 + - `warning`: 强调痛点或食欲 (Red/White)。 + - `price`: 价格显示 (Big Red)。 + - `bubble`: 旁白补充或用户评价 (Bubble)。 + - `minimal`: 高级感,适合时尚类 (Thin/White)。 + - `tech`: 数码类 (Cyan/Glow)。 + - `position` 默认 `center`,可选 top/bottom/top-left/bottom-right 等。 +5. **场景连贯性**: 确保分镜之间的逻辑和视觉风格连贯。每个分镜 duration 必须为 3。 +""" + + def _encode_image(self, image_path: str) -> str: + """读取图片并转为 Base64""" + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + def generate_script( + self, + product_name: str, + product_info: Dict[str, Any], + image_paths: List[str] = None, + model_provider: str = "shubiaobiao" # "shubiaobiao" or "doubao" + ) -> Dict[str, Any]: + """ + 生成分镜脚本 + """ + logger.info(f"Generating script for: {product_name} (Provider: {model_provider})") + + # 1. 构造 Prompt (优先从数据库读取配置) + system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt) + user_prompt = self._build_user_prompt(product_name, product_info) + + # Branch for Doubao + if model_provider == "doubao": + return self._generate_script_doubao(system_prompt, user_prompt, image_paths) + + # ... Existing Shubiaobiao Logic ... + + # 调试: 检查是否使用了自定义 Prompt + if system_prompt != self.default_system_prompt: + logger.info("Using CUSTOM system prompt from database") + else: + logger.info("Using DEFAULT system prompt") + + # 2. 构造请求 Payload (Gemini/Shubiaobiao) + contents = [] + + # User message parts + user_parts = [{"text": user_prompt}] + + # 添加图片 (Multimodal input) + if image_paths: + for path in image_paths[:10]: # 限制10张,Gemini-3-Pro 支持多图 + if Path(path).exists(): + try: + b64_img = self._encode_image(path) + user_parts.append({ + "inline_data": { + "mime_type": "image/jpeg", # 假设是 JPG/PNG + "data": b64_img + } + }) + except Exception as e: + logger.warning(f"Failed to encode image {path}: {e}") + + contents.append({ + "role": "user", + "parts": user_parts + }) + + # System instruction (Gemini 支持 system instruction 或者是放在 user prompt 前) + user_parts.insert(0, {"text": system_prompt}) + + payload = { + "contents": contents, + "generationConfig": { + "response_mime_type": "application/json", + "temperature": 0.7 + } + } + + headers = { + "x-goog-api-key": self.api_key, + "Content-Type": "application/json" + } + + # 3. 调用 API + try: + response = requests.post(self.endpoint, headers=headers, json=payload, timeout=60) + response.raise_for_status() + + result = response.json() + + # 4. 解析结果 + if "candidates" in result and result["candidates"]: + content_text = result["candidates"][0]["content"]["parts"][0]["text"] + + # 提取 JSON 部分 (处理 Markdown 代码块或纯文本) + script_json = self._extract_json_from_response(content_text) + + if script_json is None: + logger.error(f"Failed to extract JSON from response: {content_text[:500]}...") + return None + + final_script = self._validate_and_fix_script(script_json) + + # Add Debug Info (包含原始输出) + final_script["_debug"] = { + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "raw_output": content_text, + "provider": "shubiaobiao" + } + return final_script + else: + logger.error(f"No candidates in response: {result}") + return None + + except Exception as e: + logger.error(f"Script generation failed: {e}") + if 'response' in locals(): + logger.error(f"Response content: {response.text}") + return None + + def _generate_script_doubao( + self, + system_prompt: str, + user_prompt: str, + image_paths: List[str] + ) -> Dict[str, Any]: + """Doubao 脚本生成实现 (Multimodal)""" + # User Provided: https://ark.cn-beijing.volces.com/api/v3/responses + # But for 'responses' API, structure is specific. Let's try to match user's curl format exactly but adapting content. + # User curl uses "input": [{"role": "user", "content": [{"type": "input_image"...}, {"type": "input_text"...}]}] + + endpoint = "https://ark.cn-beijing.volces.com/api/v3/chat/completions" # Recommend standard Chat API first as 'responses' is usually non-standard or older + # However, user explicitly provided /responses curl. Let's try to stick to standard Chat Completions first because Doubao Pro 1.5 is OpenAI compatible. + # If that fails or if user insists on the specific structure, we can adapt. + # Volcengine 'ep-...' models are usually served via standard /chat/completions. + + # Let's try standard OpenAI format which Doubao supports perfectly. + + messages = [ + {"role": "system", "content": system_prompt} + ] + + user_content = [] + + # Add Images (Doubao Vision supports image_url) + if image_paths: + for path in image_paths[:5]: # Limit + if os.path.exists(path): + # For Volcengine, need to upload or use base64? + # Standard OpenAI format supports base64 data urls. + # "image_url": {"url": "data:image/jpeg;base64,..."} + try: + b64_img = self._encode_image(path) + user_content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{b64_img}" + } + }) + except Exception as e: + logger.warning(f"Failed to encode image for Doubao: {e}") + + # Add Text + user_content.append({"type": "text", "text": user_prompt}) + + messages.append({ + "role": "user", + "content": user_content + }) + + payload = { + "model": config.DOUBAO_SCRIPT_MODEL, + "messages": messages, + "stream": False, + # "response_format": {"type": "json_object"} # Try enabling JSON mode if supported + } + + headers = { + "Authorization": f"Bearer {config.VOLC_API_KEY}", + "Content-Type": "application/json" + } + + try: + # Try standard chat/completions first + resp = requests.post(endpoint, headers=headers, json=payload, timeout=120) + + if resp.status_code != 200: + # If 404, maybe endpoint is wrong, try the user's 'responses' endpoint? + # But 'responses' usually implies a different payload structure. + logger.warning(f"Doubao Chat API failed ({resp.status_code}), trying legacy/custom endpoint...") + # Fallback to user provided structure if needed (implement later if this fails) + resp.raise_for_status() + + result = resp.json() + content_text = result["choices"][0]["message"]["content"] + + script_json = self._extract_json_from_response(content_text) + + if script_json is None: + logger.error(f"Failed to extract JSON from Doubao response: {content_text[:500]}...") + return None + + final_script = self._validate_and_fix_script(script_json) + final_script["_debug"] = { + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "raw_output": content_text, + "provider": "doubao" + } + return final_script + + except Exception as e: + logger.error(f"Doubao script generation failed: {e}") + if 'resp' in locals(): + logger.error(f"Response: {resp.text}") + return None + + def _extract_json_from_response(self, text: str) -> Optional[Dict]: + """ + 从 API 响应中提取 JSON 对象 + 支持: + 1. 纯 JSON 响应 + 2. Markdown 代码块包裹的 JSON (```json ... ```) + 3. 文本中嵌入的 JSON (找到第一个 { 和最后一个 }) + """ + import re + + # 方法1: 尝试直接解析(纯 JSON 情况) + try: + return json.loads(text.strip()) + except json.JSONDecodeError: + pass + + # 方法2: 提取 ```json ... ``` 代码块 + json_block_match = re.search(r'```json\s*([\s\S]*?)\s*```', text) + if json_block_match: + try: + return json.loads(json_block_match.group(1)) + except json.JSONDecodeError as e: + logger.warning(f"JSON block found but parse failed: {e}") + + # 方法3: 提取 ``` ... ``` 代码块 (无 json 标记) + code_block_match = re.search(r'```\s*([\s\S]*?)\s*```', text) + if code_block_match: + try: + return json.loads(code_block_match.group(1)) + except json.JSONDecodeError: + pass + + # 方法4: 找到第一个 { 和最后一个 } 之间的内容 + first_brace = text.find('{') + last_brace = text.rfind('}') + if first_brace != -1 and last_brace != -1 and last_brace > first_brace: + try: + return json.loads(text[first_brace:last_brace + 1]) + except json.JSONDecodeError as e: + logger.warning(f"Brace extraction failed: {e}") + + return None + + def _build_user_prompt(self, product_name: str, product_info: Dict[str, Any]) -> str: + # 提取商家偏好提示 + style_hint = product_info.get("style_hint", "") + # 过滤掉不需要展示的字段 + filtered_info = {k: v for k, v in product_info.items() if k not in ["uploaded_images", "style_hint"]} + info_str = "\n".join([f"- {k}: {v}" for k, v in filtered_info.items()]) + + prompt = f""" +商品名称:{product_name} +商品信息: +{info_str} +""" + if style_hint: + prompt += f""" +## 商家特别要求 +{style_hint} +""" + prompt += "\n请根据以上信息设计视频脚本。" + return prompt + + def _validate_and_fix_script(self, script: Dict[str, Any]) -> Dict[str, Any]: + """校验并修复脚本结构""" + # 简单校验,确保必要字段存在 + if "scenes" not in script: + script["scenes"] = [] + return script diff --git a/modules/storage.py b/modules/storage.py new file mode 100644 index 0000000..86e5925 --- /dev/null +++ b/modules/storage.py @@ -0,0 +1,84 @@ +""" +MatchMe Studio - Storage Module (R2) +""" +import os +import logging +import time +import uuid +import boto3 +from botocore.exceptions import NoCredentialsError +from pathlib import Path +from typing import Optional + +import config + +logger = logging.getLogger(__name__) + +def get_s3_client(): + try: + return boto3.client( + 's3', + endpoint_url=config.R2_ENDPOINT, + aws_access_key_id=config.R2_ACCESS_KEY, + aws_secret_access_key=config.R2_SECRET_KEY, + region_name='auto' + ) + except Exception as e: + logger.error(f"Failed to create R2 client: {e}") + raise + +def upload_file(file_path: str) -> Optional[str]: + """Upload file to R2 and return Public URL.""" + if not os.path.exists(file_path): + logger.error(f"File not found: {file_path}") + return None + + # 使用 UUID 作为文件名,避免中文/特殊字符导致的 URL 问题 + original_name = Path(file_path).name + ext = Path(file_path).suffix.lower() or ".bin" + object_name = f"{uuid.uuid4().hex}{ext}" + + s3 = get_s3_client() + + try: + logger.info(f"Uploading {original_name} to R2 as {object_name}...") + + # 根据后缀设置正确的 Content-Type + if ext == ".png": + content_type = "image/png" + elif ext in [".jpg", ".jpeg"]: + content_type = "image/jpeg" + elif ext == ".mp4": + content_type = "video/mp4" + elif ext == ".mp3": + content_type = "audio/mpeg" + else: + content_type = "application/octet-stream" + + s3.upload_file( + file_path, + config.R2_BUCKET_NAME, + object_name, + ExtraArgs={'ContentType': content_type} + ) + + public_url = f"{config.R2_PUBLIC_URL}/{object_name}" + logger.info(f"Upload successful: {public_url}") + return public_url + + except Exception as e: + logger.error(f"R2 Upload Failed: {e}") + return None + +def cleanup_temp(max_age_seconds: int = 3600): + """Delete old temp files.""" + logger.info("Running cleanup_temp...") + now = time.time() + if not config.TEMP_DIR.exists(): return + + for f in config.TEMP_DIR.iterdir(): + try: + if f.is_file() and (now - f.stat().st_mtime) > max_age_seconds: + f.unlink() + except Exception as e: + logger.warning(f"Failed to delete {f}: {e}") diff --git a/modules/styles.py b/modules/styles.py new file mode 100644 index 0000000..7f1120d --- /dev/null +++ b/modules/styles.py @@ -0,0 +1,76 @@ +""" +花字样式预设库 +供 Design Agent 和 Renderer 使用 +""" + +STYLES = { + # 1. 醒目强调 (黄色高亮) + "highlight": { + "font_size": 60, + "font_color": "#FFE66D", # 亮黄 + "stroke": {"color": "#000000", "width": 4}, + "shadow": {"color": "#000000", "blur": 8, "offset": [4, 4], "opacity": 0.6} + }, + + # 2. 警告/痛点 (红色/黑色背景) + "warning": { + "font_size": 55, + "font_color": "#FFFFFF", + "stroke": {"color": "#FF0000", "width": 0}, # 无描边 + "background": { + "type": "box", + "color": "#FF4D4F", # 红色背景 + "corner_radius": 12, + "padding": [15, 25, 15, 25] # t, r, b, l + }, + "shadow": {"color": "#990000", "blur": 0, "offset": [0, 6], "opacity": 0.4} # 立体感阴影 + }, + + # 3. 价格/促销 (大号红色) + "price": { + "font_size": 90, + "font_color": "#FF2E2E", # 鲜红 + "stroke": {"color": "#FFFFFF", "width": 6}, # 白边 + "shadow": {"color": "#FF9999", "blur": 15, "offset": [0, 0], "opacity": 0.8} # 发光效果 + }, + + # 4. 对话/气泡 (黑字白底圆角) + "bubble": { + "font_size": 48, + "font_color": "#333333", + "background": { + "type": "box", + "color": "#FFFFFF", + "corner_radius": 40, # 大圆角 + "padding": [20, 40, 20, 40] + }, + "shadow": {"color": "#000000", "blur": 10, "offset": [2, 5], "opacity": 0.2} + }, + + # 5. 时尚/极简 (细黑体+白字) + "minimal": { + "font_size": 65, + "font_color": "#FFFFFF", + "stroke": {"color": "#000000", "width": 2}, + "shadow": {"color": "#000000", "blur": 2, "offset": [2, 2], "opacity": 0.8}, + "font_family": "NotoSansSC-Regular.otf" # 假设有这个字体,或者回退 + }, + + # 6. 科技/未来 (青色+发光) + "tech": { + "font_size": 60, + "font_color": "#00FFFF", + "stroke": {"color": "#003333", "width": 3}, + "shadow": {"color": "#00FFFF", "blur": 20, "offset": [0, 0], "opacity": 0.9} + } +} + +def get_style(style_name: str) -> dict: + """获取样式配置,支持回退""" + return STYLES.get(style_name, STYLES["highlight"]) + + + + + + diff --git a/modules/text_renderer.py b/modules/text_renderer.py new file mode 100644 index 0000000..3b34421 --- /dev/null +++ b/modules/text_renderer.py @@ -0,0 +1,251 @@ +""" +通用文本渲染引擎 +支持原子化设计参数,供上游 Design Agent 灵活调用 +""" +import os +import hashlib +import logging +from pathlib import Path +from typing import Dict, Any, List, Tuple, Union, Optional + +from PIL import Image, ImageDraw, ImageFont, ImageFilter, ImageColor + +import config +from modules.styles import get_style + +logger = logging.getLogger(__name__) + +# 缓存目录 +CACHE_DIR = config.TEMP_DIR / "text_renderer_cache" +CACHE_DIR.mkdir(exist_ok=True) + + +class TextRenderer: + """ + 通用文本渲染器 + 基于原子化参数渲染文本图片 (PNG) + """ + + def __init__(self): + self.default_font_path = self._resolve_font_path(None) + + def _resolve_font_path(self, font_family: Optional[str]) -> str: + """解析字体路径,支持多级回退""" + candidates = [] + if font_family: + # 1. 尝试作为绝对路径 + candidates.append(font_family) + # 2. 尝试在 assets/fonts 下查找 + candidates.append(str(config.FONTS_DIR / font_family)) + if not font_family.endswith(".ttf") and not font_family.endswith(".otf"): + candidates.append(str(config.FONTS_DIR / f"{font_family}.ttf")) + candidates.append(str(config.FONTS_DIR / f"{font_family}.otf")) + + # 3. 预设项目字体 + candidates.extend([ + str(config.FONTS_DIR / "SmileySans-Oblique.ttf"), + str(config.FONTS_DIR / "AlibabaPuHuiTi-Bold.ttf"), + str(config.FONTS_DIR / "AlibabaPuHuiTi-Regular.ttf"), + str(config.FONTS_DIR / "NotoSansSC-Bold.otf"), # 假如有效 + ]) + + # 4. 系统字体回退 + candidates.extend([ + "/System/Library/Fonts/PingFang.ttc", + "/System/Library/Fonts/STHeiti Medium.ttc", + "C:/Windows/Fonts/msyh.ttc", + "C:/Windows/Fonts/simhei.ttf", + ]) + + for path in candidates: + if path and os.path.exists(path): + # 简单验证文件大小 + try: + if os.path.getsize(path) > 10000: + return path + except: + continue + + logger.warning("No valid font found, using default load_default()") + return None + + def _get_font(self, font_path: str, size: int) -> ImageFont.FreeTypeFont: + try: + if font_path: + return ImageFont.truetype(font_path, size) + except Exception as e: + logger.warning(f"Failed to load font {font_path}: {e}") + return ImageFont.load_default() + + def _parse_color(self, color: Union[str, Tuple]) -> Tuple[int, int, int, int]: + """解析颜色为 RGBA""" + if isinstance(color, str): + if color.startswith("#"): + rgb = ImageColor.getrgb(color) + return rgb + (255,) + # TODO: 支持 'rgba(r,g,b,a)' 格式 + if isinstance(color, tuple): + if len(color) == 3: + return color + (255,) + return color + return (0, 0, 0, 255) + + def render(self, text: str, style: Union[Dict[str, Any], str], cache: bool = True) -> str: + """ + 渲染文本并返回图片路径 + + style 结构: + { + "font_family": str, + "font_size": int, + "font_color": str, + "stroke": [{"color": str, "width": int}, ...], + "shadow": {"color": str, "blur": int, "offset": [x, y], "opacity": float}, + "background": { + "type": "box", "color": str/list, "corner_radius": int, "padding": [t, r, b, l] + } + } + """ + # 0. 解析样式 + if isinstance(style, str): + style = get_style(style) + + # 1. 缓存检查 + cache_key = hashlib.md5(f"{text}_{str(style)}".encode()).hexdigest() + if cache: + cache_path = CACHE_DIR / f"{cache_key}.png" + if cache_path.exists(): + return str(cache_path) + + # 2. 解析基本参数 + font_path = self._resolve_font_path(style.get("font_family")) + font_size = style.get("font_size", 60) + font = self._get_font(font_path, font_size) + font_color = self._parse_color(style.get("font_color", "#FFFFFF")) + + # 3. 测量文本尺寸 + dummy_draw = ImageDraw.Draw(Image.new("RGBA", (1, 1))) + bbox = dummy_draw.textbbox((0, 0), text, font=font) + text_w = bbox[2] - bbox[0] + text_h = bbox[3] - bbox[1] + + # 4. 计算总尺寸 (包含 padding, stroke, shadow) + strokes = style.get("stroke", []) + if isinstance(strokes, dict): strokes = [strokes] # 兼容旧格式 + + max_stroke = 0 + for s in strokes: + max_stroke = max(max_stroke, s.get("width", 0)) + + shadow = style.get("shadow", {}) + shadow_blur = shadow.get("blur", 0) + shadow_offset = shadow.get("offset", [0, 0]) + + bg = style.get("background", {}) + padding = bg.get("padding", [0, 0, 0, 0]) + if isinstance(padding, int): padding = [padding] * 4 + if len(padding) == 2: padding = [padding[0], padding[1], padding[0], padding[1]] # v, h -> t, r, b, l + + # 内容区域尺寸 (文本 + padding) + content_w = text_w + padding[1] + padding[3] + content_h = text_h + padding[0] + padding[2] + + # 扩展区域 (描边 + 阴影) + extra_margin = max_stroke + shadow_blur + max(abs(shadow_offset[0]), abs(shadow_offset[1])) + 10 + + canvas_w = content_w + extra_margin * 2 + canvas_h = content_h + extra_margin * 2 + + # 5. 创建画布 + img = Image.new("RGBA", (int(canvas_w), int(canvas_h)), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + # 锚点位置 (文本中心点) + center_x = canvas_w // 2 + center_y = canvas_h // 2 + + # 6. 绘制顺序: 阴影 -> 背景 -> 描边 -> 文本 + + # --- 绘制阴影 (针对整个块) --- + if shadow: + shadow_color = self._parse_color(shadow.get("color", "#000000")) + opacity = shadow.get("opacity", 0.5) + shadow_color = (shadow_color[0], shadow_color[1], shadow_color[2], int(255 * opacity)) + + # 临时画布绘制形状用于生成阴影 + shadow_layer = Image.new("RGBA", (int(canvas_w), int(canvas_h)), (0, 0, 0, 0)) + shadow_draw = ImageDraw.Draw(shadow_layer) + + # 如果有背景,阴影跟随背景形状;否则跟随文字 + if bg and bg.get("type") != "none": + self._draw_background(shadow_draw, bg, center_x, center_y, content_w, content_h, shadow_color) + else: + # 文字阴影 + txt_x = center_x - text_w / 2 + txt_y = center_y - text_h / 2 + shadow_draw.text((txt_x, txt_y), text, font=font, fill=shadow_color) + # 描边阴影 + for s in strokes: + width = s.get("width", 0) + # 简单模拟描边阴影:多次绘制 + # (略: 完整描边阴影开销大,暂只做文字阴影) + + # 应用模糊 + if shadow_blur > 0: + shadow_layer = shadow_layer.filter(ImageFilter.GaussianBlur(shadow_blur)) + + # 应用偏移 + final_shadow = Image.new("RGBA", (int(canvas_w), int(canvas_h)), (0, 0, 0, 0)) + final_shadow.paste(shadow_layer, (int(shadow_offset[0]), int(shadow_offset[1])), mask=shadow_layer) + + img = Image.alpha_composite(final_shadow, img) + draw = ImageDraw.Draw(img) # 重置 draw + + # --- 绘制背景 --- + if bg and bg.get("type") in ["box", "circle"]: + bg_color = self._parse_color(bg.get("color", "#000000")) + # TODO: 支持渐变背景 + self._draw_background(draw, bg, center_x, center_y, content_w, content_h, bg_color) + + # --- 绘制描边 (仅针对文字) --- + # 从外向内绘制 + txt_x = center_x - text_w / 2 + txt_y = center_y - text_h / 2 + + for s in reversed(strokes): + color = self._parse_color(s.get("color", "#000000")) + width = s.get("width", 0) + if width > 0: + # 通过偏移模拟描边 (Pillow stroke_width 效果一般,但这里先用原生参数) + draw.text((txt_x, txt_y), text, font=font, fill=color, stroke_width=width, stroke_fill=color) + + # --- 绘制文字 --- + draw.text((txt_x, txt_y), text, font=font, fill=font_color) + + # 7. 裁剪多余透明区域 + bbox = img.getbbox() + if bbox: + img = img.crop(bbox) + + # 8. 保存 + output_path = str(CACHE_DIR / f"{cache_key}.png") + img.save(output_path) + logger.info(f"Rendered text: {text} -> {output_path}") + + return output_path + + def _draw_background(self, draw, bg, cx, cy, w, h, color): + """绘制背景形状""" + corner_radius = bg.get("corner_radius", 0) + x0 = cx - w / 2 + y0 = cy - h / 2 + x1 = cx + w / 2 + y1 = cy + h / 2 + + if bg.get("type") == "box": + draw.rounded_rectangle([x0, y0, x1, y1], radius=corner_radius, fill=color) + elif bg.get("type") == "circle": + draw.ellipse([x0, y0, x1, y1], fill=color) + +# 全局单例 +renderer = TextRenderer() diff --git a/modules/utils.py b/modules/utils.py new file mode 100644 index 0000000..6deee27 --- /dev/null +++ b/modules/utils.py @@ -0,0 +1,177 @@ +""" +Gloda Video Factory - Utility Functions +Handles font management, Auto-QC, and helper effects. +""" + +import os +import logging +from pathlib import Path +from typing import Optional, Tuple +import urllib.request +import math + +import numpy as np +from PIL import Image +from moviepy.editor import ImageClip, VideoFileClip, AudioFileClip + +import config + +logger = logging.getLogger(__name__) + +# Google Fonts CDN URL +ROBOTO_BOLD_URL = "https://github.com/googlefonts/roboto/raw/main/src/hinted/Roboto-Bold.ttf" +NOTO_SC_BOLD_URL = "https://raw.githubusercontent.com/google/fonts/main/ofl/notosanssc/NotoSansSC-Bold.ttf" + +FONT_PATH_EN = config.FONTS_DIR / "Roboto-Bold.ttf" +FONT_PATH_CN = config.FONTS_DIR / "NotoSansSC-Bold.ttf" + + +def ensure_fonts() -> Path: + """Ensure required fonts (EN & CN) are available.""" + config.FONTS_DIR.mkdir(parents=True, exist_ok=True) + + # English Font + if not FONT_PATH_EN.exists(): + logger.info(f"Downloading Roboto-Bold font...") + try: + urllib.request.urlretrieve(ROBOTO_BOLD_URL, FONT_PATH_EN) + except Exception as e: + logger.error(f"Failed to download EN font: {e}") + + # Chinese Font + if not FONT_PATH_CN.exists(): + logger.info(f"Downloading NotoSansSC-Bold font...") + try: + # Using a reliable mirror or source if Github raw is flaky, but trying Github first + urllib.request.urlretrieve(NOTO_SC_BOLD_URL, FONT_PATH_CN) + except Exception as e: + logger.error(f"Failed to download CN font: {e}") + + # Return CN font as default for mixed text + if FONT_PATH_CN.exists(): + return FONT_PATH_CN + return FONT_PATH_EN + + +def check_imagemagick() -> bool: + """Check if ImageMagick is installed.""" + import shutil + if shutil.which("convert"): + return True + else: + logger.warning("ImageMagick not found. Text overlays may fail.") + return False + + +def verify_assets(video_path: str, audio_path: str) -> Tuple[bool, str]: + """ + Auto-QC: Verify generated assets quality. + + Checks: + 1. File size sanity check + 2. Duration matching (+/- 2s tolerance) + 3. Audio silence check + + Returns: + (Passed: bool, Reason: str) + """ + logger.info(f"Running Auto-QC on:\nVideo: {video_path}\nAudio: {audio_path}") + + try: + # 1. File Size Check + vid_size = os.path.getsize(video_path) + if vid_size < 50 * 1024: # < 50KB + return False, f"Video file too small ({vid_size/1024:.1f}KB). Likely error/black screen." + + aud_size = os.path.getsize(audio_path) + if aud_size < 5 * 1024: # < 5KB + return False, f"Audio file too small ({aud_size/1024:.1f}KB)." + + # 2. Duration Check + try: + v_clip = VideoFileClip(video_path) + a_clip = AudioFileClip(audio_path) + + v_dur = v_clip.duration + a_dur = a_clip.duration + + # Check for silence (RMS) + # Read first 2 seconds of audio + chunk = a_clip.to_soundarray(fps=44100, nbytes=2, buffersize=1000) + if chunk is not None: + rms = np.sqrt(np.mean(chunk**2)) + if rms < 0.001: + v_clip.close() + a_clip.close() + return False, "Audio appears to be silent (RMS < 0.001)" + + v_clip.close() + a_clip.close() + + # Tolerance check + if abs(v_dur - a_dur) > 2.0: + return False, f"Duration mismatch: Video={v_dur:.1f}s, Audio={a_dur:.1f}s" + + except Exception as e: + return False, f"Media analysis failed: {str(e)}" + + return True, "QC Passed" + + except Exception as e: + logger.error(f"Auto-QC Error: {e}") + return False, f"QC System Error: {e}" + + +def apply_ken_burns( + image_path: str, + duration: float = 5.0, + zoom_ratio: float = 1.2, + output_path: Optional[str] = None +) -> str: + """Apply Ken Burns effect (slow zoom in) to a static image.""" + if output_path is None: + base_name = Path(image_path).stem + output_path = str(config.OUTPUT_DIR / f"{base_name}_ken_burns.mp4") + + logger.info(f"Applying Ken Burns effect to {image_path}") + + img = Image.open(image_path) + img_width, img_height = img.size + target_width = config.VIDEO_SETTINGS["width"] + target_height = config.VIDEO_SETTINGS["height"] + fps = config.VIDEO_SETTINGS["fps"] + + scale_w = (target_width * zoom_ratio) / img_width + scale_h = (target_height * zoom_ratio) / img_height + base_scale = max(scale_w, scale_h) + + new_width = int(img_width * base_scale) + new_height = int(img_height * base_scale) + img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + img_array = np.array(img_resized) + + def make_frame(t): + progress = t / duration + eased_progress = 0.5 - 0.5 * np.cos(np.pi * progress) + current_zoom = 1 + (zoom_ratio - 1) * eased_progress + + crop_width = int(target_width / current_zoom * (new_width / target_width)) + crop_height = int(target_height / current_zoom * (new_height / target_height)) + + crop_width = min(crop_width, new_width) + crop_height = min(crop_height, new_height) + + x_start = (new_width - crop_width) // 2 + y_start = (new_height - crop_height) // 2 + + cropped = img_array[y_start:y_start + crop_height, x_start:x_start + crop_width] + cropped_pil = Image.fromarray(cropped) + resized = cropped_pil.resize((target_width, target_height), Image.Resampling.LANCZOS) + return np.array(resized) + + clip = ImageClip(make_frame, duration=duration) + clip = clip.set_fps(fps) + clip.write_videofile(output_path, fps=fps, codec=config.VIDEO_SETTINGS["codec"], audio=False, logger=None) + clip.close() + + return output_path diff --git a/modules/video_gen.py b/modules/video_gen.py new file mode 100644 index 0000000..e944d00 --- /dev/null +++ b/modules/video_gen.py @@ -0,0 +1,269 @@ +""" +图生视频模块 (Volcengine Doubao-SeedDance) +负责将分镜图片转换为视频片段 +""" +import logging +import time +import requests +import os +from typing import Dict, Any, List, Optional +from pathlib import Path + +import config +from modules import storage +from modules.db_manager import db + +logger = logging.getLogger(__name__) + +class VideoGenerator: + """图生视频生成器""" + + def __init__(self): + self.api_key = config.VOLC_API_KEY + self.base_url = config.VOLC_BASE_URL + self.model_id = config.VIDEO_MODEL_ID + + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + + def submit_scene_video_task( + self, + project_id: str, + scene_id: int, + image_path: str, + prompt: str + ) -> str: + """ + 提交单场景视频生成任务 + Returns: task_id or None + """ + if not image_path or not os.path.exists(image_path): + logger.warning(f"Skipping video generation for Scene {scene_id}: Image not found") + return None + + # 上传图片到 R2 获取 URL + logger.info(f"Uploading image for Scene {scene_id}...") + image_url = storage.upload_file(image_path) + + if not image_url: + logger.error(f"Failed to upload image for Scene {scene_id}") + return None + + logger.info(f"Submitting video task for Scene {scene_id}...") + task_id = self._submit_task(image_url, prompt) + + if task_id: + # 立即保存 task_id 到数据库,状态为 processing + db.save_asset( + project_id=project_id, + scene_id=scene_id, + asset_type="video", + status="processing", + task_id=task_id, + local_path=None + ) + + return task_id + + def recover_video_from_task(self, task_id: str, output_path: str) -> bool: + """ + 尝试从已有的 task_id 恢复视频 (查询状态并下载) + """ + try: + status, video_url = self._check_task(task_id) + logger.info(f"Recovering task {task_id}: status={status}") + + if status == "succeeded" and video_url: + downloaded_path = self._download_video(video_url, os.path.basename(output_path)) + if downloaded_path: + # 如果下载的文件名和目标路径不一致 (download_video 使用 filename 参数拼接到 TEMP_DIR), + # 需要移动或确认。 _download_video 返回完整路径。 + # 如果 output_path 是绝对路径且不同,则移动。 + if os.path.abspath(downloaded_path) != os.path.abspath(output_path): + import shutil + shutil.move(downloaded_path, output_path) + return True + return False + except Exception as e: + logger.error(f"Failed to recover video task {task_id}: {e}") + return False + + def check_task_status(self, task_id: str) -> tuple[str, str]: + """ + 查询任务状态 + Returns: (status, video_url) + """ + return self._check_task(task_id) + + def generate_scene_videos( + self, + project_id: str, + script: Dict[str, Any], + scene_images: Dict[int, str] + ) -> Dict[int, str]: + """ + 批量生成分镜视频 (Legacy: 阻塞式轮询) + """ + generated_videos = {} + tasks = {} # scene_id -> task_id + + scenes = script.get("scenes", []) + + # 1. 提交所有任务 + for scene in scenes: + scene_id = scene["id"] + image_path = scene_images.get(scene_id) + prompt = scene.get("video_prompt", "High quality video") + + # Use new method signature with project_id + task_id = self.submit_scene_video_task(project_id, scene_id, image_path, prompt) + + if task_id: + tasks[scene_id] = task_id + logger.info(f"Task submitted: {task_id}") + else: + logger.error(f"Failed to submit task for Scene {scene_id}") + + # 2. 轮询任务状态 + pending_tasks = list(tasks.keys()) + + # 设置最大轮询时间 (例如 10 分钟) + start_time = time.time() + timeout = 600 + + while pending_tasks and (time.time() - start_time < timeout): + logger.info(f"Polling status for {len(pending_tasks)} tasks...") + + still_pending = [] + for scene_id in pending_tasks: + task_id = tasks[scene_id] + status, result_url = self._check_task(task_id) + + if status == "succeeded": + logger.info(f"Scene {scene_id} video generated successfully") + # 下载视频 + video_path = self._download_video(result_url, f"scene_{scene_id}_video.mp4") + if video_path: + generated_videos[scene_id] = video_path + # Update DB + db.save_asset( + project_id=project_id, + scene_id=scene_id, + asset_type="video", + status="completed", + local_path=video_path, + task_id=task_id + ) + elif status == "failed" or status == "cancelled": + logger.error(f"Scene {scene_id} task failed/cancelled") + db.save_asset( + project_id=project_id, + scene_id=scene_id, + asset_type="video", + status="failed", + task_id=task_id + ) + else: + # running, queued + still_pending.append(scene_id) + + pending_tasks = still_pending + if pending_tasks: + time.sleep(5) # 间隔 5 秒 + + return generated_videos + + def _submit_task(self, image_url: str, prompt: str) -> str: + """提交生成任务""" + url = f"{self.base_url}/contents/generations/tasks" + + payload = { + "model": self.model_id, + "content": [ + { + "type": "text", + "text": f"{prompt} --resolution 1080p --duration 3 --camerafixed false --watermark false" + }, + { + "type": "image_url", + "image_url": {"url": image_url} + } + ] + } + + try: + response = requests.post(url, headers=self.headers, json=payload, timeout=30) + response.raise_for_status() + data = response.json() + # ID might be at top level or in data object depending on exact API version response + # Document says: { "id": "...", "status": "..." } or similar + task_id = data.get("id") + if not task_id and "data" in data: + task_id = data.get("data", {}).get("id") + + return task_id + except Exception as e: + logger.error(f"Task submission failed: {e}") + if 'response' in locals(): + logger.error(f"Response: {response.text}") + return None + + def _check_task(self, task_id: str) -> tuple[str, str]: + """ + 检查任务状态 + Returns: (status, content_url) + Status: queued, running, succeeded, failed, cancelled + """ + url = f"{self.base_url}/contents/generations/tasks/{task_id}" + + try: + response = requests.get(url, headers=self.headers, timeout=30) + response.raise_for_status() + data = response.json() + + # API Response structure: + # { "id": "...", "status": "succeeded", "content": [ { "url": "...", "video_url": "..." } ] } + # Or nested in "data" key + + result = data + if "data" in data and "status" not in data: # Check if wrapped in data + result = data["data"] + + status = result.get("status") + content_url = None + + if status == "succeeded": + if "content" in result: + content = result["content"] + if isinstance(content, list) and len(content) > 0: + item = content[0] + content_url = item.get("video_url") or item.get("url") + elif isinstance(content, dict): + content_url = content.get("video_url") or content.get("url") + + return status, content_url + + except Exception as e: + logger.error(f"Check task failed: {e}") + return "unknown", None + + def _download_video(self, url: str, filename: str) -> str: + """下载视频到临时目录""" + if not url: + return None + + try: + response = requests.get(url, stream=True, timeout=60) + response.raise_for_status() + + output_path = config.TEMP_DIR / filename + with open(output_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + return str(output_path) + except Exception as e: + logger.error(f"Download video failed: {e}") + return None diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e8d6de2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,31 @@ +# Gloda Video Factory - Dependencies +# Python 3.10+ + +# Core LLM +openai>=1.0.0 + +# Image Generation +fal-client>=0.4.0 + +# Video Generation (Real Mode) +PyJWT>=2.8.0 +requests>=2.31.0 + +# Audio Generation +elevenlabs>=1.0.0 +gTTS>=2.4.0 + +# Video Processing +moviepy==1.0.3 +imageio[ffmpeg]>=2.33.0 +Pillow>=10.0.0 +numpy>=1.24.0 + +# Web UI +streamlit>=1.29.0 + +# Config +python-dotenv>=1.0.0 +PyYAML>=6.0.1 +boto3>=1.34.0 + diff --git a/volcengine_binary_demo/examples/volcengine/binary.py b/volcengine_binary_demo/examples/volcengine/binary.py new file mode 100644 index 0000000..e273079 --- /dev/null +++ b/volcengine_binary_demo/examples/volcengine/binary.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +import argparse +import json +import logging +import uuid + +import websockets + +from protocols import MsgType, full_client_request, receive_message + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_cluster(voice: str) -> str: + if voice.startswith("S_"): + return "volcano_icl" + return "volcano_tts" + + +async def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--appid", required=True, help="APP ID") + parser.add_argument("--access_token", required=True, help="Access Token") + parser.add_argument("--voice_type", required=True, help="Voice type") + parser.add_argument("--cluster", default="", help="Cluster name") + parser.add_argument("--text", required=True, help="Text to convert") + parser.add_argument("--encoding", default="wav", help="Output file encoding") + parser.add_argument( + "--endpoint", + default="wss://openspeech.bytedance.com/api/v1/tts/ws_binary", + help="WebSocket endpoint URL", + ) + + args = parser.parse_args() + + # Determine cluster + cluster = args.cluster if args.cluster else get_cluster(args.voice_type) + + # Connect to server + headers = { + "Authorization": f"Bearer;{args.access_token}", + } + + logger.info(f"Connecting to {args.endpoint} with headers: {headers}") + websocket = await websockets.connect( + args.endpoint, additional_headers=headers, max_size=10 * 1024 * 1024 + ) + logger.info( + f"Connected to WebSocket server, Logid: {websocket.response.headers['x-tt-logid']}", + ) + + try: + # Prepare request payload + request = { + "app": { + "appid": args.appid, + "token": args.access_token, + "cluster": cluster, + }, + "user": { + "uid": str(uuid.uuid4()), + }, + "audio": { + "voice_type": args.voice_type, + "encoding": args.encoding, + }, + "request": { + "reqid": str(uuid.uuid4()), + "text": args.text, + "operation": "submit", + "with_timestamp": "1", + "extra_param": json.dumps( + { + "disable_markdown_filter": False, + } + ), + }, + } + + # Send request + await full_client_request(websocket, json.dumps(request).encode()) + + # Receive audio data + audio_data = bytearray() + while True: + msg = await receive_message(websocket) + + if msg.type == MsgType.FrontEndResultServer: + continue + elif msg.type == MsgType.AudioOnlyServer: + audio_data.extend(msg.payload) + if msg.sequence < 0: # Last message + break + else: + raise RuntimeError(f"TTS conversion failed: {msg}") + + # Check if we received any audio data + if not audio_data: + raise RuntimeError("No audio data received") + + # Save audio file + filename = f"{args.voice_type}.{args.encoding}" + with open(filename, "wb") as f: + f.write(audio_data) + logger.info(f"Audio received: {len(audio_data)}, saved to {filename}") + + finally: + await websocket.close() + logger.info("Connection closed") + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/volcengine_binary_demo/protocols/__init__.py b/volcengine_binary_demo/protocols/__init__.py new file mode 100644 index 0000000..00d9367 --- /dev/null +++ b/volcengine_binary_demo/protocols/__init__.py @@ -0,0 +1,41 @@ +from .protocols import ( + CompressionBits, + EventType, + HeaderSizeBits, + Message, + MsgType, + MsgTypeFlagBits, + SerializationBits, + VersionBits, + audio_only_client, + cancel_session, + finish_connection, + finish_session, + full_client_request, + receive_message, + start_connection, + start_session, + task_request, + wait_for_event, +) + +__all__ = [ + "CompressionBits", + "EventType", + "HeaderSizeBits", + "Message", + "MsgType", + "MsgTypeFlagBits", + "SerializationBits", + "VersionBits", + "audio_only_client", + "cancel_session", + "finish_connection", + "finish_session", + "full_client_request", + "receive_message", + "start_connection", + "start_session", + "task_request", + "wait_for_event", +] diff --git a/volcengine_binary_demo/protocols/protocols.py b/volcengine_binary_demo/protocols/protocols.py new file mode 100644 index 0000000..6d76488 --- /dev/null +++ b/volcengine_binary_demo/protocols/protocols.py @@ -0,0 +1,543 @@ +import io +import logging +import struct +from dataclasses import dataclass +from enum import IntEnum +from typing import Callable, List + +import websockets + +logger = logging.getLogger(__name__) + + +class MsgType(IntEnum): + """Message type enumeration""" + + Invalid = 0 + FullClientRequest = 0b1 + AudioOnlyClient = 0b10 + FullServerResponse = 0b1001 + AudioOnlyServer = 0b1011 + FrontEndResultServer = 0b1100 + Error = 0b1111 + + # Alias + ServerACK = AudioOnlyServer + + def __str__(self) -> str: + return self.name if self.name else f"MsgType({self.value})" + + +class MsgTypeFlagBits(IntEnum): + """Message type flag bits""" + + NoSeq = 0 # Non-terminal packet with no sequence + PositiveSeq = 0b1 # Non-terminal packet with sequence > 0 + LastNoSeq = 0b10 # Last packet with no sequence + NegativeSeq = 0b11 # Last packet with sequence < 0 + WithEvent = 0b100 # Payload contains event number (int32) + + +class VersionBits(IntEnum): + """Version bits""" + + Version1 = 1 + Version2 = 2 + Version3 = 3 + Version4 = 4 + + +class HeaderSizeBits(IntEnum): + """Header size bits""" + + HeaderSize4 = 1 + HeaderSize8 = 2 + HeaderSize12 = 3 + HeaderSize16 = 4 + + +class SerializationBits(IntEnum): + """Serialization method bits""" + + Raw = 0 + JSON = 0b1 + Thrift = 0b11 + Custom = 0b1111 + + +class CompressionBits(IntEnum): + """Compression method bits""" + + None_ = 0 + Gzip = 0b1 + Custom = 0b1111 + + +class EventType(IntEnum): + """Event type enumeration""" + + None_ = 0 # Default event + + # 1 ~ 49 Upstream Connection events + StartConnection = 1 + StartTask = 1 # Alias of StartConnection + FinishConnection = 2 + FinishTask = 2 # Alias of FinishConnection + + # 50 ~ 99 Downstream Connection events + ConnectionStarted = 50 # Connection established successfully + TaskStarted = 50 # Alias of ConnectionStarted + ConnectionFailed = 51 # Connection failed (possibly due to authentication failure) + TaskFailed = 51 # Alias of ConnectionFailed + ConnectionFinished = 52 # Connection ended + TaskFinished = 52 # Alias of ConnectionFinished + + # 100 ~ 149 Upstream Session events + StartSession = 100 + CancelSession = 101 + FinishSession = 102 + + # 150 ~ 199 Downstream Session events + SessionStarted = 150 + SessionCanceled = 151 + SessionFinished = 152 + SessionFailed = 153 + UsageResponse = 154 # Usage response + ChargeData = 154 # Alias of UsageResponse + + # 200 ~ 249 Upstream general events + TaskRequest = 200 + UpdateConfig = 201 + + # 250 ~ 299 Downstream general events + AudioMuted = 250 + + # 300 ~ 349 Upstream TTS events + SayHello = 300 + + # 350 ~ 399 Downstream TTS events + TTSSentenceStart = 350 + TTSSentenceEnd = 351 + TTSResponse = 352 + TTSEnded = 359 + PodcastRoundStart = 360 + PodcastRoundResponse = 361 + PodcastRoundEnd = 362 + + # 450 ~ 499 Downstream ASR events + ASRInfo = 450 + ASRResponse = 451 + ASREnded = 459 + + # 500 ~ 549 Upstream dialogue events + ChatTTSText = 500 # (Ground-Truth-Alignment) text for speech synthesis + + # 550 ~ 599 Downstream dialogue events + ChatResponse = 550 + ChatEnded = 559 + + # 650 ~ 699 Downstream dialogue events + # Events for source (original) language subtitle + SourceSubtitleStart = 650 + SourceSubtitleResponse = 651 + SourceSubtitleEnd = 652 + # Events for target (translation) language subtitle + TranslationSubtitleStart = 653 + TranslationSubtitleResponse = 654 + TranslationSubtitleEnd = 655 + + def __str__(self) -> str: + return self.name if self.name else f"EventType({self.value})" + + +@dataclass +class Message: + """Message object + + Message format: + 0 1 2 3 + | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Version | Header Size | Msg Type | Flags | + | (4 bits) | (4 bits) | (4 bits) | (4 bits) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Serialization | Compression | Reserved | + | (4 bits) | (4 bits) | (8 bits) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | Optional Header Extensions | + | (if Header Size > 1) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | Payload | + | (variable length) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + """ + + version: VersionBits = VersionBits.Version1 + header_size: HeaderSizeBits = HeaderSizeBits.HeaderSize4 + type: MsgType = MsgType.Invalid + flag: MsgTypeFlagBits = MsgTypeFlagBits.NoSeq + serialization: SerializationBits = SerializationBits.JSON + compression: CompressionBits = CompressionBits.None_ + + event: EventType = EventType.None_ + session_id: str = "" + connect_id: str = "" + sequence: int = 0 + error_code: int = 0 + + payload: bytes = b"" + + @classmethod + def from_bytes(cls, data: bytes) -> "Message": + """Create message object from bytes""" + if len(data) < 3: + raise ValueError( + f"Data too short: expected at least 3 bytes, got {len(data)}" + ) + + type_and_flag = data[1] + msg_type = MsgType(type_and_flag >> 4) + flag = MsgTypeFlagBits(type_and_flag & 0b00001111) + + msg = cls(type=msg_type, flag=flag) + msg.unmarshal(data) + return msg + + def marshal(self) -> bytes: + """Serialize message to bytes""" + buffer = io.BytesIO() + + # Write header + header = [ + (self.version << 4) | self.header_size, + (self.type << 4) | self.flag, + (self.serialization << 4) | self.compression, + ] + + header_size = 4 * self.header_size + if padding := header_size - len(header): + header.extend([0] * padding) + + buffer.write(bytes(header)) + + # Write other fields + writers = self._get_writers() + for writer in writers: + writer(buffer) + + return buffer.getvalue() + + def unmarshal(self, data: bytes) -> None: + """Deserialize message from bytes""" + buffer = io.BytesIO(data) + + # Read version and header size + version_and_header_size = buffer.read(1)[0] + self.version = VersionBits(version_and_header_size >> 4) + self.header_size = HeaderSizeBits(version_and_header_size & 0b00001111) + + # Skip second byte + buffer.read(1) + + # Read serialization and compression methods + serialization_compression = buffer.read(1)[0] + self.serialization = SerializationBits(serialization_compression >> 4) + self.compression = CompressionBits(serialization_compression & 0b00001111) + + # Skip header padding + header_size = 4 * self.header_size + read_size = 3 + if padding_size := header_size - read_size: + buffer.read(padding_size) + + # Read other fields + readers = self._get_readers() + for reader in readers: + reader(buffer) + + # Check for remaining data + remaining = buffer.read() + if remaining: + raise ValueError(f"Unexpected data after message: {remaining}") + + def _get_writers(self) -> List[Callable[[io.BytesIO], None]]: + """Get list of writer functions""" + writers = [] + + if self.flag == MsgTypeFlagBits.WithEvent: + writers.extend([self._write_event, self._write_session_id]) + + if self.type in [ + MsgType.FullClientRequest, + MsgType.FullServerResponse, + MsgType.FrontEndResultServer, + MsgType.AudioOnlyClient, + MsgType.AudioOnlyServer, + ]: + if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]: + writers.append(self._write_sequence) + elif self.type == MsgType.Error: + writers.append(self._write_error_code) + else: + raise ValueError(f"Unsupported message type: {self.type}") + + writers.append(self._write_payload) + return writers + + def _get_readers(self) -> List[Callable[[io.BytesIO], None]]: + """Get list of reader functions""" + readers = [] + + if self.type in [ + MsgType.FullClientRequest, + MsgType.FullServerResponse, + MsgType.FrontEndResultServer, + MsgType.AudioOnlyClient, + MsgType.AudioOnlyServer, + ]: + if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]: + readers.append(self._read_sequence) + elif self.type == MsgType.Error: + readers.append(self._read_error_code) + else: + raise ValueError(f"Unsupported message type: {self.type}") + + if self.flag == MsgTypeFlagBits.WithEvent: + readers.extend( + [self._read_event, self._read_session_id, self._read_connect_id] + ) + + readers.append(self._read_payload) + return readers + + def _write_event(self, buffer: io.BytesIO) -> None: + """Write event""" + buffer.write(struct.pack(">i", self.event)) + + def _write_session_id(self, buffer: io.BytesIO) -> None: + """Write session ID""" + if self.event in [ + EventType.StartConnection, + EventType.FinishConnection, + EventType.ConnectionStarted, + EventType.ConnectionFailed, + ]: + return + + session_id_bytes = self.session_id.encode("utf-8") + size = len(session_id_bytes) + if size > 0xFFFFFFFF: + raise ValueError(f"Session ID size ({size}) exceeds max(uint32)") + + buffer.write(struct.pack(">I", size)) + if size > 0: + buffer.write(session_id_bytes) + + def _write_sequence(self, buffer: io.BytesIO) -> None: + """Write sequence number""" + buffer.write(struct.pack(">i", self.sequence)) + + def _write_error_code(self, buffer: io.BytesIO) -> None: + """Write error code""" + buffer.write(struct.pack(">I", self.error_code)) + + def _write_payload(self, buffer: io.BytesIO) -> None: + """Write payload""" + size = len(self.payload) + if size > 0xFFFFFFFF: + raise ValueError(f"Payload size ({size}) exceeds max(uint32)") + + buffer.write(struct.pack(">I", size)) + buffer.write(self.payload) + + def _read_event(self, buffer: io.BytesIO) -> None: + """Read event""" + event_bytes = buffer.read(4) + if event_bytes: + self.event = EventType(struct.unpack(">i", event_bytes)[0]) + + def _read_session_id(self, buffer: io.BytesIO) -> None: + """Read session ID""" + if self.event in [ + EventType.StartConnection, + EventType.FinishConnection, + EventType.ConnectionStarted, + EventType.ConnectionFailed, + EventType.ConnectionFinished, + ]: + return + + size_bytes = buffer.read(4) + if size_bytes: + size = struct.unpack(">I", size_bytes)[0] + if size > 0: + session_id_bytes = buffer.read(size) + if len(session_id_bytes) == size: + self.session_id = session_id_bytes.decode("utf-8") + + def _read_connect_id(self, buffer: io.BytesIO) -> None: + """Read connection ID""" + if self.event in [ + EventType.ConnectionStarted, + EventType.ConnectionFailed, + EventType.ConnectionFinished, + ]: + size_bytes = buffer.read(4) + if size_bytes: + size = struct.unpack(">I", size_bytes)[0] + if size > 0: + self.connect_id = buffer.read(size).decode("utf-8") + + def _read_sequence(self, buffer: io.BytesIO) -> None: + """Read sequence number""" + sequence_bytes = buffer.read(4) + if sequence_bytes: + self.sequence = struct.unpack(">i", sequence_bytes)[0] + + def _read_error_code(self, buffer: io.BytesIO) -> None: + """Read error code""" + error_code_bytes = buffer.read(4) + if error_code_bytes: + self.error_code = struct.unpack(">I", error_code_bytes)[0] + + def _read_payload(self, buffer: io.BytesIO) -> None: + """Read payload""" + size_bytes = buffer.read(4) + if size_bytes: + size = struct.unpack(">I", size_bytes)[0] + if size > 0: + self.payload = buffer.read(size) + + def __str__(self) -> str: + """String representation""" + if self.type in [MsgType.AudioOnlyServer, MsgType.AudioOnlyClient]: + if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]: + return f"MsgType: {self.type}, EventType:{self.event}, Sequence: {self.sequence}, PayloadSize: {len(self.payload)}" + return f"MsgType: {self.type}, EventType:{self.event}, PayloadSize: {len(self.payload)}" + elif self.type == MsgType.Error: + return f"MsgType: {self.type}, EventType:{self.event}, ErrorCode: {self.error_code}, Payload: {self.payload.decode('utf-8', 'ignore')}" + else: + if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]: + return f"MsgType: {self.type}, EventType:{self.event}, Sequence: {self.sequence}, Payload: {self.payload.decode('utf-8', 'ignore')}" + return f"MsgType: {self.type}, EventType:{self.event}, Payload: {self.payload.decode('utf-8', 'ignore')}" + + +async def receive_message(websocket: websockets.WebSocketClientProtocol) -> Message: + """Receive message from websocket""" + try: + data = await websocket.recv() + if isinstance(data, str): + raise ValueError(f"Unexpected text message: {data}") + elif isinstance(data, bytes): + msg = Message.from_bytes(data) + logger.info(f"Received: {msg}") + return msg + else: + raise ValueError(f"Unexpected message type: {type(data)}") + except Exception as e: + logger.error(f"Failed to receive message: {e}") + raise + + +async def wait_for_event( + websocket: websockets.WebSocketClientProtocol, + msg_type: MsgType, + event_type: EventType, +) -> Message: + """Wait for specific event""" + while True: + msg = await receive_message(websocket) + if msg.type != msg_type or msg.event != event_type: + raise ValueError(f"Unexpected message: {msg}") + if msg.type == msg_type and msg.event == event_type: + return msg + + +async def full_client_request( + websocket: websockets.WebSocketClientProtocol, payload: bytes +) -> None: + """Send full client message""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.NoSeq) + msg.payload = payload + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def audio_only_client( + websocket: websockets.WebSocketClientProtocol, payload: bytes, flag: MsgTypeFlagBits +) -> None: + """Send audio-only client message""" + msg = Message(type=MsgType.AudioOnlyClient, flag=flag) + msg.payload = payload + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def start_connection(websocket: websockets.WebSocketClientProtocol) -> None: + """Start connection""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.StartConnection + msg.payload = b"{}" + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def finish_connection(websocket: websockets.WebSocketClientProtocol) -> None: + """Finish connection""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.FinishConnection + msg.payload = b"{}" + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def start_session( + websocket: websockets.WebSocketClientProtocol, payload: bytes, session_id: str +) -> None: + """Start session""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.StartSession + msg.session_id = session_id + msg.payload = payload + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def finish_session( + websocket: websockets.WebSocketClientProtocol, session_id: str +) -> None: + """Finish session""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.FinishSession + msg.session_id = session_id + msg.payload = b"{}" + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def cancel_session( + websocket: websockets.WebSocketClientProtocol, session_id: str +) -> None: + """Cancel session""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.CancelSession + msg.session_id = session_id + msg.payload = b"{}" + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) + + +async def task_request( + websocket: websockets.WebSocketClientProtocol, payload: bytes, session_id: str +) -> None: + """Send task request""" + msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent) + msg.event = EventType.TaskRequest + msg.session_id = session_id + msg.payload = payload + logger.info(f"Sending: {msg}") + await websocket.send(msg.marshal()) diff --git a/volcengine_binary_demo/pyproject.toml b/volcengine_binary_demo/pyproject.toml new file mode 100644 index 0000000..827d227 --- /dev/null +++ b/volcengine_binary_demo/pyproject.toml @@ -0,0 +1,11 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "volc-speech-python-sdk" +version = "0.1.0" +requires-python = ">=3.9" +dependencies = [ + "websockets>=14.0", +] diff --git a/volcengine_binary_demo/setup.py b/volcengine_binary_demo/setup.py new file mode 100644 index 0000000..408fe5d --- /dev/null +++ b/volcengine_binary_demo/setup.py @@ -0,0 +1,11 @@ +from setuptools import find_packages, setup + +setup( + name="volc-speech-python-sdk", + version="0.1.0", + packages=find_packages(include=["protocols"]), + install_requires=[ + "websockets>=14.0", + ], + python_requires=">=3.9", +) diff --git a/web_app.py b/web_app.py new file mode 100644 index 0000000..6162225 --- /dev/null +++ b/web_app.py @@ -0,0 +1,593 @@ +""" +MatchMe Studio - 6-Step Video Creation Wizard (v2) +""" +import streamlit as st +import logging +from pathlib import Path + +import config +from modules import brain, factory, editor, storage, ingest, asr, project + +logging.basicConfig(level=logging.INFO) + +st.set_page_config( + page_title="MatchMe 视频工场", + page_icon="🎬", + layout="wide" +) + +# Custom CSS +st.markdown(""" + +""", unsafe_allow_html=True) + + +def init_session(): + """Initialize session state.""" + if "proj" not in st.session_state: + st.session_state.proj = project.create_project() + if "step" not in st.session_state: + st.session_state.step = 0 + if "brief" not in st.session_state: + st.session_state.brief = {} + + +def render_sidebar(): + """Render sidebar with project info.""" + with st.sidebar: + st.header("项目控制台") + + proj = st.session_state.proj + st.text(f"项目 ID: {proj.id}") + st.text(f"状态: {proj.status}") + + st.divider() + + load_id = st.text_input("恢复项目 (输入ID)") + if st.button("加载"): + loaded = project.load_project(load_id) + if loaded: + st.session_state.proj = loaded + st.success(f"已加载项目 {load_id}") + st.rerun() + else: + st.error("项目不存在") + + st.divider() + + if st.button("重置项目"): + st.session_state.proj = project.create_project() + st.session_state.step = 0 + st.session_state.brief = {} + st.rerun() + + st.divider() + steps = ["素材提交", "AI分析", "脚本生成", "画面生成", "视频生成", "最终合成"] + for i, name in enumerate(steps): + if i == st.session_state.step: + st.markdown(f"**→ {i}. {name}**") + elif i < st.session_state.step: + st.markdown(f"✅ {i}. {name}") + else: + st.markdown(f"○ {i}. {name}") + + +def step0_ingest(): + """Step 0: Material Submission.""" + st.markdown('
Step 0: 素材提交
', unsafe_allow_html=True) + + proj = st.session_state.proj + + mode = st.radio( + "选择输入方式", + ["纯文本创意", "图片 + 描述", "视频 + 描述"], + horizontal=True + ) + + prompt = st.text_area("创意描述 / 产品卖点", height=100, placeholder="描述你想要的视频内容...") + + if mode == "纯文本创意": + proj.input_mode = "text" + + elif mode == "图片 + 描述": + proj.input_mode = "images" + uploaded = st.file_uploader("上传图片 (支持多张)", type=["jpg", "png", "jpeg"], accept_multiple_files=True) + + if uploaded: + urls = [] + with st.spinner("上传图片中..."): + for f in uploaded: + temp_path = config.TEMP_DIR / f.name + with open(temp_path, "wb") as fp: + fp.write(f.getbuffer()) + url = storage.upload_file(str(temp_path)) + if url: + urls.append(url) + else: + st.error(f"上传失败: {f.name}") + + if urls: + proj.image_urls = urls + st.image(urls, width=150) + st.success(f"成功上传 {len(urls)} 张图片") + + elif mode == "视频 + 描述": + proj.input_mode = "video" + uploaded = st.file_uploader("上传视频", type=["mp4"]) + + if uploaded: + with st.spinner("处理视频中..."): + temp_path = config.TEMP_DIR / uploaded.name + with open(temp_path, "wb") as f: + f.write(uploaded.getbuffer()) + + try: + frame_urls, video_url = ingest.process_uploaded_video(str(temp_path)) + proj.image_urls = frame_urls + proj.video_url = video_url + st.image(frame_urls, width=150, caption=["帧1", "帧2", "帧3"]) + except Exception as e: + st.error(f"视频处理失败: {e}") + + try: + asr_text = asr.transcribe_video(str(temp_path)) + proj.asr_text = asr_text + st.info(f"语音识别: {asr_text[:100]}...") + except Exception as e: + st.warning(f"语音识别失败: {e}") + + proj.prompt = prompt + + if st.button("下一步: AI 分析", disabled=not prompt): + proj.status = "analyzing" + project.save_project(proj) + st.session_state.step = 1 + st.rerun() + + +def step1_analyze(): + """Step 1: AI Analysis & Questions with multi-select and custom input.""" + st.markdown('
Step 1: AI 深度分析
', unsafe_allow_html=True) + + proj = st.session_state.proj + + # Run analysis if not done + if not proj.analysis: + with st.spinner("AI 正在分析素材..."): + result = brain.analyze_materials( + prompt=proj.prompt, + image_urls=proj.image_urls if proj.image_urls else None, + asr_text=proj.asr_text + ) + proj.analysis = result.get("analysis", "") + proj.questions = result.get("questions", []) + project.save_project(proj) + + st.subheader("分析结果") + st.write(proj.analysis) + + # Show questions with multi-select and custom input + if proj.questions: + st.subheader("补充信息") + st.caption("请回答以下问题,帮助 AI 更好地理解你的需求") + + answers = {} + for q in proj.questions: + qid = q["id"] + st.markdown(f'
', unsafe_allow_html=True) + + # Check if multi-select is allowed + allow_multiple = q.get("allow_multiple", False) + allow_custom = q.get("allow_custom", True) + + if allow_multiple: + selected = st.multiselect( + q["text"], + q["options"], + key=f"q_{qid}" + ) + answers[qid] = {"selected": selected} + else: + selected = st.radio( + q["text"], + q["options"], + key=f"q_{qid}" + ) + answers[qid] = {"selected": [selected] if selected else []} + + # Custom input for additional context + if allow_custom: + custom = st.text_input( + "补充说明 (选填)", + key=f"custom_{qid}", + placeholder="如有其他想法,请在此补充..." + ) + answers[qid]["custom"] = custom + + st.markdown('
', unsafe_allow_html=True) + + if st.button("确认回答,生成创意简报"): + proj.answers = answers + + # Refine brief with answers + with st.spinner("整合创意简报中..."): + brief_result = brain.refine_brief( + proj.prompt, + {"analysis": proj.analysis}, + answers, + proj.image_urls + ) + st.session_state.brief = brief_result.get("brief", {}) + + # Store creative summary + if "creative_summary" in brief_result: + st.session_state.brief["creative_summary"] = brief_result["creative_summary"] + + project.save_project(proj) + st.session_state.step = 2 + st.rerun() + else: + # No questions needed, build basic brief + if st.button("下一步: 生成脚本"): + st.session_state.brief = { + "product": proj.prompt, + "selling_points": [], + "style": "现代广告" + } + st.session_state.step = 2 + st.rerun() + + +def step2_script(): + """Step 2: Script Generation.""" + st.markdown('
Step 2: 脚本生成
', unsafe_allow_html=True) + + proj = st.session_state.proj + brief = st.session_state.brief + + # Show creative summary + if brief.get("creative_summary"): + st.info(f"🎯 创意方向: {brief['creative_summary']}") + + if brief.get("style"): + st.caption(f"视频风格: {brief['style']}") + + # Generate script if not done + if not proj.scenes: + with st.spinner("AI 正在创作脚本..."): + script = brain.generate_script(brief, proj.image_urls) + proj.hook = script.get("hook", "") + proj.scenes = script.get("scenes", []) + proj.cta = script.get("cta", "") + + # Store creative summary from script if available + if script.get("creative_summary"): + brief["creative_summary"] = script["creative_summary"] + st.session_state.brief = brief + + proj.status = "scripting" + project.save_project(proj) + + # Display script + st.subheader(f"🎣 Hook: {proj.hook}") + + # Creative summary + if brief.get("creative_summary"): + st.markdown(f"**整体创意**: {brief['creative_summary']}") + + for i, scene in enumerate(proj.scenes): + with st.expander(f"分镜 {scene.get('id', i+1)}: {scene.get('timeline', '')}"): + col1, col2 = st.columns(2) + + with col1: + st.write(f"**时长**: {scene.get('duration', 5)}秒") + st.write(f"**运镜**: {scene.get('camera_movement', '')}") + st.write(f"**故事节拍**: {scene.get('story_beat', '')}") + st.write(f"**音效设计**: {scene.get('sound_design', '')}") + + with col2: + kf = scene.get("keyframe", {}) + st.write(f"**色调**: {kf.get('color_tone', '')}") + st.write(f"**环境**: {kf.get('environment', '')}") + st.write(f"**焦点**: {kf.get('focus', '')}") + st.write(f"**构图**: {kf.get('composition', '')}") + + # Image prompt (key for generation) + st.write("**生图Prompt**:") + st.code(scene.get("image_prompt", "(未生成)"), language=None) + + st.write(f"**旁白**: {scene.get('voiceover', '(无)')}") + + feedback = st.text_input(f"修改意见", key=f"fb_{i}") + if st.button(f"重新生成此分镜", key=f"regen_{i}"): + with st.spinner("重新生成中..."): + new_scene = brain.regenerate_scene( + {"hook": proj.hook, "scenes": proj.scenes, "cta": proj.cta}, + scene.get("id", i+1), + feedback, + brief + ) + proj.scenes[i] = new_scene + project.save_project(proj) + st.rerun() + + # CTA - ensure it's a string + cta_text = proj.cta + if isinstance(cta_text, dict): + cta_text = cta_text.get("text", str(cta_text)) + st.subheader(f"📢 CTA: {cta_text}") + + col1, col2 = st.columns(2) + with col1: + regen_feedback = st.text_input("整体修改意见") + if st.button("重新生成整个脚本"): + with st.spinner("重新生成中..."): + script = brain.generate_script(brief, proj.image_urls, regen_feedback) + proj.hook = script.get("hook", "") + proj.scenes = script.get("scenes", []) + proj.cta = script.get("cta", "") + project.save_project(proj) + st.rerun() + + with col2: + if st.button("确认脚本,下一步"): + st.session_state.step = 3 + st.rerun() + + +def step3_images(): + """Step 3: Image Generation (Concurrent) using Gemini Image.""" + st.markdown('
Step 3: 画面生成 (Gemini Image)
', unsafe_allow_html=True) + + proj = st.session_state.proj + brief = st.session_state.brief + + # Show reference images if available + if proj.image_urls: + st.caption("参考素材(用于保持产品一致性):") + st.image(proj.image_urls[:3], width=100) + + has_images = all(s.get("image_url") for s in proj.scenes) + + if not has_images: + if st.button("开始生成所有画面 (并发)"): + progress = st.progress(0) + status = st.empty() + + try: + status.text("正在并发生成所有分镜画面...") + # Pass user's reference images for product consistency + image_urls = factory.generate_all_scene_images_concurrent( + proj.scenes, + brief, + reference_images=proj.image_urls, # 传递用户素材 + max_workers=3 + ) + + for i, url in enumerate(image_urls): + if url: + proj.scenes[i]["image_url"] = url + progress.progress((i + 1) / len(proj.scenes)) + + proj.status = "imaging" + project.save_project(proj) + st.rerun() + + except Exception as e: + st.error(f"生成失败: {e}") + import traceback + st.code(traceback.format_exc()) + + # Display images in grid + cols = st.columns(min(4, len(proj.scenes))) + for i, scene in enumerate(proj.scenes): + with cols[i % 4]: + img_url = scene.get("image_url", "") + if img_url: + st.image(img_url, caption=f"分镜 {scene.get('id', i+1)}") + + if st.button(f"重新生成", key=f"img_regen_{i}"): + with st.spinner("生成中..."): + url = factory.generate_scene_image(scene, brief, proj.image_urls) + proj.scenes[i]["image_url"] = url + project.save_project(proj) + st.rerun() + + custom = st.file_uploader(f"替换", key=f"img_up_{i}", type=["jpg", "png"]) + if custom: + temp_path = config.TEMP_DIR / custom.name + with open(temp_path, "wb") as f: + f.write(custom.getbuffer()) + url = storage.upload_file(str(temp_path)) + if url: + proj.scenes[i]["image_url"] = url + project.save_project(proj) + st.rerun() + + vo = st.text_area(f"旁白", scene.get("voiceover", ""), key=f"vo_{i}", height=80) + if vo != scene.get("voiceover", ""): + proj.scenes[i]["voiceover"] = vo + project.save_project(proj) + + if has_images and st.button("下一步: 生成视频"): + st.session_state.step = 4 + st.rerun() + + +def step4_videos(): + """Step 4: Video Generation (Concurrent) using Sora 2.""" + st.markdown('
Step 4: 分镜视频生成 (Sora 2)
', unsafe_allow_html=True) + + proj = st.session_state.proj + + has_videos = all(s.get("video_url") for s in proj.scenes) + + if not has_videos: + if st.button("开始生成所有视频 (并发)"): + progress = st.progress(0) + status = st.empty() + + try: + image_urls = [s.get("image_url") for s in proj.scenes] + + status.text("正在并发生成所有分镜视频 (Sora 2)...") + video_urls = factory.generate_all_scene_videos_concurrent( + proj.scenes, + image_urls, + max_workers=2 + ) + + for i, url in enumerate(video_urls): + if url: + proj.scenes[i]["video_url"] = url + progress.progress((i + 1) / len(proj.scenes)) + + proj.status = "video" + project.save_project(proj) + st.rerun() + + except Exception as e: + st.error(f"视频生成失败: {e}") + import traceback + st.code(traceback.format_exc()) + + # Display videos + for i, scene in enumerate(proj.scenes): + vid_url = scene.get("video_url", "") + if vid_url: + col1, col2 = st.columns([3, 1]) + with col1: + st.video(vid_url) + with col2: + st.write(f"分镜 {scene.get('id', i+1)}") + st.write(f"{scene.get('duration', 5)}秒") + + if st.button(f"重新生成", key=f"vid_regen_{i}"): + with st.spinner("生成中..."): + image_url = scene.get("image_url", "") + url = factory.generate_scene_video( + image_url, + scene.get("camera_movement", "slow zoom"), + scene.get("duration", 5) + ) + proj.scenes[i]["video_url"] = url + project.save_project(proj) + st.rerun() + + if has_videos and st.button("下一步: 合成"): + st.session_state.step = 5 + st.rerun() + + +def step5_render(): + """Step 5: Final Rendering.""" + st.markdown('
Step 5: 最终合成
', unsafe_allow_html=True) + + proj = st.session_state.proj + brief = st.session_state.brief + + col1, col2 = st.columns(2) + + with col1: + add_subtitles = st.checkbox("烧录字幕", value=True) + add_voiceover = st.checkbox("添加旁白配音", value=True) + + with col2: + add_bgm = st.checkbox("添加背景音乐", value=False) + bgm_file = None + if add_bgm: + bgm_file = st.file_uploader("上传 BGM", type=["mp3", "wav"]) + + if st.button("开始合成"): + with st.spinner("合成中,请稍候..."): + video_urls = [s.get("video_url") for s in proj.scenes] + + vo_url = "" + if add_voiceover: + style = brief.get("style", "") + vo_url = factory.generate_full_voiceover(proj.scenes, style) + + bgm_url = "" + if bgm_file: + temp_path = config.TEMP_DIR / bgm_file.name + with open(temp_path, "wb") as f: + f.write(bgm_file.getbuffer()) + bgm_url = storage.upload_file(str(temp_path)) + + final_url = editor.assemble_final_video( + video_urls=video_urls, + scenes=proj.scenes if add_subtitles else [], + voiceover_url=vo_url, + bgm_url=bgm_url + ) + + proj.final_video_url = final_url + proj.status = "done" + project.save_project(proj) + + st.success("🎉 视频合成完成!") + st.video(final_url) + st.markdown(f"### [📥 下载高清视频]({final_url})") + + storage.cleanup_temp() + + +def main(): + init_session() + render_sidebar() + + st.title("MatchMe 视频工场 🎬") + st.caption("AI 驱动的短视频创作平台") + + step = st.session_state.step + + if step == 0: + step0_ingest() + elif step == 1: + step1_analyze() + elif step == 2: + step2_script() + elif step == 3: + step3_images() + elif step == 4: + step4_videos() + elif step == 5: + step5_render() + + +if __name__ == "__main__": + main()