Files
video-flow/modules/script_gen.py
Tony Zhang 33a165a615 feat: video-flow initial commit
- app.py: Streamlit UI for video generation workflow
- main_flow.py: CLI tool with argparse support
- modules/: Business logic modules (script_gen, image_gen, video_gen, composer, etc.)
- config.py: Configuration with API keys and paths
- requirements.txt: Python dependencies
- docs/: System prompt documentation
2025-12-12 19:18:27 +08:00

391 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
脚本生成模块 (Gemini-3-Pro)
负责解析商品信息,生成分镜脚本
"""
import base64
import json
import logging
import os
import requests
from typing import Dict, Any, List, Optional
from pathlib import Path
import config
from modules.db_manager import db
logger = logging.getLogger(__name__)
class ScriptGenerator:
"""分镜脚本生成器"""
def __init__(self):
self.api_key = config.SHUBIAOBIAO_KEY
# 注意API 地址可能需要适配 gemini-3-pro-preview 的具体路径
# 根据 demo: https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent
# 这里我们先假设 base_url 是 v1beta/models/
self.endpoint = "https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent"
# Default System Prompt
self.default_system_prompt = """
你是一个专业的抖音电商短视频导演。请根据提供的商品信息和图片,设计一个高转化率的商品详情页首图视频脚本。
## 目标
- 提升商品详情页的 GPM 和下单转化率
- 视频时长 9-12 秒 (由 3-4 个分镜组成)
- **每个分镜时长固定为 3 秒** (duration: 3),不要超过 3 秒
- 必须包含:目标人群分析、卖点提炼、分镜设计
## 分镜设计原则
1. **单分镜单主体**:每个分镜聚焦一个视觉主体或动作,避免复杂运镜,因为 AI 生视频在长时间(>3秒容易出现画面异常。
2. **旁白跨分镜**:一段完整的旁白/卖点可以跨越多个分镜。在 voiceover_timeline 中,通过 start_time 和 duration (秒) 控制旁白的绝对时间位置,无需与分镜一一对应。
3. **节奏感**:分镜之间保持视觉连贯,通过景别变化(特写 -> 中景 -> 全景)制造节奏。
4. **语速控制**:旁白语速约 4 字/秒12字旁白约需 3 秒。
## 输出格式要求 (JSON)
必须严格遵守以下 JSON 结构:
{
"product_name": "商品名称",
"visual_anchor": "商品视觉锚点:材质+颜色+形状+包装特征(用于保持生图一致性)",
"selling_points": ["卖点1", "卖点2"],
"target_audience": "目标人群描述",
"video_style": "视频风格关键词",
"bgm_style": "BGM风格关键词",
"voiceover_timeline": [
{
"id": 1,
"text": "旁白文案片段1可横跨多个分镜",
"subtitle": "字幕文案1 (简短有力)",
"start_time": 0.0,
"duration": 3.0
},
{
"id": 2,
"text": "旁白文案片段2",
"subtitle": "字幕文案2",
"start_time": 3.5,
"duration": 2.5
}
],
"scenes": [
{
"id": 1,
"duration": 3,
"visual_prompt": "详细的画面描述用于AI生图包含主体、背景、构图、光影。英文描述。",
"video_prompt": "详细的动效描述用于AI图生视频。英文描述。",
"fancy_text": {
"text": "花字文案 (最多6字)",
"style": "highlight",
"position": "center",
"start_time": 0.5,
"duration": 2.0
}
}
]
}
## 注意事项
1. **visual_prompt**:
- 必须是英文。
- 描述要具体,例如 "Close-up shot of a hair clip, soft lighting, minimalist background".
- **CRITICAL**: 禁止 AI 额外生成装饰性文字、标语、水印。但必须保留商品包装自带的文字和 Logo这是商品真实外观的一部分
- 正确写法: "Product front view, keep original packaging design --no added text --no watermarks"
- **EMPHASIS**: Strictly follow the appearance of the product in the reference images.
2. **video_prompt**: 必须是英文,描述动作,例如 "Slow zoom in, the hair clip rotates slightly"。注意保持动作简单,避免复杂运镜和人体动作。
3. **voiceover_timeline**:
- 这是整个视频的旁白和字幕时间轴,独立于分镜。
- `start_time` 是旁白开始的绝对时间 (秒)`duration` 是旁白持续时长 (秒)。
- **一段旁白可以横跨多个分镜**,例如:总时长 9 秒 (3 个分镜),一段旁白从 start_time=0duration=5则覆盖前两个分镜。
- 两段旁白之间留 0.3-0.5 秒间隙(气口)。
4. **fancy_text**:
- 花字要精简(最多 6 字),突出卖点。
- **Style Selection**:
- `highlight`: 默认样式,适合通用卖点 (Yellow/Black)。
- `warning`: 强调痛点或食欲 (Red/White)。
- `price`: 价格显示 (Big Red)。
- `bubble`: 旁白补充或用户评价 (Bubble)。
- `minimal`: 高级感,适合时尚类 (Thin/White)。
- `tech`: 数码类 (Cyan/Glow)。
- `position` 默认 `center`,可选 top/bottom/top-left/bottom-right 等。
5. **场景连贯性**: 确保分镜之间的逻辑和视觉风格连贯。每个分镜 duration 必须为 3。
"""
def _encode_image(self, image_path: str) -> str:
"""读取图片并转为 Base64"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def generate_script(
self,
product_name: str,
product_info: Dict[str, Any],
image_paths: List[str] = None,
model_provider: str = "shubiaobiao" # "shubiaobiao" or "doubao"
) -> Dict[str, Any]:
"""
生成分镜脚本
"""
logger.info(f"Generating script for: {product_name} (Provider: {model_provider})")
# 1. 构造 Prompt (优先从数据库读取配置)
system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt)
user_prompt = self._build_user_prompt(product_name, product_info)
# Branch for Doubao
if model_provider == "doubao":
return self._generate_script_doubao(system_prompt, user_prompt, image_paths)
# ... Existing Shubiaobiao Logic ...
# 调试: 检查是否使用了自定义 Prompt
if system_prompt != self.default_system_prompt:
logger.info("Using CUSTOM system prompt from database")
else:
logger.info("Using DEFAULT system prompt")
# 2. 构造请求 Payload (Gemini/Shubiaobiao)
contents = []
# User message parts
user_parts = [{"text": user_prompt}]
# 添加图片 (Multimodal input)
if image_paths:
for path in image_paths[:10]: # 限制10张Gemini-3-Pro 支持多图
if Path(path).exists():
try:
b64_img = self._encode_image(path)
user_parts.append({
"inline_data": {
"mime_type": "image/jpeg", # 假设是 JPG/PNG
"data": b64_img
}
})
except Exception as e:
logger.warning(f"Failed to encode image {path}: {e}")
contents.append({
"role": "user",
"parts": user_parts
})
# System instruction (Gemini 支持 system instruction 或者是放在 user prompt 前)
user_parts.insert(0, {"text": system_prompt})
payload = {
"contents": contents,
"generationConfig": {
"response_mime_type": "application/json",
"temperature": 0.7
}
}
headers = {
"x-goog-api-key": self.api_key,
"Content-Type": "application/json"
}
# 3. 调用 API
try:
response = requests.post(self.endpoint, headers=headers, json=payload, timeout=60)
response.raise_for_status()
result = response.json()
# 4. 解析结果
if "candidates" in result and result["candidates"]:
content_text = result["candidates"][0]["content"]["parts"][0]["text"]
# 提取 JSON 部分 (处理 Markdown 代码块或纯文本)
script_json = self._extract_json_from_response(content_text)
if script_json is None:
logger.error(f"Failed to extract JSON from response: {content_text[:500]}...")
return None
final_script = self._validate_and_fix_script(script_json)
# Add Debug Info (包含原始输出)
final_script["_debug"] = {
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"raw_output": content_text,
"provider": "shubiaobiao"
}
return final_script
else:
logger.error(f"No candidates in response: {result}")
return None
except Exception as e:
logger.error(f"Script generation failed: {e}")
if 'response' in locals():
logger.error(f"Response content: {response.text}")
return None
def _generate_script_doubao(
self,
system_prompt: str,
user_prompt: str,
image_paths: List[str]
) -> Dict[str, Any]:
"""Doubao 脚本生成实现 (Multimodal)"""
# User Provided: https://ark.cn-beijing.volces.com/api/v3/responses
# But for 'responses' API, structure is specific. Let's try to match user's curl format exactly but adapting content.
# User curl uses "input": [{"role": "user", "content": [{"type": "input_image"...}, {"type": "input_text"...}]}]
endpoint = "https://ark.cn-beijing.volces.com/api/v3/chat/completions" # Recommend standard Chat API first as 'responses' is usually non-standard or older
# However, user explicitly provided /responses curl. Let's try to stick to standard Chat Completions first because Doubao Pro 1.5 is OpenAI compatible.
# If that fails or if user insists on the specific structure, we can adapt.
# Volcengine 'ep-...' models are usually served via standard /chat/completions.
# Let's try standard OpenAI format which Doubao supports perfectly.
messages = [
{"role": "system", "content": system_prompt}
]
user_content = []
# Add Images (Doubao Vision supports image_url)
if image_paths:
for path in image_paths[:5]: # Limit
if os.path.exists(path):
# For Volcengine, need to upload or use base64?
# Standard OpenAI format supports base64 data urls.
# "image_url": {"url": "data:image/jpeg;base64,..."}
try:
b64_img = self._encode_image(path)
user_content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64_img}"
}
})
except Exception as e:
logger.warning(f"Failed to encode image for Doubao: {e}")
# Add Text
user_content.append({"type": "text", "text": user_prompt})
messages.append({
"role": "user",
"content": user_content
})
payload = {
"model": config.DOUBAO_SCRIPT_MODEL,
"messages": messages,
"stream": False,
# "response_format": {"type": "json_object"} # Try enabling JSON mode if supported
}
headers = {
"Authorization": f"Bearer {config.VOLC_API_KEY}",
"Content-Type": "application/json"
}
try:
# Try standard chat/completions first
resp = requests.post(endpoint, headers=headers, json=payload, timeout=120)
if resp.status_code != 200:
# If 404, maybe endpoint is wrong, try the user's 'responses' endpoint?
# But 'responses' usually implies a different payload structure.
logger.warning(f"Doubao Chat API failed ({resp.status_code}), trying legacy/custom endpoint...")
# Fallback to user provided structure if needed (implement later if this fails)
resp.raise_for_status()
result = resp.json()
content_text = result["choices"][0]["message"]["content"]
script_json = self._extract_json_from_response(content_text)
if script_json is None:
logger.error(f"Failed to extract JSON from Doubao response: {content_text[:500]}...")
return None
final_script = self._validate_and_fix_script(script_json)
final_script["_debug"] = {
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"raw_output": content_text,
"provider": "doubao"
}
return final_script
except Exception as e:
logger.error(f"Doubao script generation failed: {e}")
if 'resp' in locals():
logger.error(f"Response: {resp.text}")
return None
def _extract_json_from_response(self, text: str) -> Optional[Dict]:
"""
从 API 响应中提取 JSON 对象
支持:
1. 纯 JSON 响应
2. Markdown 代码块包裹的 JSON (```json ... ```)
3. 文本中嵌入的 JSON (找到第一个 { 和最后一个 })
"""
import re
# 方法1: 尝试直接解析(纯 JSON 情况)
try:
return json.loads(text.strip())
except json.JSONDecodeError:
pass
# 方法2: 提取 ```json ... ``` 代码块
json_block_match = re.search(r'```json\s*([\s\S]*?)\s*```', text)
if json_block_match:
try:
return json.loads(json_block_match.group(1))
except json.JSONDecodeError as e:
logger.warning(f"JSON block found but parse failed: {e}")
# 方法3: 提取 ``` ... ``` 代码块 (无 json 标记)
code_block_match = re.search(r'```\s*([\s\S]*?)\s*```', text)
if code_block_match:
try:
return json.loads(code_block_match.group(1))
except json.JSONDecodeError:
pass
# 方法4: 找到第一个 { 和最后一个 } 之间的内容
first_brace = text.find('{')
last_brace = text.rfind('}')
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
try:
return json.loads(text[first_brace:last_brace + 1])
except json.JSONDecodeError as e:
logger.warning(f"Brace extraction failed: {e}")
return None
def _build_user_prompt(self, product_name: str, product_info: Dict[str, Any]) -> str:
# 提取商家偏好提示
style_hint = product_info.get("style_hint", "")
# 过滤掉不需要展示的字段
filtered_info = {k: v for k, v in product_info.items() if k not in ["uploaded_images", "style_hint"]}
info_str = "\n".join([f"- {k}: {v}" for k, v in filtered_info.items()])
prompt = f"""
商品名称:{product_name}
商品信息:
{info_str}
"""
if style_hint:
prompt += f"""
## 商家特别要求
{style_hint}
"""
prompt += "\n请根据以上信息设计视频脚本。"
return prompt
def _validate_and_fix_script(self, script: Dict[str, Any]) -> Dict[str, Any]:
"""校验并修复脚本结构"""
# 简单校验,确保必要字段存在
if "scenes" not in script:
script["scenes"] = []
return script