feat: video-flow initial commit
- app.py: Streamlit UI for video generation workflow - main_flow.py: CLI tool with argparse support - modules/: Business logic modules (script_gen, image_gen, video_gen, composer, etc.) - config.py: Configuration with API keys and paths - requirements.txt: Python dependencies - docs/: System prompt documentation
This commit is contained in:
56
.gitignore
vendored
Normal file
56
.gitignore
vendored
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
.env.*.local
|
||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
*.egg-info/
|
||||||
|
|
||||||
|
# Output
|
||||||
|
output/
|
||||||
|
*.mp4
|
||||||
|
*.mp3
|
||||||
|
*.wav
|
||||||
|
*.m4a
|
||||||
|
|
||||||
|
# Assets (downloaded)
|
||||||
|
assets/fonts/*.ttf
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# 参考
|
||||||
|
参考/
|
||||||
|
|
||||||
|
# 素材
|
||||||
|
素材/
|
||||||
|
|
||||||
|
# Images
|
||||||
|
*.png
|
||||||
|
*.jpeg
|
||||||
|
*.jpg
|
||||||
|
|
||||||
|
# Database & Logs
|
||||||
|
*.db
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Temp files
|
||||||
|
temp/
|
||||||
|
|
||||||
|
# Binaries
|
||||||
|
bin/
|
||||||
|
|
||||||
2152
assets/fonts/NotoSansSC-Bold.otf
Normal file
2152
assets/fonts/NotoSansSC-Bold.otf
Normal file
File diff suppressed because one or more lines are too long
BIN
assets/fonts/NotoSansSC-Regular.otf
Normal file
BIN
assets/fonts/NotoSansSC-Regular.otf
Normal file
Binary file not shown.
181
config.py
Normal file
181
config.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
"""
|
||||||
|
MatchMe Studio - Configuration
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# API Keys
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Volcengine / Doubao (Official)
|
||||||
|
VOLC_API_KEY = os.getenv("VOLC_API_KEY", "05aed9c1-f5e6-487b-9273-fe7d6be51957")
|
||||||
|
VOLC_BASE_URL = os.getenv("VOLC_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3")
|
||||||
|
|
||||||
|
# Models (Updated with User-Provided Endpoint IDs)
|
||||||
|
# LLM: Doubao Pro 1.5 (Using provided brain/vision endpoint)
|
||||||
|
BRAIN_MODEL_ID = os.getenv("BRAIN_MODEL_ID", "ep-20251203231055-dpsp7")
|
||||||
|
# Vision: Doubao Vision Pro 1.5
|
||||||
|
VISION_MODEL_ID = os.getenv("VISION_MODEL_ID", "ep-20251203232121-xjt2s")
|
||||||
|
# Image: Doubao Image (Updated to user provided model)
|
||||||
|
IMAGE_MODEL_ID = os.getenv("IMAGE_MODEL_ID", "ep-20251203231641-wg9nb")
|
||||||
|
# Video: Doubao Video (PixelDance)
|
||||||
|
VIDEO_MODEL_ID = os.getenv("VIDEO_MODEL_ID", "ep-20251207100506-rjx4x")
|
||||||
|
|
||||||
|
# Doubao Specifics (User Provided)
|
||||||
|
DOUBAO_SCRIPT_MODEL = "ep-20251203231055-dpsp7"
|
||||||
|
DOUBAO_IMG_MODEL = "ep-20251203231641-wg9nb"
|
||||||
|
|
||||||
|
|
||||||
|
# Text/Brain API (Legacy)
|
||||||
|
SHUBIAOBIAO_KEY = os.getenv("SHUBIAOBIAO_KEY", "sk-aL167A8sQEyvs40yBfC140Fc0fDa4c198f029aAcF0429108")
|
||||||
|
SHUBIAOBIAO_BASE_URL = os.getenv("SHUBIAOBIAO_BASE_URL", "https://api.shubiaobiao.cn/v1")
|
||||||
|
SHUBIAOBIAO_MODEL_TEXT = "gemini-3-pro-preview"
|
||||||
|
|
||||||
|
# Image Generation API (Updated)
|
||||||
|
# Host: https://api.wuyinkeji.com/
|
||||||
|
# Model: nanoBanana-pro (Gemini)
|
||||||
|
GEMINI_IMG_KEY = os.getenv("GEMINI_IMG_KEY", "G9rXx3Ag2Xfa7Gs8zou6t6HqeZ")
|
||||||
|
GEMINI_IMG_API_URL = os.getenv("GEMINI_IMG_API_URL", "https://api.wuyinkeji.com/api/img/nanoBanana-pro")
|
||||||
|
GEMINI_IMG_DETAIL_URL = os.getenv("GEMINI_IMG_DETAIL_URL", "https://api.wuyinkeji.com/api/img/drawDetail")
|
||||||
|
|
||||||
|
# Legacy Image API
|
||||||
|
SHUBIAOBIAO_IMG_KEY = os.getenv("SHUBIAOBIAO_IMG_KEY", "sk-1yr2h4sJybHB7DED57CeF446D08c4bC989F621Db5b48E70d")
|
||||||
|
SHUBIAOBIAO_IMG_BASE_URL = os.getenv("SHUBIAOBIAO_IMG_BASE_URL", "https://api2img.shubiaobiao.com")
|
||||||
|
SHUBIAOBIAO_IMG_MODEL_NAME = "gemini-3-pro-image-preview"
|
||||||
|
|
||||||
|
# Backup
|
||||||
|
FAL_KEY = os.getenv("FAL_KEY", "")
|
||||||
|
KLING_ACCESS_KEY = os.getenv("KLING_ACCESS_KEY", "")
|
||||||
|
KLING_SECRET_KEY = os.getenv("KLING_SECRET_KEY", "")
|
||||||
|
|
||||||
|
XI_KEY = os.getenv("XI_KEY", "")
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Cloudflare R2 Storage
|
||||||
|
# ============================================================
|
||||||
|
R2_ENDPOINT = os.getenv("R2_ENDPOINT", "")
|
||||||
|
R2_ACCESS_KEY = os.getenv("R2_ACCESS_KEY", "")
|
||||||
|
R2_SECRET_KEY = os.getenv("R2_SECRET_KEY", "")
|
||||||
|
R2_BUCKET_NAME = os.getenv("R2_BUCKET_NAME", "mms-assets")
|
||||||
|
# Public URL for accessing uploaded files
|
||||||
|
R2_PUBLIC_URL = os.getenv("R2_PUBLIC_URL", "https://pub-7942a75aa66d4315a628ee464267ebf4.r2.dev")
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# ElevenLabs Settings (Legacy - for English)
|
||||||
|
# ============================================================
|
||||||
|
ELEVENLABS_VOICE_ID = os.getenv("XI_VOICE_ID", "21m00Tcm4TlvDq8ikWAM")
|
||||||
|
ELEVENLABS_MODEL = "eleven_turbo_v2_5"
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Volcengine TTS Settings (火山引擎语音合成 - 中文)
|
||||||
|
# ============================================================
|
||||||
|
# 申请地址: https://console.volcengine.com/speech/service/8
|
||||||
|
VOLC_TTS_APPID = os.getenv("VOLC_TTS_APPID", "6771884088")
|
||||||
|
VOLC_TTS_ACCESS_TOKEN = os.getenv("VOLC_TTS_ACCESS_TOKEN", "Q5sR2SNfxO8Vb9g2ucsaqfUGOpcpZi3S")
|
||||||
|
VOLC_TTS_SECRET_KEY = os.getenv("VOLC_TTS_SECRET_KEY", "RXc2WiA6OK6G1xuEZ7cyAU3Q3B5Z1oUx")
|
||||||
|
|
||||||
|
# 默认音色
|
||||||
|
# 抖音热门带货音色推荐:
|
||||||
|
# - BV700_streaming: 甜美小媛(甜美活泼,适合美妆/好物)- 可能无权限
|
||||||
|
# - zh_female_santongyongns_saturn_bigtts: 三通永(已验证可用)
|
||||||
|
# - zh_female_meilinvyou_saturn_bigtts: 美丽女友(已验证可用)
|
||||||
|
VOLC_TTS_DEFAULT_VOICE = os.getenv("VOLC_TTS_VOICE", "zh_female_santongyongns_saturn_bigtts")
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Video Settings
|
||||||
|
# ============================================================
|
||||||
|
VIDEO_SETTINGS = {
|
||||||
|
"width": 1080,
|
||||||
|
"height": 1920,
|
||||||
|
"fps": 30,
|
||||||
|
"format": "mp4",
|
||||||
|
"codec": "libx264",
|
||||||
|
}
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Paths
|
||||||
|
# ============================================================
|
||||||
|
BASE_DIR = Path(__file__).parent
|
||||||
|
OUTPUT_DIR = BASE_DIR / "output"
|
||||||
|
TEMP_DIR = BASE_DIR / "temp"
|
||||||
|
ASSETS_DIR = BASE_DIR / "assets"
|
||||||
|
FONTS_DIR = ASSETS_DIR / "fonts"
|
||||||
|
|
||||||
|
# Ensure directories exist
|
||||||
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||||
|
TEMP_DIR.mkdir(exist_ok=True)
|
||||||
|
ASSETS_DIR.mkdir(exist_ok=True)
|
||||||
|
FONTS_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Database Configuration
|
||||||
|
# ============================================================
|
||||||
|
# Format: postgresql://user:password@host:port/dbname
|
||||||
|
# Default to SQLite if not provided
|
||||||
|
DB_CONNECTION_STRING = os.getenv("DB_CONNECTION_STRING", f"sqlite:///{BASE_DIR}/video_flow.db")
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Font Settings (字体配置)
|
||||||
|
# ============================================================
|
||||||
|
# 优先检测系统字体,防止乱码
|
||||||
|
SYSTEM_FONTS = [
|
||||||
|
str(FONTS_DIR / "SmileySans-Oblique.otf"),
|
||||||
|
str(FONTS_DIR / "HarmonyOS-Sans-SC-Regular.ttf"),
|
||||||
|
str(FONTS_DIR / "HarmonyOS-Sans-SC-Bold.ttf"),
|
||||||
|
str(FONTS_DIR / "NotoSansSC-Regular.otf"),
|
||||||
|
str(FONTS_DIR / "NotoSansSC-Bold.otf"),
|
||||||
|
"/System/Library/Fonts/PingFang.ttc",
|
||||||
|
"/System/Library/Fonts/STHeiti Medium.ttc",
|
||||||
|
"/System/Library/Fonts/Supplemental/Arial Unicode.ttf",
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_FONT = str(FONTS_DIR / "NotoSansSC-Regular.otf")
|
||||||
|
DEFAULT_FONT_BOLD = str(FONTS_DIR / "NotoSansSC-Bold.otf")
|
||||||
|
|
||||||
|
# 检查项目字体是否存在,不存在则使用系统字体
|
||||||
|
def pick_font():
|
||||||
|
for f in SYSTEM_FONTS:
|
||||||
|
if os.path.exists(f) and os.path.getsize(f) > 1000:
|
||||||
|
return f
|
||||||
|
return "/System/Library/Fonts/PingFang.ttc"
|
||||||
|
|
||||||
|
DEFAULT_FONT = pick_font()
|
||||||
|
DEFAULT_FONT_BOLD = DEFAULT_FONT
|
||||||
|
|
||||||
|
# 花字样式预设
|
||||||
|
FANCY_TEXT_STYLES = {
|
||||||
|
"subtitle": {
|
||||||
|
"font_size": 48,
|
||||||
|
"font_color": "#FFFFFF",
|
||||||
|
"stroke_color": "#000000",
|
||||||
|
"stroke_width": 3
|
||||||
|
},
|
||||||
|
"highlight": {
|
||||||
|
"font_size": 56,
|
||||||
|
"font_color": "#FFE66D",
|
||||||
|
"stroke_color": "#000000",
|
||||||
|
"stroke_width": 4
|
||||||
|
},
|
||||||
|
"warning": {
|
||||||
|
"font_size": 52,
|
||||||
|
"font_color": "#FF4444",
|
||||||
|
"stroke_color": "#FFFFFF",
|
||||||
|
"stroke_width": 4
|
||||||
|
},
|
||||||
|
"price": {
|
||||||
|
"font_size": 72,
|
||||||
|
"price_color": "#FF4444",
|
||||||
|
"stroke_color": "#FFFFFF",
|
||||||
|
"stroke_width": 5
|
||||||
|
},
|
||||||
|
"button": {
|
||||||
|
"font_size": 36,
|
||||||
|
"font_color": "#FFFFFF",
|
||||||
|
"bg_color": "#FF6B35",
|
||||||
|
"corner_radius": 25
|
||||||
|
}
|
||||||
|
}
|
||||||
255
deploy.py
Normal file
255
deploy.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
"""
|
||||||
|
Gloda Video Factory - Deployment Script
|
||||||
|
One-click deployment to remote server using Fabric.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from fabric import Connection, Config
|
||||||
|
from invoke import task
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Server configuration
|
||||||
|
SERVER_IP = os.getenv("SERVER_IP", "")
|
||||||
|
SERVER_USER = os.getenv("SERVER_USER", "root")
|
||||||
|
SERVER_PASS = os.getenv("SERVER_PASS", "")
|
||||||
|
|
||||||
|
# Remote paths
|
||||||
|
REMOTE_APP_DIR = "/opt/gloda-factory"
|
||||||
|
REMOTE_VENV = f"{REMOTE_APP_DIR}/venv"
|
||||||
|
|
||||||
|
# Files to upload
|
||||||
|
LOCAL_FILES = [
|
||||||
|
"config.py",
|
||||||
|
"web_app.py",
|
||||||
|
"requirements.txt",
|
||||||
|
".env",
|
||||||
|
"modules/__init__.py",
|
||||||
|
"modules/utils.py",
|
||||||
|
"modules/brain.py",
|
||||||
|
"modules/factory.py",
|
||||||
|
"modules/editor.py",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_connection() -> Connection:
|
||||||
|
"""Create SSH connection to remote server."""
|
||||||
|
if not SERVER_IP or not SERVER_PASS:
|
||||||
|
raise ValueError("SERVER_IP and SERVER_PASS must be set in .env")
|
||||||
|
|
||||||
|
config = Config(overrides={"sudo": {"password": SERVER_PASS}})
|
||||||
|
return Connection(
|
||||||
|
host=SERVER_IP,
|
||||||
|
user=SERVER_USER,
|
||||||
|
connect_kwargs={"password": SERVER_PASS},
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def deploy():
|
||||||
|
"""Full deployment: setup server, upload code, start app."""
|
||||||
|
print("🚀 Starting deployment...")
|
||||||
|
|
||||||
|
conn = get_connection()
|
||||||
|
|
||||||
|
# Step 1: Install system dependencies
|
||||||
|
print("\n📦 Step 1/5: Installing system dependencies...")
|
||||||
|
install_dependencies(conn)
|
||||||
|
|
||||||
|
# Step 2: Create app directory
|
||||||
|
print("\n📁 Step 2/5: Setting up directories...")
|
||||||
|
setup_directories(conn)
|
||||||
|
|
||||||
|
# Step 3: Upload code
|
||||||
|
print("\n📤 Step 3/5: Uploading code...")
|
||||||
|
upload_code(conn)
|
||||||
|
|
||||||
|
# Step 4: Setup Python environment
|
||||||
|
print("\n🐍 Step 4/5: Setting up Python environment...")
|
||||||
|
setup_python(conn)
|
||||||
|
|
||||||
|
# Step 5: Start application
|
||||||
|
print("\n🎬 Step 5/5: Starting application...")
|
||||||
|
start_app(conn)
|
||||||
|
|
||||||
|
print(f"\n✅ Deployment complete!")
|
||||||
|
print(f"🌐 Access the app at: http://{SERVER_IP}:8501")
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def install_dependencies(conn: Connection):
|
||||||
|
"""Install system-level dependencies."""
|
||||||
|
commands = [
|
||||||
|
"apt-get update -qq",
|
||||||
|
"apt-get install -y -qq python3 python3-pip python3-venv",
|
||||||
|
"apt-get install -y -qq ffmpeg imagemagick",
|
||||||
|
"apt-get install -y -qq fonts-liberation fonts-dejavu-core",
|
||||||
|
]
|
||||||
|
|
||||||
|
for cmd in commands:
|
||||||
|
print(f" Running: {cmd[:50]}...")
|
||||||
|
conn.sudo(cmd, hide=True)
|
||||||
|
|
||||||
|
# Configure ImageMagick policy (allow PDF/SVG for text rendering)
|
||||||
|
policy_fix = """
|
||||||
|
sed -i 's/<policy domain="path" rights="none" pattern="@\\*"/<policy domain="path" rights="read|write" pattern="@*"/g' /etc/ImageMagick-6/policy.xml 2>/dev/null || true
|
||||||
|
"""
|
||||||
|
conn.sudo(policy_fix, hide=True, warn=True)
|
||||||
|
|
||||||
|
print(" ✅ System dependencies installed")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_directories(conn: Connection):
|
||||||
|
"""Create application directories on remote server."""
|
||||||
|
conn.sudo(f"mkdir -p {REMOTE_APP_DIR}/modules", hide=True)
|
||||||
|
conn.sudo(f"mkdir -p {REMOTE_APP_DIR}/output", hide=True)
|
||||||
|
conn.sudo(f"mkdir -p {REMOTE_APP_DIR}/assets/fonts", hide=True)
|
||||||
|
conn.sudo(f"chown -R {SERVER_USER}:{SERVER_USER} {REMOTE_APP_DIR}", hide=True)
|
||||||
|
|
||||||
|
print(f" ✅ Directories created at {REMOTE_APP_DIR}")
|
||||||
|
|
||||||
|
|
||||||
|
def upload_code(conn: Connection):
|
||||||
|
"""Upload application code to remote server."""
|
||||||
|
local_base = Path(__file__).parent
|
||||||
|
|
||||||
|
for file_path in LOCAL_FILES:
|
||||||
|
local_file = local_base / file_path
|
||||||
|
remote_file = f"{REMOTE_APP_DIR}/{file_path}"
|
||||||
|
|
||||||
|
if local_file.exists():
|
||||||
|
# Ensure remote directory exists
|
||||||
|
remote_dir = str(Path(remote_file).parent)
|
||||||
|
conn.run(f"mkdir -p {remote_dir}", hide=True)
|
||||||
|
|
||||||
|
# Upload file
|
||||||
|
conn.put(str(local_file), remote_file)
|
||||||
|
print(f" ✅ Uploaded: {file_path}")
|
||||||
|
else:
|
||||||
|
print(f" ⚠️ Skipped (not found): {file_path}")
|
||||||
|
|
||||||
|
print(" ✅ Code uploaded")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_python(conn: Connection):
|
||||||
|
"""Setup Python virtual environment and install dependencies."""
|
||||||
|
with conn.cd(REMOTE_APP_DIR):
|
||||||
|
# Create virtual environment
|
||||||
|
conn.run(f"python3 -m venv {REMOTE_VENV}", hide=True)
|
||||||
|
|
||||||
|
# Upgrade pip
|
||||||
|
conn.run(f"{REMOTE_VENV}/bin/pip install --upgrade pip -q", hide=True)
|
||||||
|
|
||||||
|
# Install requirements
|
||||||
|
conn.run(f"{REMOTE_VENV}/bin/pip install -r requirements.txt -q", hide=True)
|
||||||
|
|
||||||
|
print(" ✅ Python environment ready")
|
||||||
|
|
||||||
|
|
||||||
|
def start_app(conn: Connection):
|
||||||
|
"""Start the Streamlit application."""
|
||||||
|
# Stop existing process if any
|
||||||
|
conn.run("pkill -f 'streamlit run web_app.py' || true", hide=True, warn=True)
|
||||||
|
|
||||||
|
# Start in background with nohup
|
||||||
|
start_cmd = f"""
|
||||||
|
cd {REMOTE_APP_DIR} && \
|
||||||
|
nohup {REMOTE_VENV}/bin/streamlit run web_app.py \
|
||||||
|
--server.port 8501 \
|
||||||
|
--server.address 0.0.0.0 \
|
||||||
|
--server.headless true \
|
||||||
|
--browser.gatherUsageStats false \
|
||||||
|
> /var/log/gloda-factory.log 2>&1 &
|
||||||
|
"""
|
||||||
|
conn.run(start_cmd, hide=True)
|
||||||
|
|
||||||
|
# Wait and verify
|
||||||
|
import time
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
result = conn.run("pgrep -f 'streamlit run web_app.py'", hide=True, warn=True)
|
||||||
|
if result.ok:
|
||||||
|
print(f" ✅ Application started (PID: {result.stdout.strip()})")
|
||||||
|
else:
|
||||||
|
print(" ⚠️ Application may not have started. Check logs.")
|
||||||
|
|
||||||
|
|
||||||
|
def stop_app():
|
||||||
|
"""Stop the running application."""
|
||||||
|
print("🛑 Stopping application...")
|
||||||
|
conn = get_connection()
|
||||||
|
conn.run("pkill -f 'streamlit run web_app.py' || true", hide=True, warn=True)
|
||||||
|
conn.close()
|
||||||
|
print("✅ Application stopped")
|
||||||
|
|
||||||
|
|
||||||
|
def logs():
|
||||||
|
"""Show application logs."""
|
||||||
|
print("📋 Recent logs:")
|
||||||
|
conn = get_connection()
|
||||||
|
conn.run("tail -50 /var/log/gloda-factory.log", warn=True)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def status():
|
||||||
|
"""Check application status."""
|
||||||
|
print("📊 Checking status...")
|
||||||
|
conn = get_connection()
|
||||||
|
|
||||||
|
result = conn.run("pgrep -f 'streamlit run web_app.py'", hide=True, warn=True)
|
||||||
|
if result.ok:
|
||||||
|
print(f"✅ Application is running (PID: {result.stdout.strip()})")
|
||||||
|
print(f"🌐 URL: http://{SERVER_IP}:8501")
|
||||||
|
else:
|
||||||
|
print("❌ Application is not running")
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def restart():
|
||||||
|
"""Restart the application."""
|
||||||
|
print("🔄 Restarting application...")
|
||||||
|
conn = get_connection()
|
||||||
|
|
||||||
|
# Stop
|
||||||
|
conn.run("pkill -f 'streamlit run web_app.py' || true", hide=True, warn=True)
|
||||||
|
|
||||||
|
import time
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Start
|
||||||
|
start_app(conn)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Gloda Factory Deployment Tool")
|
||||||
|
parser.add_argument(
|
||||||
|
"command",
|
||||||
|
choices=["deploy", "start", "stop", "restart", "status", "logs"],
|
||||||
|
help="Deployment command to run"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
commands = {
|
||||||
|
"deploy": deploy,
|
||||||
|
"stop": stop_app,
|
||||||
|
"status": status,
|
||||||
|
"logs": logs,
|
||||||
|
"restart": restart,
|
||||||
|
"start": lambda: start_app(get_connection()),
|
||||||
|
}
|
||||||
|
|
||||||
|
commands[args.command]()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
317
docs/SYSTEM_PROMPT_VIDEO_SCRIPT_v2.md
Normal file
317
docs/SYSTEM_PROMPT_VIDEO_SCRIPT_v2.md
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
# SYSTEM CONTEXT
|
||||||
|
|
||||||
|
**Role**: 你是一名精通抖音电商算法、搜索转化心理学与 AI 视频工程化的创意总监。
|
||||||
|
**Task**: 为商品详情页(PDP)首图设计高转化率、可直接执行的 AI 视频脚本 (JSON)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 🎯 GOALS & KPI (业务核心)
|
||||||
|
|
||||||
|
1. **GPM First**: 一切为了提升千次曝光成交额 (GPM) 和下单转化率。
|
||||||
|
|
||||||
|
2. **搜索心智 (Search Intent)**: 用户通过搜索关键词或商品卡进入,处于"决策验证期"。视频必须**"所见即所得"**,前 3 秒直接承接搜索预期。
|
||||||
|
|
||||||
|
3. **静音法则 (Mute Play)**: 默认静音播放。必须依赖高视觉冲击力和醒目花字 (Fancy Text) 在前 3 秒留住用户。
|
||||||
|
|
||||||
|
4. **全品类转化逻辑**: 必须根据商品属性匹配最佳脚本策略(见思维链)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# ⏱️ 时长规范 (Duration Rules)
|
||||||
|
|
||||||
|
- **总时长**: 9-12 秒 (由 3-4 个分镜组成)
|
||||||
|
- **单分镜**: 固定 **3 秒** (`duration: 3`),严禁超过 3 秒
|
||||||
|
- **原因**: AI 生成视频超过 3 秒容易出现主体变形、画面抖动、物理异常
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 🧠 THINKING CHAIN (思维链 - 执行逻辑)
|
||||||
|
|
||||||
|
在输出 JSON 前,必须按以下步骤思考:
|
||||||
|
|
||||||
|
## Step 1: Input Analysis & Categorization (定性)
|
||||||
|
|
||||||
|
分析商品属性,将其归类为以下四种类型之一:
|
||||||
|
|
||||||
|
| Type | 类型 | 典型品类 | 脚本策略 |
|
||||||
|
|------|------|----------|----------|
|
||||||
|
| A | 功能型 | 清洁/收纳/工具/家电 | 痛点 → 解决方案 → 爽点 |
|
||||||
|
| B | 审美型 | 服装/首饰/彩妆/摆件 | 高颜全貌 → 细节质感 → 上身/氛围 |
|
||||||
|
| C | 感官型 | 零食/饮料/水果/预制菜 | 瞬间冲击 → 微观纹理 → 食欲诱惑 |
|
||||||
|
| D | 信任型 | 母婴/滋补/茶叶/高客单 | 源头/原料 → 权威背书 → 结果呈现 |
|
||||||
|
|
||||||
|
## Step 2: Visual Anchor Extraction (定锚)
|
||||||
|
|
||||||
|
基于参考图,提取一段包含 **材质、颜色、形状、包装特征** 的标准视觉描述 (Visual Anchor)。
|
||||||
|
这是防止 AI 视频变形的"防伪码",**必须复用于所有分镜的 visual_prompt**。
|
||||||
|
|
||||||
|
示例:`"深棕色圆形曲奇饼干,表面嵌入巧克力碎块,牛皮纸包装袋印有品牌Logo"`
|
||||||
|
|
||||||
|
## Step 3: Scripting Strategy (编排)
|
||||||
|
|
||||||
|
| 分镜 | 时间 | 功能 | 设计要点 |
|
||||||
|
|------|------|------|----------|
|
||||||
|
| Scene 1 | 0-3s | 搜索承接 | Visual Anchor 全貌 + 核心卖点花字 |
|
||||||
|
| Scene 2 | 3-6s | 自适应 | Type A:功能演示 / B:细节质感 / C:食欲特写 / D:原料溯源 |
|
||||||
|
| Scene 3 | 6-9s | 深化 | 对比效果 / 动态美感 / 爆浆拉丝 / 权威背书 |
|
||||||
|
| Scene 4 | 9-12s | 收尾 (可选) | 信任背书 / 使用后美好状态 / 行动号召 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 🎙️ 旁白设计规范 (Voiceover Rules)
|
||||||
|
|
||||||
|
## 核心定位 ⚠️
|
||||||
|
旁白是**卖点传递的主力军**,不是画面解说词。10秒内必须完成:场景共鸣 → 核心卖点 → 信任背书 → 行动召唤。
|
||||||
|
|
||||||
|
## 技术规范
|
||||||
|
1. **语速**: **5 字/秒** (9秒视频 = 45-50字旁白),可略超视频时长,后期 1.1x 倍速压入
|
||||||
|
2. **气口间隔**: 两段旁白之间留 **0.3-0.5 秒**
|
||||||
|
3. **时间控制**: `start_time` 和 `duration` 单位为秒
|
||||||
|
4. **字幕同步**: `subtitle` 与 `text` 完全一致
|
||||||
|
|
||||||
|
## 写作禁忌
|
||||||
|
- ❌ 描述画面:"这是一款发夹" → ✅ 带入场景:"想要千金范?这款发夹绝了"
|
||||||
|
- ❌ 空洞形容:"非常好看" → ✅ 具体感受:"黑发棕发都显贵气"
|
||||||
|
- ❌ 无信任背书 → ✅ 加数据:"月销3万单,回购率超高"
|
||||||
|
- ❌ 无行动召唤 → ✅ 加引导:"现在下单,还送同款小号"
|
||||||
|
|
||||||
|
## 示例对比
|
||||||
|
```
|
||||||
|
❌ 旧版 (24字,信息不足):
|
||||||
|
"秋冬氛围感,财阀千金风" + "毛绒质感,搭配璀璨水钻" + "精致耐看,百搭不挑人"
|
||||||
|
|
||||||
|
✅ 新版 (52字,信息密集):
|
||||||
|
"想要秋冬千金范?这款发夹绝了" + "奥地利进口水钻,手工镶嵌不掉钻" +
|
||||||
|
"黑发棕发都显贵,扎个马尾直接气质拉满" + "月销3万单,现在下单送同款小号"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 🎨 商家风格提示 (Style Hint - Optional)
|
||||||
|
|
||||||
|
如果用户提供了风格关键词(如"韩风"、"高级感"、"日系"),需融入:
|
||||||
|
- `video_style`: 调整色调、光影、构图
|
||||||
|
- 韩风 → 低饱和、柔光、简洁留白
|
||||||
|
- 高级感 → 暗调、金属质感、几何构图
|
||||||
|
- 日系 → 自然光、木质/棉麻元素、温暖色调
|
||||||
|
- `fancy_text.style`: 选择匹配的字幕风格
|
||||||
|
- 高级感 → `minimal` (白字)
|
||||||
|
- 活力 → `highlight` (黄字)
|
||||||
|
- 食欲/警示 → `warning` (红字)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# ⚠️ EXECUTION CONSTRAINTS (执行红线)
|
||||||
|
|
||||||
|
## 视觉干净度 (Visual Cleanliness)
|
||||||
|
|
||||||
|
**禁止 AI 额外生成**:装饰性文字、标语、水印、非商品元素
|
||||||
|
**必须保留**:商品包装自带的文字、Logo、品牌标识(这是商品真实外观的一部分)
|
||||||
|
|
||||||
|
正确写法:
|
||||||
|
```
|
||||||
|
✅ "商品正面全貌,保留包装原有设计 --no added text --no watermarks"
|
||||||
|
❌ "--no text" (这会错误移除包装文字)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 视觉一致性 (Visual Consistency)
|
||||||
|
|
||||||
|
所有分镜的 `visual_prompt` **必须包含完整的 Visual Anchor**,确保主体外观不变形。
|
||||||
|
|
||||||
|
## 运动控制 (Motion Control)
|
||||||
|
|
||||||
|
| 允许 ✅ | 禁止 ❌ |
|
||||||
|
|---------|---------|
|
||||||
|
| 物理运镜: Zoom In/Out, Pan, Tilt | 复杂生物动作: 手部翻转、穿衣、咀嚼 |
|
||||||
|
| 环境微动: 光影流动、水珠滑落、蒸汽升腾 | 主体形变: 产品旋转360°、折叠展开 |
|
||||||
|
| 物理动态: 掰开、倾倒、碎屑飞溅 | 长时间连续动作 (>3秒) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# ❌ 禁止示例 (Counter-examples)
|
||||||
|
|
||||||
|
## Bad visual_prompt
|
||||||
|
```
|
||||||
|
❌ "一只手拿起曲奇,放入嘴中咀嚼"
|
||||||
|
→ 手部和嘴部动作必然变形
|
||||||
|
✅ "曲奇被掰开的瞬间,巧克力流心缓缓溢出,微距特写"
|
||||||
|
→ 物理动作,无人体
|
||||||
|
```
|
||||||
|
|
||||||
|
## Bad video_prompt
|
||||||
|
```
|
||||||
|
❌ "镜头跟随产品旋转一周,展示各个角度"
|
||||||
|
→ 超出 3 秒,旋转运动易变形
|
||||||
|
✅ "Slow Zoom In, 光影在表面流动, 背景蒸汽微动"
|
||||||
|
→ 简单运镜 + 物理微动
|
||||||
|
```
|
||||||
|
|
||||||
|
## Bad fancy_text
|
||||||
|
```
|
||||||
|
❌ "进口黄油手工烘焙每日新鲜发货限时特惠"
|
||||||
|
→ 超过 6 字,静音下无法快速阅读
|
||||||
|
✅ "进口黄油"
|
||||||
|
→ 核心卖点浓缩,一眼可读
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 📐 Visual Prompt 语法规范
|
||||||
|
|
||||||
|
## 结构模板
|
||||||
|
```
|
||||||
|
[Visual Anchor] + [主体状态/动作] + [景别] + [环境/光影] + [否定提示]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 完整示例
|
||||||
|
```
|
||||||
|
"[深棕色圆形曲奇,表面嵌入巧克力碎块,牛皮纸包装] +
|
||||||
|
饼干被掰开,流心巧克力缓缓流出 +
|
||||||
|
微距特写,浅景深 +
|
||||||
|
暖黄色逆光,大理石台面 +
|
||||||
|
--no added text --no watermarks --no hands"
|
||||||
|
```
|
||||||
|
|
||||||
|
## 否定提示规范 (--no)
|
||||||
|
- `--no added text` (禁止AI添加的文字,保留包装原有文字)
|
||||||
|
- `--no watermarks` (禁止水印)
|
||||||
|
- `--no hands` / `--no human body` (如非必要)
|
||||||
|
- `--no complex motion` (禁止复杂动作)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 📄 OUTPUT FORMAT (Strict JSON Schema)
|
||||||
|
|
||||||
|
**重要**:必须保留以下顶层字段,确保与现有系统兼容。
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"product_name": "商品名称",
|
||||||
|
"visual_anchor": "商品视觉锚点:材质+颜色+形状+包装特征,用于保持生图一致性",
|
||||||
|
"selling_points": ["核心卖点1", "核心卖点2", "核心卖点3"],
|
||||||
|
"target_audience": "目标人群描述",
|
||||||
|
"video_style": "视频风格 (色调/光影/构图)",
|
||||||
|
"bgm_style": "BGM风格",
|
||||||
|
"voiceover_timeline": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"text": "旁白文案 (口语化, 4字/秒)",
|
||||||
|
"subtitle": "字幕文案 (与text完全一致)",
|
||||||
|
"start_time": 0.0,
|
||||||
|
"duration": 3.0
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"scenes": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"duration": 3,
|
||||||
|
"visual_prompt": "[Visual Anchor] + 场景描述 --no added text --no watermarks",
|
||||||
|
"video_prompt": "运镜 + 物理动态描述",
|
||||||
|
"fancy_text": {
|
||||||
|
"text": "最多6字",
|
||||||
|
"style": "highlight | warning | minimal",
|
||||||
|
"position": "top | center | bottom",
|
||||||
|
"start_time": 0.5,
|
||||||
|
"duration": 2.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# 📝 完整示例 (Type C - 爆浆曲奇)
|
||||||
|
|
||||||
|
**Input**: 商品名"爆浆流心曲奇",参考图为深棕色曲奇+巧克力流心特写
|
||||||
|
|
||||||
|
**Output**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"product_name": "爆浆流心曲奇",
|
||||||
|
"visual_anchor": "深棕色圆形曲奇饼干,表面嵌入巧克力碎块,内部巧克力流心,牛皮纸包装袋印有品牌Logo",
|
||||||
|
"selling_points": ["真·爆浆流心", "进口黄油", "香浓不腻"],
|
||||||
|
"target_audience": "18-35岁女性,追求零食品质,喜欢巧克力甜品",
|
||||||
|
"video_style": "Macro photography, warm golden backlight, shallow DOF, rustic wood surface",
|
||||||
|
"bgm_style": "ASMR crackling + light upbeat rhythm",
|
||||||
|
"voiceover_timeline": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"text": "下午嘴馋了?来一口真·爆浆流心曲奇",
|
||||||
|
"subtitle": "下午嘴馋了?来一口真·爆浆流心曲奇",
|
||||||
|
"start_time": 0.0,
|
||||||
|
"duration": 3.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"text": "新西兰进口黄油,纯可可脂,咬开瞬间流心爆浆",
|
||||||
|
"subtitle": "新西兰进口黄油,纯可可脂,咬开瞬间流心爆浆",
|
||||||
|
"start_time": 3.2,
|
||||||
|
"duration": 3.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"text": "已售50万盒,回购率超高,现在下单买二送一",
|
||||||
|
"subtitle": "已售50万盒,回购率超高,现在下单买二送一",
|
||||||
|
"start_time": 7.0,
|
||||||
|
"duration": 3.0
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"scenes": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"duration": 3,
|
||||||
|
"visual_prompt": "[深棕色圆形曲奇饼干,表面嵌入巧克力碎块,牛皮纸包装印有品牌标识] 正面全貌堆叠展示,大理石台面,暖黄逆光,浅景深 --no added text --no watermarks",
|
||||||
|
"video_prompt": "Slow Zoom In, 光影在曲奇表面缓缓流动,背景轻微虚化",
|
||||||
|
"fancy_text": {
|
||||||
|
"text": "爆浆流心",
|
||||||
|
"style": "warning",
|
||||||
|
"position": "center",
|
||||||
|
"start_time": 0.5,
|
||||||
|
"duration": 2.0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"duration": 3,
|
||||||
|
"visual_prompt": "[深棕色圆形曲奇饼干,内部巧克力流心] 饼干被掰开的瞬间,巧克力流心缓缓溢出,微距特写,暖色调 --no hands --no added text --no watermarks",
|
||||||
|
"video_prompt": "Static macro shot, 流心自然流动,碎屑微微散落",
|
||||||
|
"fancy_text": {
|
||||||
|
"text": "真·爆浆",
|
||||||
|
"style": "highlight",
|
||||||
|
"position": "bottom",
|
||||||
|
"start_time": 0.3,
|
||||||
|
"duration": 2.0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"duration": 3,
|
||||||
|
"visual_prompt": "[深棕色圆形曲奇饼干,牛皮纸包装印有品牌标识] 包装盒俯拍,旁边散落黄油块和可可豆原料,简洁浅色背景 --no added text --no watermarks",
|
||||||
|
"video_prompt": "Slow Pan Right, 依次掠过原料和包装",
|
||||||
|
"fancy_text": {
|
||||||
|
"text": "进口黄油",
|
||||||
|
"style": "minimal",
|
||||||
|
"position": "top",
|
||||||
|
"start_time": 0.5,
|
||||||
|
"duration": 2.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# ✅ 输出前自检清单
|
||||||
|
|
||||||
|
1. [ ] `product_name`, `visual_anchor`, `selling_points`, `target_audience` 是否存在于顶层?
|
||||||
|
2. [ ] `visual_anchor` 是否包含:材质+颜色+形状+包装特征?
|
||||||
|
3. [ ] `video_style`, `bgm_style` 是否存在于顶层?
|
||||||
|
4. [ ] 每个分镜 duration 是否 = 3?
|
||||||
|
5. [ ] 总时长是否在 9-12 秒范围内?
|
||||||
|
6. [ ] voiceover_timeline 使用的是 `start_time` 和 `duration` (秒) 而非 ratio?
|
||||||
|
7. [ ] 旁白语速是否 ≤ 4字/秒?
|
||||||
|
8. [ ] fancy_text 是否 ≤ 6 字?
|
||||||
|
9. [ ] 是否使用 `--no added text` 而非 `--no text`?
|
||||||
|
10. [ ] 是否避免了复杂人体动作描述?
|
||||||
339
main_flow.py
Normal file
339
main_flow.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
"""
|
||||||
|
Video Flow v2.0 - 命令行主流程控制器
|
||||||
|
|
||||||
|
独立的 CLI 入口,支持命令行参数调用完整的视频生成流程。
|
||||||
|
与 app.py (Streamlit UI) 分离,共用 modules 层。
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python main_flow.py --help
|
||||||
|
|
||||||
|
python main_flow.py \
|
||||||
|
--product-name "网红气质大号发量多!高马尾香蕉夹" \
|
||||||
|
--images /path/to/主图1.png /path/to/主图2.png \
|
||||||
|
--category "钟表配饰-时尚饰品-发饰" \
|
||||||
|
--price "3.99元" \
|
||||||
|
--tags "回头客|款式好看|材质好" \
|
||||||
|
--model doubao \
|
||||||
|
--output final_hairclip
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(sys.stdout),
|
||||||
|
logging.FileHandler("video_flow.log")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
import config
|
||||||
|
from modules.script_gen import ScriptGenerator
|
||||||
|
from modules.image_gen import ImageGenerator
|
||||||
|
from modules.video_gen import VideoGenerator
|
||||||
|
from modules.composer import VideoComposer
|
||||||
|
|
||||||
|
logger = logging.getLogger("MainFlow")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""解析命令行参数"""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Video Flow CLI - 商品短视频自动生成命令行工具",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
示例:
|
||||||
|
# 使用默认测试数据
|
||||||
|
python main_flow.py --demo
|
||||||
|
|
||||||
|
# 指定商品信息
|
||||||
|
python main_flow.py \\
|
||||||
|
--product-name "网红气质大号发量多!高马尾香蕉夹" \\
|
||||||
|
--images ./素材/发夹/原始稿/主图1.png ./素材/发夹/原始稿/主图2.png \\
|
||||||
|
--category "钟表配饰-时尚饰品-发饰" \\
|
||||||
|
--price "3.99元" \\
|
||||||
|
--tags "回头客|款式好看|材质好" \\
|
||||||
|
--model doubao \\
|
||||||
|
--output final_hairclip
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# 基本参数
|
||||||
|
parser.add_argument("--demo", action="store_true", help="使用内置测试数据(发夹案例)")
|
||||||
|
parser.add_argument("--product-name", type=str, help="商品标题")
|
||||||
|
parser.add_argument("--images", nargs="+", type=str, help="商品主图路径列表 (建议 3-5 张)")
|
||||||
|
|
||||||
|
# 商品信息
|
||||||
|
parser.add_argument("--category", type=str, default="", help="商品类目")
|
||||||
|
parser.add_argument("--price", type=str, default="", help="商品价格")
|
||||||
|
parser.add_argument("--tags", type=str, default="", help="评价标签 (用于提炼卖点)")
|
||||||
|
parser.add_argument("--params", type=str, default="", help="商品参数")
|
||||||
|
parser.add_argument("--style-hint", type=str, default="", help="风格提示 (如: 韩风、高级感)")
|
||||||
|
|
||||||
|
# 模型选择
|
||||||
|
parser.add_argument("--script-model", choices=["shubiaobiao", "doubao"], default="doubao",
|
||||||
|
help="脚本生成模型 (default: doubao)")
|
||||||
|
parser.add_argument("--image-model", choices=["shubiaobiao", "doubao", "gemini", "doubao-group"],
|
||||||
|
default="doubao", help="图片生成模型 (default: doubao)")
|
||||||
|
|
||||||
|
# 输出选项
|
||||||
|
parser.add_argument("--output", type=str, default="final_video", help="输出文件名 (不含扩展名)")
|
||||||
|
parser.add_argument("--project-id", type=str, default=None, help="项目ID (默认自动生成)")
|
||||||
|
|
||||||
|
# 可选步骤控制
|
||||||
|
parser.add_argument("--skip-video", action="store_true", help="跳过视频生成步骤 (仅生成脚本和图片)")
|
||||||
|
parser.add_argument("--skip-compose", action="store_true", help="跳过合成步骤")
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def get_demo_data() -> tuple:
|
||||||
|
"""获取内置测试数据 (发夹案例)"""
|
||||||
|
product_name = "网红气质大号发量多!高马尾香蕉夹 马尾显发量蓬松神器马尾夹"
|
||||||
|
product_info = {
|
||||||
|
"category": "钟表配饰-时尚饰品-发饰",
|
||||||
|
"price": "3.99元",
|
||||||
|
"tags": "回头客|款式好看|材质好|尺寸合适|颜色好看|很好用|做工好|质感不错|很牢固",
|
||||||
|
"params": "金属材质:非金属; 非金属材质:树脂; 发夹分类:香蕉夹; 风格:日韩|简约风|法式|瑞丽风",
|
||||||
|
"style_hint": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# 原始图片路径
|
||||||
|
base_image_dir = Path("/Volumes/Tony/video-flow/素材/发夹/原始稿")
|
||||||
|
original_images = [
|
||||||
|
str(base_image_dir / "主图1.png"),
|
||||||
|
str(base_image_dir / "主图2.png"),
|
||||||
|
str(base_image_dir / "主图3.png")
|
||||||
|
]
|
||||||
|
|
||||||
|
return product_name, product_info, original_images
|
||||||
|
|
||||||
|
|
||||||
|
def match_bgm_by_style(bgm_style: str, bgm_dir: Path) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
根据脚本 bgm_style 智能匹配 BGM 文件
|
||||||
|
- 匹配成功:随机选一个匹配的 BGM
|
||||||
|
- 匹配失败:随机选任意一个 BGM
|
||||||
|
"""
|
||||||
|
# 获取所有 BGM 文件 (支持 .mp3 和 .mp4)
|
||||||
|
bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3"))
|
||||||
|
bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')]
|
||||||
|
|
||||||
|
if not bgm_files:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 关键词匹配
|
||||||
|
if bgm_style:
|
||||||
|
style_lower = bgm_style.lower()
|
||||||
|
keywords = ["活泼", "欢快", "轻松", "舒缓", "休闲", "温柔", "随性", "百搭", "bling", "节奏"]
|
||||||
|
matched_keywords = [kw for kw in keywords if kw in style_lower]
|
||||||
|
|
||||||
|
matched_files = []
|
||||||
|
for f in bgm_files:
|
||||||
|
fname = f.name
|
||||||
|
if any(kw in fname for kw in matched_keywords):
|
||||||
|
matched_files.append(f)
|
||||||
|
|
||||||
|
if matched_files:
|
||||||
|
return str(random.choice(matched_files))
|
||||||
|
|
||||||
|
# 无匹配则随机选一个
|
||||||
|
return str(random.choice(bgm_files))
|
||||||
|
|
||||||
|
|
||||||
|
def run_video_flow(args) -> Optional[str]:
|
||||||
|
"""执行完整的视频生成流程"""
|
||||||
|
|
||||||
|
# ===== 1. 准备输入数据 =====
|
||||||
|
if args.demo:
|
||||||
|
logger.info("Using DEMO data (发夹案例)...")
|
||||||
|
product_name, product_info, original_images = get_demo_data()
|
||||||
|
else:
|
||||||
|
if not args.product_name or not args.images:
|
||||||
|
logger.error("Must provide --product-name and --images, or use --demo")
|
||||||
|
return None
|
||||||
|
|
||||||
|
product_name = args.product_name
|
||||||
|
product_info = {
|
||||||
|
"category": args.category,
|
||||||
|
"price": args.price,
|
||||||
|
"tags": args.tags,
|
||||||
|
"params": args.params,
|
||||||
|
"style_hint": args.style_hint
|
||||||
|
}
|
||||||
|
original_images = args.images
|
||||||
|
|
||||||
|
# 验证图片是否存在
|
||||||
|
valid_images = [p for p in original_images if Path(p).exists()]
|
||||||
|
if not valid_images:
|
||||||
|
logger.error("No valid input images found!")
|
||||||
|
logger.error(f"Checked paths: {original_images}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(f"Found {len(valid_images)} valid images")
|
||||||
|
|
||||||
|
# 生成项目 ID
|
||||||
|
project_id = args.project_id or f"CLI-{int(time.time())}"
|
||||||
|
logger.info(f"Project ID: {project_id}")
|
||||||
|
|
||||||
|
# ===== 2. 生成脚本 =====
|
||||||
|
logger.info("="*50)
|
||||||
|
logger.info("Step 1: Generating Script...")
|
||||||
|
logger.info("="*50)
|
||||||
|
|
||||||
|
script_gen = ScriptGenerator()
|
||||||
|
script = script_gen.generate_script(
|
||||||
|
product_name,
|
||||||
|
product_info,
|
||||||
|
valid_images,
|
||||||
|
model_provider=args.script_model
|
||||||
|
)
|
||||||
|
|
||||||
|
if not script:
|
||||||
|
logger.error("Script generation failed.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 保存脚本供检查
|
||||||
|
script_path = config.OUTPUT_DIR / f"script_{project_id}.json"
|
||||||
|
with open(script_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(script, f, ensure_ascii=False, indent=2)
|
||||||
|
logger.info(f"Script saved to {script_path}")
|
||||||
|
|
||||||
|
scenes = script.get("scenes", [])
|
||||||
|
logger.info(f"Generated {len(scenes)} scenes")
|
||||||
|
|
||||||
|
# ===== 3. 生成分镜图片 =====
|
||||||
|
logger.info("="*50)
|
||||||
|
logger.info("Step 2: Generating Scene Images...")
|
||||||
|
logger.info("="*50)
|
||||||
|
|
||||||
|
image_gen = ImageGenerator()
|
||||||
|
visual_anchor = script.get("visual_anchor", "")
|
||||||
|
|
||||||
|
scene_images: Dict[int, str] = {}
|
||||||
|
|
||||||
|
if args.image_model == "doubao-group":
|
||||||
|
# 组图生成模式
|
||||||
|
logger.info("Using Doubao Group Image Generation...")
|
||||||
|
scene_images = image_gen.generate_group_images_doubao(
|
||||||
|
scenes=scenes,
|
||||||
|
reference_images=valid_images,
|
||||||
|
visual_anchor=visual_anchor
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 顺序生成模式
|
||||||
|
current_refs = list(valid_images)
|
||||||
|
|
||||||
|
for idx, scene in enumerate(scenes):
|
||||||
|
scene_id = scene["id"]
|
||||||
|
logger.info(f"Generating image for Scene {scene_id} ({idx+1}/{len(scenes)})...")
|
||||||
|
|
||||||
|
img_path = image_gen.generate_single_scene_image(
|
||||||
|
scene=scene,
|
||||||
|
original_image_path=current_refs,
|
||||||
|
previous_image_path=None,
|
||||||
|
model_provider=args.image_model,
|
||||||
|
visual_anchor=visual_anchor
|
||||||
|
)
|
||||||
|
|
||||||
|
if img_path:
|
||||||
|
scene_images[scene_id] = img_path
|
||||||
|
current_refs.append(img_path)
|
||||||
|
logger.info(f"Scene {scene_id} image: {img_path}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to generate image for Scene {scene_id}")
|
||||||
|
|
||||||
|
if not scene_images:
|
||||||
|
logger.error("Image generation failed (no images generated).")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(f"Generated {len(scene_images)} scene images.")
|
||||||
|
|
||||||
|
if args.skip_video:
|
||||||
|
logger.info("Skipping video generation (--skip-video)")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ===== 4. 生成分镜视频 =====
|
||||||
|
logger.info("="*50)
|
||||||
|
logger.info("Step 3: Generating Scene Videos...")
|
||||||
|
logger.info("="*50)
|
||||||
|
|
||||||
|
video_gen = VideoGenerator()
|
||||||
|
scene_videos = video_gen.generate_scene_videos(project_id, script, scene_images)
|
||||||
|
|
||||||
|
if not scene_videos:
|
||||||
|
logger.error("Video generation failed (or partially failed).")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(f"Generated {len(scene_videos)} scene videos.")
|
||||||
|
|
||||||
|
if args.skip_compose:
|
||||||
|
logger.info("Skipping composition (--skip-compose)")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ===== 5. 合成最终视频 =====
|
||||||
|
logger.info("="*50)
|
||||||
|
logger.info("Step 4: Composing Final Video...")
|
||||||
|
logger.info("="*50)
|
||||||
|
|
||||||
|
composer = VideoComposer(voice_type=config.VOLC_TTS_DEFAULT_VOICE)
|
||||||
|
|
||||||
|
# 智能匹配 BGM
|
||||||
|
bgm_style = script.get("bgm_style", "")
|
||||||
|
bgm_path = match_bgm_by_style(bgm_style, config.ASSETS_DIR / "bgm")
|
||||||
|
if bgm_path:
|
||||||
|
logger.info(f"Selected BGM: {Path(bgm_path).name} (style: {bgm_style or 'default'})")
|
||||||
|
|
||||||
|
# 合成
|
||||||
|
output_name = f"{args.output}_{project_id}"
|
||||||
|
final_video = composer.compose_from_script(
|
||||||
|
script=script,
|
||||||
|
video_map=scene_videos,
|
||||||
|
bgm_path=bgm_path,
|
||||||
|
output_name=output_name
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("="*50)
|
||||||
|
logger.info(f"✅ Workflow Complete!")
|
||||||
|
logger.info(f" Final Video: {final_video}")
|
||||||
|
logger.info(f" Script: {script_path}")
|
||||||
|
logger.info("="*50)
|
||||||
|
|
||||||
|
return final_video
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""CLI 入口"""
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# 验证参数
|
||||||
|
if not args.demo and not args.product_name:
|
||||||
|
print("Error: Must provide --product-name and --images, or use --demo")
|
||||||
|
print("Run with --help for usage information.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = run_video_flow(args)
|
||||||
|
if result:
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
sys.exit(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Interrupted by user")
|
||||||
|
sys.exit(130)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Unexpected error: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
14
modules/__init__.py
Normal file
14
modules/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Gloda Video Factory - Modules Package
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"utils",
|
||||||
|
"brain",
|
||||||
|
"factory",
|
||||||
|
"editor",
|
||||||
|
"ffmpeg_utils",
|
||||||
|
"fancy_text",
|
||||||
|
"composer"
|
||||||
|
]
|
||||||
|
|
||||||
81
modules/asr.py
Normal file
81
modules/asr.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""
|
||||||
|
MatchMe Studio - ASR Module (Whisper via ShuBiaoBiao)
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=config.SHUBIAOBIAO_KEY,
|
||||||
|
base_url=config.SHUBIAOBIAO_BASE_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_audio_from_video(video_path: str) -> str:
|
||||||
|
"""Extract audio track from video using ffmpeg."""
|
||||||
|
video_path = Path(video_path)
|
||||||
|
audio_path = config.TEMP_DIR / f"{video_path.stem}_audio.mp3"
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-i", str(video_path),
|
||||||
|
"-vn", # No video
|
||||||
|
"-acodec", "libmp3lame",
|
||||||
|
"-ar", "16000", # 16kHz for Whisper
|
||||||
|
"-ac", "1", # Mono
|
||||||
|
str(audio_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(cmd, check=True, capture_output=True)
|
||||||
|
logger.info(f"Audio extracted to {audio_path}")
|
||||||
|
return str(audio_path)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"FFmpeg error: {e.stderr.decode()}")
|
||||||
|
raise RuntimeError("Failed to extract audio from video")
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe(audio_path: str) -> str:
|
||||||
|
"""Transcribe audio to text using Whisper API."""
|
||||||
|
logger.info(f"Transcribing {audio_path}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(audio_path, "rb") as audio_file:
|
||||||
|
response = client.audio.transcriptions.create(
|
||||||
|
model="whisper-1",
|
||||||
|
file=audio_file,
|
||||||
|
language="zh", # Chinese
|
||||||
|
response_format="text"
|
||||||
|
)
|
||||||
|
|
||||||
|
text = response if isinstance(response, str) else response.text
|
||||||
|
logger.info(f"Transcription complete: {len(text)} chars")
|
||||||
|
return text
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Whisper API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_video(video_path: str) -> str:
|
||||||
|
"""Extract audio from video and transcribe."""
|
||||||
|
audio_path = extract_audio_from_video(video_path)
|
||||||
|
return transcribe(audio_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
346
modules/brain.py
Normal file
346
modules/brain.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
"""
|
||||||
|
MatchMe Studio - Brain Module (Multi-stage Analysis & Script Generation)
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Use Volcengine (Doubao) via OpenAI Compatible Interface
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=config.VOLC_API_KEY,
|
||||||
|
base_url=config.VOLC_BASE_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Stage 1: Analyze Materials
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
ANALYZE_SYSTEM_PROMPT = """你是一位资深短视频创作总监,专精TikTok/抖音爆款内容。
|
||||||
|
|
||||||
|
任务:深度分析用户提供的素材和需求,识别产品特性、使用场景、目标人群。
|
||||||
|
|
||||||
|
分析维度:
|
||||||
|
1. 产品/服务核心卖点(从素材中提取视觉特征)
|
||||||
|
2. 视觉风格特征(颜色、质感、包装)
|
||||||
|
3. 潜在目标受众
|
||||||
|
4. 内容调性建议
|
||||||
|
|
||||||
|
然后检查是否缺少关键信息,如果缺少,生成2-5个问题帮助完善需求。
|
||||||
|
每个问题必须与短视频创作直接相关。
|
||||||
|
|
||||||
|
输出严格JSON格式:
|
||||||
|
{
|
||||||
|
"analysis": "详细分析结果,包括从素材中识别到的视觉元素...",
|
||||||
|
"detected_info": {
|
||||||
|
"product": "识别到的产品名称和类型",
|
||||||
|
"visual_features": ["视觉特征1", "视觉特征2"],
|
||||||
|
"audience": "推测的目标人群",
|
||||||
|
"style": "推测的风格"
|
||||||
|
},
|
||||||
|
"missing_info": ["缺少的信息1", "缺少的信息2"],
|
||||||
|
"questions": [
|
||||||
|
{
|
||||||
|
"id": "q1",
|
||||||
|
"text": "问题文字(说明为什么这个问题重要)",
|
||||||
|
"options": ["选项A", "选项B", "选项C"],
|
||||||
|
"allow_multiple": true,
|
||||||
|
"allow_custom": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"ready": false
|
||||||
|
}
|
||||||
|
|
||||||
|
如果信息足够,ready=true,questions为空数组。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def analyze_materials(
|
||||||
|
prompt: str,
|
||||||
|
image_urls: List[str] = None,
|
||||||
|
asr_text: str = ""
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Deep analysis of user materials.
|
||||||
|
Returns analysis text and questions if info is missing.
|
||||||
|
"""
|
||||||
|
logger.info("Brain: Analyzing materials...")
|
||||||
|
|
||||||
|
# Using Vision Model format (Doubao Vision)
|
||||||
|
# Input format: messages with content list (text + image_url)
|
||||||
|
|
||||||
|
content_parts = [{"type": "text", "text": f"用户需求: {prompt}"}]
|
||||||
|
|
||||||
|
if asr_text:
|
||||||
|
content_parts.append({"type": "text", "text": f"\n视频原声(ASR转写): {asr_text}"})
|
||||||
|
|
||||||
|
if image_urls:
|
||||||
|
content_parts.append({"type": "text", "text": "\n用户上传的素材图片(请仔细分析这些图片中的产品特征):"})
|
||||||
|
for url in image_urls:
|
||||||
|
content_parts.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": url}
|
||||||
|
})
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
# Note: Some vision models might not support 'system' role with images well,
|
||||||
|
# but Doubao usually follows standard chat structure.
|
||||||
|
# If system prompt fails, prepend it to user content.
|
||||||
|
{"role": "system", "content": ANALYZE_SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": content_parts}
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use Vision Model for Analysis
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=config.VISION_MODEL_ID,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=4000
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content.strip()
|
||||||
|
if content.startswith("```"):
|
||||||
|
parts = content.split("```")
|
||||||
|
if len(parts) > 1:
|
||||||
|
content = parts[1]
|
||||||
|
if content.startswith("json"): content = content[4:]
|
||||||
|
|
||||||
|
return json.loads(content)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Brain Analyze Error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Stage 2: Refine Brief with Answers
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
REFINE_SYSTEM_PROMPT = """你是短视频创作总监。
|
||||||
|
根据原始需求、AI分析结果、用户补充回答,整合为完整的创意简报。
|
||||||
|
|
||||||
|
注意:用户选择的风格偏好(如ASMR、剧情、视觉流等)必须作为核心创作方向贯穿整个简报。
|
||||||
|
|
||||||
|
输出JSON:
|
||||||
|
{
|
||||||
|
"brief": {
|
||||||
|
"product": "产品名称",
|
||||||
|
"product_visual_description": "产品视觉描述(颜色、形状、包装、质感等,用于后续图片生成)",
|
||||||
|
"selling_points": ["卖点1", "卖点2"],
|
||||||
|
"target_audience": "目标人群",
|
||||||
|
"platform": "投放平台",
|
||||||
|
"style": "视频风格(必须明确,如ASMR/剧情/视觉流等)",
|
||||||
|
"style_requirements": "该风格的具体创作要求(如ASMR需要:开盖声、质感特写、无人脸等)",
|
||||||
|
"creativity_level": "创意程度",
|
||||||
|
"reference": "对标账号/竞品",
|
||||||
|
"user_assets_description": "用户上传素材的描述(用于后续继承)"
|
||||||
|
},
|
||||||
|
"creative_summary": "整体创意概述(50字以内,描述这个视频的核心创意方向)",
|
||||||
|
"ready": true
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def refine_brief(
|
||||||
|
original_prompt: str,
|
||||||
|
analysis: Dict[str, Any],
|
||||||
|
answers: Dict[str, Any],
|
||||||
|
image_urls: List[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Integrate user answers into a complete creative brief.
|
||||||
|
"""
|
||||||
|
logger.info("Brain: Refining brief with answers...")
|
||||||
|
|
||||||
|
user_content = f"""
|
||||||
|
原始需求: {original_prompt}
|
||||||
|
|
||||||
|
AI分析结果: {json.dumps(analysis, ensure_ascii=False)}
|
||||||
|
|
||||||
|
用户补充回答: {json.dumps(answers, ensure_ascii=False)}
|
||||||
|
|
||||||
|
用户上传的素材URL: {json.dumps(image_urls or [], ensure_ascii=False)}
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use Text LLM for reasoning/refining if no new images involved
|
||||||
|
# But to keep it simple, we can stick to BRAIN_MODEL_ID (Doubao Pro)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=config.BRAIN_MODEL_ID,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": REFINE_SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": user_content}
|
||||||
|
],
|
||||||
|
temperature=0.5,
|
||||||
|
max_tokens=3000
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content.strip()
|
||||||
|
if content.startswith("```"):
|
||||||
|
parts = content.split("```")
|
||||||
|
if len(parts) > 1:
|
||||||
|
content = parts[1]
|
||||||
|
if content.startswith("json"): content = content[4:]
|
||||||
|
|
||||||
|
return json.loads(content)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Brain Refine Error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Stage 3: Generate Script
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
SCRIPT_SYSTEM_PROMPT = """你是顶级短视频编导,专精{style}风格内容创作。
|
||||||
|
|
||||||
|
根据创意简报生成爆款脚本。必须严格遵循用户选择的风格要求。
|
||||||
|
|
||||||
|
脚本结构要求:
|
||||||
|
1. creative_summary: 整体创意概述(这条视频的核心创意是什么)
|
||||||
|
2. hook: 前3秒钩子设计(必须抓眼球,符合{style}风格)
|
||||||
|
3. scenes: 3-8个分镜
|
||||||
|
4. cta: 结尾行动号召(纯文本字符串)
|
||||||
|
|
||||||
|
每个分镜(scene)必须包含:
|
||||||
|
- id: 分镜编号
|
||||||
|
- duration: 时长(5/10/15秒,符合视频模型参数)
|
||||||
|
- timeline: 时间轴 (如 "0:00-0:05")
|
||||||
|
- image_prompt: 【关键】用于AI生图的详细英文prompt,必须包含:
|
||||||
|
* 产品的具体视觉描述(继承自brief中的product_visual_description)
|
||||||
|
* 8k, hyper-realistic, cinematic lighting
|
||||||
|
* 色调、环境、构图、焦点
|
||||||
|
* 风格要求(如ASMR需要:macro shot, satisfying texture, no human face)
|
||||||
|
- keyframe: {
|
||||||
|
"color_tone": "色调",
|
||||||
|
"environment": "环境/背景",
|
||||||
|
"foreground": "前景元素",
|
||||||
|
"focus": "视觉焦点",
|
||||||
|
"subject": "主体描述",
|
||||||
|
"composition": "构图方式"
|
||||||
|
}
|
||||||
|
- camera_movement: 运镜描述(如:slow zoom in, pan left, static)
|
||||||
|
- story_beat: 这个分镜在整体故事中的作用
|
||||||
|
- voiceover: 旁白文字({style}风格,如ASMR应简短或无旁白,用音效代替)
|
||||||
|
- sound_design: 音效设计(如:开盖声、水滴声、环境白噪音)
|
||||||
|
- rhythm: {"change": "保持/加快/放慢", "multiplier": 1.0}
|
||||||
|
|
||||||
|
旁白要求:
|
||||||
|
- 必须连贯,形成完整的叙事
|
||||||
|
- 符合{style}风格(ASMR风格应极简或无旁白)
|
||||||
|
- 每句旁白要能独立成句,但连起来是完整故事
|
||||||
|
|
||||||
|
输出严格JSON格式。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def generate_script(
|
||||||
|
brief: Dict[str, Any],
|
||||||
|
image_urls: List[str] = None,
|
||||||
|
regenerate_feedback: str = ""
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate complete video script with scenes.
|
||||||
|
"""
|
||||||
|
logger.info("Brain: Generating script...")
|
||||||
|
|
||||||
|
style = brief.get("style", "现代广告")
|
||||||
|
system_prompt = SCRIPT_SYSTEM_PROMPT.replace("{style}", style)
|
||||||
|
|
||||||
|
content_parts = [{"type": "text", "text": f"创意简报: {json.dumps(brief, ensure_ascii=False)}"}]
|
||||||
|
|
||||||
|
if regenerate_feedback:
|
||||||
|
content_parts.append({"type": "text", "text": f"\n用户反馈(请据此调整): {regenerate_feedback}"})
|
||||||
|
|
||||||
|
if image_urls:
|
||||||
|
content_parts.append({"type": "text", "text": "\n用户上传的参考素材(生成的image_prompt必须参考这些素材中的产品外观):"})
|
||||||
|
for url in image_urls:
|
||||||
|
content_parts.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": url}
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=config.VISION_MODEL_ID, # Use Vision model to see reference images if available
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": content_parts}
|
||||||
|
],
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=8000
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content.strip()
|
||||||
|
if content.startswith("```"):
|
||||||
|
parts = content.split("```")
|
||||||
|
if len(parts) > 1:
|
||||||
|
content = parts[1]
|
||||||
|
if content.startswith("json"): content = content[4:]
|
||||||
|
|
||||||
|
return json.loads(content)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Brain Script Error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Stage 4: Regenerate Single Scene
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def regenerate_scene(
|
||||||
|
full_script: Dict[str, Any],
|
||||||
|
scene_id: int,
|
||||||
|
feedback: str,
|
||||||
|
brief: Dict[str, Any] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Regenerate a single scene based on feedback.
|
||||||
|
"""
|
||||||
|
logger.info(f"Brain: Regenerating scene {scene_id}...")
|
||||||
|
|
||||||
|
style = brief.get("style", "现代广告") if brief else "现代广告"
|
||||||
|
|
||||||
|
system_prompt = f"""你是短视频编导,专精{style}风格。根据用户反馈重新生成指定分镜。
|
||||||
|
保持与其他分镜的风格连贯性。
|
||||||
|
image_prompt必须继承产品的视觉描述。
|
||||||
|
只输出新的scene对象(JSON)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_content = f"""
|
||||||
|
完整脚本: {json.dumps(full_script, ensure_ascii=False)}
|
||||||
|
|
||||||
|
创意简报: {json.dumps(brief, ensure_ascii=False) if brief else "无"}
|
||||||
|
|
||||||
|
需要重新生成的分镜ID: {scene_id}
|
||||||
|
|
||||||
|
用户反馈: {feedback}
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=config.BRAIN_MODEL_ID,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_content}
|
||||||
|
],
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=2000
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content.strip()
|
||||||
|
if content.startswith("```"):
|
||||||
|
parts = content.split("```")
|
||||||
|
if len(parts) > 1:
|
||||||
|
content = parts[1]
|
||||||
|
if content.startswith("json"): content = content[4:]
|
||||||
|
|
||||||
|
return json.loads(content)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Brain Regenerate Scene Error: {e}")
|
||||||
|
raise
|
||||||
717
modules/composer.py
Normal file
717
modules/composer.py
Normal file
@@ -0,0 +1,717 @@
|
|||||||
|
"""
|
||||||
|
视频合成器模块
|
||||||
|
整合视频拼接、花字叠加、旁白配音的完整流程
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Optional, Union
|
||||||
|
|
||||||
|
import config
|
||||||
|
from modules import ffmpeg_utils, fancy_text, factory, storage
|
||||||
|
from modules.text_renderer import renderer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoComposer:
|
||||||
|
"""视频合成器"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_dir: str = None,
|
||||||
|
target_size: tuple = (1080, 1920),
|
||||||
|
voice_type: str = "sweet_female"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化合成器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: 输出目录
|
||||||
|
target_size: 目标分辨率 (width, height)
|
||||||
|
voice_type: 默认旁白音色
|
||||||
|
"""
|
||||||
|
self.output_dir = Path(output_dir) if output_dir else config.OUTPUT_DIR
|
||||||
|
self.output_dir.mkdir(exist_ok=True)
|
||||||
|
self.target_size = target_size
|
||||||
|
self.voice_type = voice_type
|
||||||
|
|
||||||
|
# 临时文件追踪
|
||||||
|
self._temp_files = []
|
||||||
|
|
||||||
|
def _add_temp(self, path: str):
|
||||||
|
"""记录临时文件"""
|
||||||
|
if path:
|
||||||
|
self._temp_files.append(path)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""清理临时文件"""
|
||||||
|
for f in self._temp_files:
|
||||||
|
try:
|
||||||
|
if os.path.exists(f):
|
||||||
|
os.remove(f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to cleanup {f}: {e}")
|
||||||
|
self._temp_files = []
|
||||||
|
|
||||||
|
def compose(
|
||||||
|
self,
|
||||||
|
video_paths: List[str],
|
||||||
|
subtitles: List[Dict[str, Any]] = None,
|
||||||
|
fancy_texts: List[Dict[str, Any]] = None,
|
||||||
|
voiceover_text: str = None,
|
||||||
|
voiceover_segments: List[Dict[str, Any]] = None,
|
||||||
|
bgm_path: str = None,
|
||||||
|
bgm_volume: float = 0.15,
|
||||||
|
output_name: str = None,
|
||||||
|
upload_to_r2: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
完整视频合成流程
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_paths: 分镜视频路径列表
|
||||||
|
subtitles: 字幕配置列表 [{text, start, duration, style}]
|
||||||
|
fancy_texts: 花字配置列表 [{text, style, x, y, start, duration}]
|
||||||
|
voiceover_text: 完整旁白文本(会自动生成并混音)
|
||||||
|
voiceover_segments: 分段旁白配置 [{text, start}],与 voiceover_text 二选一
|
||||||
|
bgm_path: 背景音乐路径
|
||||||
|
bgm_volume: BGM音量
|
||||||
|
output_name: 输出文件名(不含扩展名)
|
||||||
|
upload_to_r2: 是否上传到R2存储
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最终视频路径(或R2 URL)
|
||||||
|
"""
|
||||||
|
if not video_paths:
|
||||||
|
raise ValueError("No video paths provided")
|
||||||
|
|
||||||
|
timestamp = int(time.time())
|
||||||
|
output_name = output_name or f"composed_{timestamp}"
|
||||||
|
|
||||||
|
logger.info(f"Starting composition: {len(video_paths)} videos")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: 拼接视频
|
||||||
|
merged_path = str(config.TEMP_DIR / f"{output_name}_merged.mp4")
|
||||||
|
ffmpeg_utils.concat_videos(video_paths, merged_path, self.target_size)
|
||||||
|
self._add_temp(merged_path)
|
||||||
|
current_video = merged_path
|
||||||
|
|
||||||
|
# Step 1.1: 若无音轨,补一条静音底,避免后续滤镜找不到 0:a
|
||||||
|
silent_path = str(config.TEMP_DIR / f"{output_name}_silent.mp4")
|
||||||
|
ffmpeg_utils.add_silence_audio(current_video, silent_path)
|
||||||
|
self._add_temp(silent_path)
|
||||||
|
current_video = silent_path
|
||||||
|
|
||||||
|
# Step 2: 添加字幕 (白字黑边,无底框,下半区域居中)
|
||||||
|
if subtitles:
|
||||||
|
subtitled_path = str(config.TEMP_DIR / f"{output_name}_subtitled.mp4")
|
||||||
|
subtitle_style = {
|
||||||
|
"font": ffmpeg_utils._get_font_path(),
|
||||||
|
"fontsize": 60,
|
||||||
|
"fontcolor": "white",
|
||||||
|
"borderw": 5,
|
||||||
|
"bordercolor": "black",
|
||||||
|
"box": 0, # 无底框
|
||||||
|
"y": "h-200", # 下半区域居中
|
||||||
|
}
|
||||||
|
ffmpeg_utils.add_multiple_subtitles(
|
||||||
|
current_video, subtitles, subtitled_path, default_style=subtitle_style
|
||||||
|
)
|
||||||
|
self._add_temp(subtitled_path)
|
||||||
|
current_video = subtitled_path
|
||||||
|
|
||||||
|
# Step 3: 叠加花字 (支持原子化参数)
|
||||||
|
if fancy_texts:
|
||||||
|
overlay_configs = []
|
||||||
|
for ft in fancy_texts:
|
||||||
|
text = ft.get("text", "")
|
||||||
|
style = ft.get("style")
|
||||||
|
custom_style = ft.get("custom_style")
|
||||||
|
|
||||||
|
# 如果 style 是字典,说明是原子化参数,直接使用
|
||||||
|
if isinstance(style, dict):
|
||||||
|
img_path = renderer.render(text, style, cache=False)
|
||||||
|
elif custom_style and isinstance(custom_style, dict):
|
||||||
|
# 兼容旧逻辑:如果有 custom_style,尝试通过原子化渲染器渲染
|
||||||
|
if "font_size" in custom_style:
|
||||||
|
img_path = renderer.render(text, custom_style, cache=False)
|
||||||
|
else:
|
||||||
|
# 回退到旧版 fancy_text
|
||||||
|
img_path = fancy_text.create_fancy_text(
|
||||||
|
text=text,
|
||||||
|
style=style if isinstance(style, str) else "subtitle",
|
||||||
|
custom_style={
|
||||||
|
**(custom_style or {}),
|
||||||
|
"font_name": "/System/Library/Fonts/PingFang.ttc",
|
||||||
|
},
|
||||||
|
cache=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 旧版逻辑
|
||||||
|
img_path = fancy_text.create_fancy_text(
|
||||||
|
text=text,
|
||||||
|
style=style if isinstance(style, str) else "subtitle",
|
||||||
|
custom_style={
|
||||||
|
"font_name": "/System/Library/Fonts/PingFang.ttc",
|
||||||
|
},
|
||||||
|
cache=False
|
||||||
|
)
|
||||||
|
|
||||||
|
overlay_configs.append({
|
||||||
|
"path": img_path,
|
||||||
|
"x": ft.get("x", "(W-w)/2"),
|
||||||
|
"y": ft.get("y", "(H-h)/2"),
|
||||||
|
"start": ft.get("start", 0),
|
||||||
|
"duration": ft.get("duration", 999)
|
||||||
|
})
|
||||||
|
|
||||||
|
fancy_path = str(config.TEMP_DIR / f"{output_name}_fancy.mp4")
|
||||||
|
ffmpeg_utils.overlay_multiple_images(
|
||||||
|
current_video, overlay_configs, fancy_path
|
||||||
|
)
|
||||||
|
self._add_temp(fancy_path)
|
||||||
|
current_video = fancy_path
|
||||||
|
|
||||||
|
# Step 4: 生成并混合旁白(火山 WS 优先,失败回退 Edge)
|
||||||
|
if voiceover_text:
|
||||||
|
vo_path = factory.generate_voiceover_volcengine(
|
||||||
|
text=voiceover_text,
|
||||||
|
voice_type=self.voice_type
|
||||||
|
)
|
||||||
|
self._add_temp(vo_path)
|
||||||
|
|
||||||
|
voiced_path = str(config.TEMP_DIR / f"{output_name}_voiced.mp4")
|
||||||
|
ffmpeg_utils.mix_audio(
|
||||||
|
current_video, vo_path, voiced_path,
|
||||||
|
audio_volume=1.5,
|
||||||
|
video_volume=0.2
|
||||||
|
)
|
||||||
|
self._add_temp(voiced_path)
|
||||||
|
current_video = voiced_path
|
||||||
|
|
||||||
|
elif voiceover_segments:
|
||||||
|
current_video = self._add_segmented_voiceover(
|
||||||
|
current_video, voiceover_segments, output_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: 添加BGM(淡入淡出,若 duck 失败会自动退回低音量混合)
|
||||||
|
if bgm_path:
|
||||||
|
bgm_output = str(config.TEMP_DIR / f"{output_name}_bgm.mp4")
|
||||||
|
ffmpeg_utils.add_bgm(
|
||||||
|
current_video, bgm_path, bgm_output,
|
||||||
|
bgm_volume=bgm_volume,
|
||||||
|
ducking=False, # 为避免兼容性问题,这里禁用 duck,保持低音量
|
||||||
|
duck_gain_db=-6.0,
|
||||||
|
fade_in=1.0,
|
||||||
|
fade_out=1.0
|
||||||
|
)
|
||||||
|
self._add_temp(bgm_output)
|
||||||
|
current_video = bgm_output
|
||||||
|
|
||||||
|
# Step 6: 输出最终文件
|
||||||
|
final_path = str(self.output_dir / f"{output_name}.mp4")
|
||||||
|
|
||||||
|
# 复制到输出目录
|
||||||
|
import shutil
|
||||||
|
shutil.copy(current_video, final_path)
|
||||||
|
|
||||||
|
logger.info(f"Composition complete: {final_path}")
|
||||||
|
|
||||||
|
# 上传到R2
|
||||||
|
if upload_to_r2:
|
||||||
|
r2_url = storage.upload_file(final_path)
|
||||||
|
logger.info(f"Uploaded to R2: {r2_url}")
|
||||||
|
return r2_url
|
||||||
|
|
||||||
|
return final_path
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理临时文件(保留最终输出)
|
||||||
|
self.cleanup()
|
||||||
|
|
||||||
|
def _add_segmented_voiceover(
|
||||||
|
self,
|
||||||
|
video_path: str,
|
||||||
|
segments: List[Dict[str, Any]],
|
||||||
|
output_name: str
|
||||||
|
) -> str:
|
||||||
|
"""添加分段旁白"""
|
||||||
|
if not segments:
|
||||||
|
return video_path
|
||||||
|
|
||||||
|
# 为每段生成音频
|
||||||
|
audio_files = []
|
||||||
|
for i, seg in enumerate(segments):
|
||||||
|
text = seg.get("text", "")
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
voice = seg.get("voice_type", self.voice_type)
|
||||||
|
audio_path = factory.generate_voiceover_volcengine(
|
||||||
|
text=text,
|
||||||
|
voice_type=voice,
|
||||||
|
output_path=str(config.TEMP_DIR / f"{output_name}_seg_{i}.mp3")
|
||||||
|
)
|
||||||
|
|
||||||
|
if audio_path:
|
||||||
|
audio_files.append({
|
||||||
|
"path": audio_path,
|
||||||
|
"start": seg.get("start", 0)
|
||||||
|
})
|
||||||
|
self._add_temp(audio_path)
|
||||||
|
|
||||||
|
if not audio_files:
|
||||||
|
return video_path
|
||||||
|
|
||||||
|
# 依次混入音频
|
||||||
|
current = video_path
|
||||||
|
for i, af in enumerate(audio_files):
|
||||||
|
output = str(config.TEMP_DIR / f"{output_name}_seg_mixed_{i}.mp4")
|
||||||
|
ffmpeg_utils.mix_audio(
|
||||||
|
current, af["path"], output,
|
||||||
|
audio_volume=1.0,
|
||||||
|
video_volume=0.2 if i == 0 else 1.0, # 只在第一次降低原视频音量
|
||||||
|
audio_start=af["start"]
|
||||||
|
)
|
||||||
|
self._add_temp(output)
|
||||||
|
current = output
|
||||||
|
|
||||||
|
return current
|
||||||
|
|
||||||
|
def compose_from_script(
|
||||||
|
self,
|
||||||
|
script: Dict[str, Any],
|
||||||
|
video_map: Dict[int, str],
|
||||||
|
bgm_path: str = None,
|
||||||
|
output_name: str = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
基于生成脚本和视频映射进行合成
|
||||||
|
|
||||||
|
Args:
|
||||||
|
script: 标准化分镜脚本
|
||||||
|
video_map: 场景ID到视频路径的映射
|
||||||
|
bgm_path: BGM路径
|
||||||
|
output_name: 输出文件名
|
||||||
|
"""
|
||||||
|
scenes = script.get("scenes", [])
|
||||||
|
if not scenes:
|
||||||
|
raise ValueError("Empty script")
|
||||||
|
|
||||||
|
video_paths = []
|
||||||
|
fancy_texts = []
|
||||||
|
|
||||||
|
# 1. 收集视频路径和花字 (按分镜顺序)
|
||||||
|
total_duration = 0.0
|
||||||
|
|
||||||
|
for scene in scenes:
|
||||||
|
scene_id = scene["id"]
|
||||||
|
video_path = video_map.get(scene_id)
|
||||||
|
|
||||||
|
if not video_path or not os.path.exists(video_path):
|
||||||
|
logger.warning(f"Missing video for scene {scene_id}, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取实际视频时长
|
||||||
|
try:
|
||||||
|
info = ffmpeg_utils.get_video_info(video_path)
|
||||||
|
duration = float(info.get("duration", 5.0))
|
||||||
|
except:
|
||||||
|
duration = 5.0
|
||||||
|
|
||||||
|
video_paths.append(video_path)
|
||||||
|
|
||||||
|
# 花字 (白字黑边,无底框,固定在上半区域居中)
|
||||||
|
if "fancy_text" in scene:
|
||||||
|
ft = scene["fancy_text"]
|
||||||
|
if isinstance(ft, dict):
|
||||||
|
text = ft.get("text", "")
|
||||||
|
|
||||||
|
if text:
|
||||||
|
# 固定样式:白字黑边,无底框
|
||||||
|
fixed_style = {
|
||||||
|
"font_size": 72,
|
||||||
|
"font_color": "#FFFFFF",
|
||||||
|
"stroke": {"color": "#000000", "width": 5}
|
||||||
|
# 无 background,不加底框
|
||||||
|
}
|
||||||
|
|
||||||
|
fancy_texts.append({
|
||||||
|
"text": text,
|
||||||
|
"style": fixed_style,
|
||||||
|
"x": "(W-w)/2", # 居中
|
||||||
|
"y": "180", # 上半区域
|
||||||
|
"start": total_duration + float(ft.get("start_time", 0)),
|
||||||
|
"duration": float(ft.get("duration", duration))
|
||||||
|
})
|
||||||
|
|
||||||
|
total_duration += duration
|
||||||
|
|
||||||
|
# 2. 拼接视频
|
||||||
|
timestamp = int(time.time())
|
||||||
|
output_name = output_name or f"composed_{timestamp}"
|
||||||
|
|
||||||
|
merged_path = str(config.TEMP_DIR / f"{output_name}_merged.mp4")
|
||||||
|
ffmpeg_utils.concat_videos(video_paths, merged_path, self.target_size)
|
||||||
|
self._add_temp(merged_path)
|
||||||
|
current_video = merged_path
|
||||||
|
|
||||||
|
# 3. 处理整体旁白时间轴 (New Logic)
|
||||||
|
voiceover_timeline = script.get("voiceover_timeline", [])
|
||||||
|
mixed_audio_path = str(config.TEMP_DIR / f"{output_name}_mixed_vo.mp3")
|
||||||
|
|
||||||
|
# 初始化静音底轨 (长度为 total_duration)
|
||||||
|
ffmpeg_utils._run_ffmpeg([
|
||||||
|
ffmpeg_utils.FFMPEG_PATH, "-y",
|
||||||
|
"-f", "lavfi", "-i", "anullsrc=r=44100:cl=stereo",
|
||||||
|
"-t", str(total_duration),
|
||||||
|
"-c:a", "mp3",
|
||||||
|
mixed_audio_path
|
||||||
|
])
|
||||||
|
self._add_temp(mixed_audio_path)
|
||||||
|
|
||||||
|
subtitles = []
|
||||||
|
|
||||||
|
if voiceover_timeline:
|
||||||
|
for i, item in enumerate(voiceover_timeline):
|
||||||
|
text = item.get("text", "")
|
||||||
|
sub_text = item.get("subtitle", text)
|
||||||
|
|
||||||
|
# 支持两种格式:
|
||||||
|
# 新格式: start_time (秒), duration (秒) - 直接使用绝对时间
|
||||||
|
# 旧格式: start_ratio (0-1), duration_ratio (0-1) - 按比例计算
|
||||||
|
if "start_time" in item:
|
||||||
|
# 新格式:直接使用秒
|
||||||
|
target_start = float(item.get("start_time", 0))
|
||||||
|
target_duration = float(item.get("duration", 3))
|
||||||
|
else:
|
||||||
|
# 旧格式:按比例计算(向后兼容)
|
||||||
|
start_ratio = float(item.get("start_ratio", 0))
|
||||||
|
duration_ratio = float(item.get("duration_ratio", 0))
|
||||||
|
target_start = start_ratio * total_duration
|
||||||
|
target_duration = duration_ratio * total_duration
|
||||||
|
|
||||||
|
if not text: continue
|
||||||
|
|
||||||
|
# 生成 TTS
|
||||||
|
tts_path = factory.generate_voiceover_volcengine(
|
||||||
|
text=text,
|
||||||
|
voice_type=self.voice_type,
|
||||||
|
output_path=str(config.TEMP_DIR / f"{output_name}_vo_{i}.mp3")
|
||||||
|
)
|
||||||
|
self._add_temp(tts_path)
|
||||||
|
|
||||||
|
# 调整时长
|
||||||
|
adjusted_path = str(config.TEMP_DIR / f"{output_name}_vo_adj_{i}.mp3")
|
||||||
|
ffmpeg_utils.adjust_audio_duration(tts_path, target_duration, adjusted_path)
|
||||||
|
self._add_temp(adjusted_path)
|
||||||
|
|
||||||
|
# 混合到总音轨
|
||||||
|
new_mixed = str(config.TEMP_DIR / f"{output_name}_mixed_{i}.mp3")
|
||||||
|
ffmpeg_utils.mix_audio_at_offset(mixed_audio_path, adjusted_path, target_start, new_mixed)
|
||||||
|
mixed_audio_path = new_mixed # Update current mixed path
|
||||||
|
self._add_temp(new_mixed)
|
||||||
|
|
||||||
|
# 添加字幕配置 (完全同步)
|
||||||
|
subtitles.append({
|
||||||
|
"text": ffmpeg_utils.wrap_text_smart(sub_text),
|
||||||
|
"start": target_start,
|
||||||
|
"duration": target_duration,
|
||||||
|
"style": {} # Default
|
||||||
|
})
|
||||||
|
|
||||||
|
# 4. 将合成好的旁白混入视频
|
||||||
|
voiced_path = str(config.TEMP_DIR / f"{output_name}_voiced.mp4")
|
||||||
|
ffmpeg_utils.mix_audio(
|
||||||
|
current_video, mixed_audio_path, voiced_path,
|
||||||
|
audio_volume=1.5,
|
||||||
|
video_volume=0.2 # 压低原音
|
||||||
|
)
|
||||||
|
self._add_temp(voiced_path)
|
||||||
|
current_video = voiced_path
|
||||||
|
|
||||||
|
# 5. 添加字幕 (使用新的 ffmpeg_utils.add_multiple_subtitles)
|
||||||
|
if subtitles:
|
||||||
|
subtitled_path = str(config.TEMP_DIR / f"{output_name}_subtitled.mp4")
|
||||||
|
subtitle_style = {
|
||||||
|
"font": ffmpeg_utils._get_font_path(),
|
||||||
|
"fontsize": 60,
|
||||||
|
"fontcolor": "white",
|
||||||
|
"borderw": 5,
|
||||||
|
"bordercolor": "black",
|
||||||
|
"box": 0, # 无底框
|
||||||
|
"y": "h-200", # 下半区域居中
|
||||||
|
}
|
||||||
|
ffmpeg_utils.add_multiple_subtitles(
|
||||||
|
current_video, subtitles, subtitled_path, default_style=subtitle_style
|
||||||
|
)
|
||||||
|
self._add_temp(subtitled_path)
|
||||||
|
current_video = subtitled_path
|
||||||
|
|
||||||
|
# 6. 添加花字
|
||||||
|
if fancy_texts:
|
||||||
|
fancy_path = str(config.TEMP_DIR / f"{output_name}_fancy.mp4")
|
||||||
|
|
||||||
|
overlay_configs = []
|
||||||
|
for ft in fancy_texts:
|
||||||
|
# 渲染花字图片
|
||||||
|
img_path = renderer.render(ft["text"], ft["style"], cache=False)
|
||||||
|
overlay_configs.append({
|
||||||
|
"path": img_path,
|
||||||
|
"x": ft["x"],
|
||||||
|
"y": ft["y"],
|
||||||
|
"start": ft["start"],
|
||||||
|
"duration": ft["duration"]
|
||||||
|
})
|
||||||
|
|
||||||
|
ffmpeg_utils.overlay_multiple_images(
|
||||||
|
current_video, overlay_configs, fancy_path
|
||||||
|
)
|
||||||
|
self._add_temp(fancy_path)
|
||||||
|
current_video = fancy_path
|
||||||
|
|
||||||
|
# 7. 添加 BGM
|
||||||
|
if bgm_path:
|
||||||
|
bgm_output = str(config.TEMP_DIR / f"{output_name}_bgm.mp4")
|
||||||
|
ffmpeg_utils.add_bgm(
|
||||||
|
current_video, bgm_path, bgm_output,
|
||||||
|
bgm_volume=0.15
|
||||||
|
)
|
||||||
|
self._add_temp(bgm_output)
|
||||||
|
current_video = bgm_output
|
||||||
|
|
||||||
|
# 8. 输出最终文件
|
||||||
|
final_path = str(self.output_dir / f"{output_name}.mp4")
|
||||||
|
import shutil
|
||||||
|
shutil.copy(current_video, final_path)
|
||||||
|
|
||||||
|
logger.info(f"Composition complete: {final_path}")
|
||||||
|
|
||||||
|
self.cleanup()
|
||||||
|
return final_path
|
||||||
|
|
||||||
|
|
||||||
|
def compose_standard_task(self, task_config: Dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
执行标准合成任务 (Legacy)
|
||||||
|
"""
|
||||||
|
settings = task_config.get("settings", {})
|
||||||
|
self.voice_type = settings.get("voice_type", self.voice_type)
|
||||||
|
|
||||||
|
# 1. 准备视频片段
|
||||||
|
video_paths = []
|
||||||
|
for seg in task_config.get("segments", []):
|
||||||
|
path = seg.get("path") or seg.get("video_path")
|
||||||
|
if not path: continue
|
||||||
|
video_paths.append(path)
|
||||||
|
|
||||||
|
# 2. 解析时间轴
|
||||||
|
subtitles = []
|
||||||
|
fancy_texts = []
|
||||||
|
voiceover_segments = []
|
||||||
|
|
||||||
|
for item in task_config.get("timeline", []):
|
||||||
|
itype = item.get("type")
|
||||||
|
|
||||||
|
if not itype:
|
||||||
|
if "text" in item and ("style" in item or "x" in item or "y" in item):
|
||||||
|
itype = "fancy_text"
|
||||||
|
elif "text" in item and "duration" in item and "start" in item:
|
||||||
|
itype = "subtitle"
|
||||||
|
elif "text" in item and "start" in item:
|
||||||
|
itype = "voiceover"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if itype == "subtitle":
|
||||||
|
subtitles.append(item)
|
||||||
|
elif itype == "fancy_text":
|
||||||
|
if "x" not in item and "position" in item:
|
||||||
|
item["x"] = item["position"].get("x")
|
||||||
|
item["y"] = item["position"].get("y")
|
||||||
|
fancy_texts.append(item)
|
||||||
|
elif itype == "voiceover":
|
||||||
|
voiceover_segments.append(item)
|
||||||
|
|
||||||
|
return self.compose(
|
||||||
|
video_paths=video_paths,
|
||||||
|
subtitles=subtitles,
|
||||||
|
fancy_texts=fancy_texts,
|
||||||
|
voiceover_segments=voiceover_segments,
|
||||||
|
bgm_path=settings.get("bgm_path"),
|
||||||
|
bgm_volume=settings.get("bgm_volume", 0.06),
|
||||||
|
output_name=settings.get("output_name"),
|
||||||
|
upload_to_r2=settings.get("upload_to_r2", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compose_product_video(
|
||||||
|
video_paths: List[str],
|
||||||
|
subtitle_configs: List[Dict[str, Any]] = None,
|
||||||
|
fancy_text_configs: List[Dict[str, Any]] = None,
|
||||||
|
voiceover_text: str = None,
|
||||||
|
bgm_path: str = None,
|
||||||
|
output_path: str = None,
|
||||||
|
voice_type: str = "sweet_female"
|
||||||
|
) -> str:
|
||||||
|
"""便捷函数:合成商品短视频"""
|
||||||
|
composer = VideoComposer(voice_type=voice_type)
|
||||||
|
|
||||||
|
output_name = None
|
||||||
|
if output_path:
|
||||||
|
output_name = Path(output_path).stem
|
||||||
|
composer.output_dir = Path(output_path).parent
|
||||||
|
|
||||||
|
return composer.compose(
|
||||||
|
video_paths=video_paths,
|
||||||
|
subtitles=subtitle_configs,
|
||||||
|
fancy_texts=fancy_text_configs,
|
||||||
|
voiceover_text=voiceover_text,
|
||||||
|
bgm_path=bgm_path,
|
||||||
|
output_name=output_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def quick_compose(
|
||||||
|
video_folder: str,
|
||||||
|
script: List[Dict[str, Any]],
|
||||||
|
output_path: str = None,
|
||||||
|
voice_type: str = "sweet_female",
|
||||||
|
bgm_path: str = None
|
||||||
|
) -> str:
|
||||||
|
"""快速合成:从文件夹读取视频,配合脚本合成"""
|
||||||
|
folder = Path(video_folder)
|
||||||
|
|
||||||
|
video_files = sorted([
|
||||||
|
f for f in folder.iterdir()
|
||||||
|
if f.suffix.lower() in ['.mp4', '.mov', '.avi', '.mkv']
|
||||||
|
])
|
||||||
|
|
||||||
|
video_paths = []
|
||||||
|
subtitles = []
|
||||||
|
fancy_texts = []
|
||||||
|
voiceovers = []
|
||||||
|
|
||||||
|
current_time = 0
|
||||||
|
|
||||||
|
for i, item in enumerate(script):
|
||||||
|
if "video" in item:
|
||||||
|
vp = folder / item["video"]
|
||||||
|
elif i < len(video_files):
|
||||||
|
vp = video_files[i]
|
||||||
|
else:
|
||||||
|
logger.warning(f"No video for script item {i}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
video_paths.append(str(vp))
|
||||||
|
|
||||||
|
try:
|
||||||
|
info = ffmpeg_utils.get_video_info(str(vp))
|
||||||
|
duration = info.get("duration", 5)
|
||||||
|
except:
|
||||||
|
duration = item.get("duration", 5)
|
||||||
|
|
||||||
|
if "subtitle" in item:
|
||||||
|
subtitles.append({
|
||||||
|
"text": item["subtitle"],
|
||||||
|
"start": current_time,
|
||||||
|
"duration": duration,
|
||||||
|
"style": item.get("subtitle_style", {})
|
||||||
|
})
|
||||||
|
|
||||||
|
if "fancy_text" in item:
|
||||||
|
ft = item["fancy_text"]
|
||||||
|
if isinstance(ft, str):
|
||||||
|
ft = {"text": ft}
|
||||||
|
fancy_texts.append({
|
||||||
|
"text": ft.get("text", ""),
|
||||||
|
"style": ft.get("style", "highlight"),
|
||||||
|
"custom_style": ft.get("custom_style"),
|
||||||
|
"x": ft.get("x", "(W-w)/2"),
|
||||||
|
"y": ft.get("y", 200),
|
||||||
|
"start": current_time,
|
||||||
|
"duration": duration
|
||||||
|
})
|
||||||
|
|
||||||
|
if "voiceover" in item:
|
||||||
|
voiceovers.append(item["voiceover"])
|
||||||
|
|
||||||
|
current_time += duration
|
||||||
|
|
||||||
|
voiceover_text = "。".join(voiceovers) if voiceovers else None
|
||||||
|
|
||||||
|
return compose_product_video(
|
||||||
|
video_paths=video_paths,
|
||||||
|
subtitle_configs=subtitles if subtitles else None,
|
||||||
|
fancy_text_configs=fancy_texts if fancy_texts else None,
|
||||||
|
voiceover_text=voiceover_text,
|
||||||
|
bgm_path=bgm_path,
|
||||||
|
output_path=output_path,
|
||||||
|
voice_type=voice_type
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 示例用法
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def example_hairclip_video():
|
||||||
|
"""示例:发夹商品视频合成"""
|
||||||
|
素材目录 = Path("/Volumes/Tony/video-flow/素材/发夹/合成图拆分镜")
|
||||||
|
|
||||||
|
video_paths = [
|
||||||
|
str(素材目录 / "视频-分镜1.mp4"),
|
||||||
|
str(素材目录 / "视频-分镜2.mp4"),
|
||||||
|
str(素材目录 / "视频-分镜3.mp4"),
|
||||||
|
str(素材目录 / "视频-分镜4.mp4"),
|
||||||
|
str(素材目录 / "视频-分镜5.mp4"),
|
||||||
|
]
|
||||||
|
|
||||||
|
script = [
|
||||||
|
{
|
||||||
|
"subtitle": "塌马尾 vs 高颅顶",
|
||||||
|
"fancy_text": {
|
||||||
|
"text": "塌马尾 vs 高颅顶",
|
||||||
|
"style": "comparison",
|
||||||
|
"y": 150
|
||||||
|
},
|
||||||
|
"voiceover": "普通马尾和高颅顶马尾的区别,你看出来了吗",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"subtitle": "3秒出门,无需皮筋",
|
||||||
|
"fancy_text": {"text": "发量+50%", "style": "bubble", "y": 300},
|
||||||
|
"voiceover": "只需要三秒钟,不需要皮筋,发量瞬间增加百分之五十",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"subtitle": "发量+50%",
|
||||||
|
"voiceover": "蓬松的高颅顶效果,让你瞬间变美",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"subtitle": "狂甩不掉!",
|
||||||
|
"fancy_text": {"text": "狂甩不掉!", "style": "warning", "y": 400},
|
||||||
|
"voiceover": "而且超级牢固,怎么甩都不会掉",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"subtitle": "¥3.99 立即抢购",
|
||||||
|
"fancy_text": {"text": "3.99", "style": "price", "y": 500},
|
||||||
|
"voiceover": "只要三块九毛九,点击下方链接立即购买",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
output = quick_compose(
|
||||||
|
video_folder=str(素材目录),
|
||||||
|
script=script,
|
||||||
|
output_path="/Volumes/Tony/video-flow/output/发夹_合成视频.mp4",
|
||||||
|
voice_type="sweet_female"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"视频合成完成: {output}")
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
example_hairclip_video()
|
||||||
305
modules/db_manager.py
Normal file
305
modules/db_manager.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
"""
|
||||||
|
数据库管理模块 (SQLAlchemy)
|
||||||
|
负责项目数据、任务状态、素材路径的持久化存储
|
||||||
|
支持 SQLite 和 PostgreSQL
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, Column, String, Integer, Text, Float, UniqueConstraint, func
|
||||||
|
from sqlalchemy.orm import sessionmaker, scoped_session, declarative_base
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
class Project(Base):
|
||||||
|
__tablename__ = 'projects'
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True)
|
||||||
|
name = Column(String)
|
||||||
|
status = Column(String) # created, script_generated, images_generated, videos_generated, completed
|
||||||
|
product_info = Column(Text) # JSON string (SQLite) or JSONB (PG - using Text for compat)
|
||||||
|
script_data = Column(Text) # JSON string
|
||||||
|
created_at = Column(Float, default=time.time)
|
||||||
|
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||||||
|
|
||||||
|
class SceneAsset(Base):
|
||||||
|
__tablename__ = 'scene_assets'
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
project_id = Column(String, index=True)
|
||||||
|
scene_id = Column(Integer)
|
||||||
|
asset_type = Column(String) # image, video
|
||||||
|
status = Column(String) # pending, processing, completed, failed
|
||||||
|
local_path = Column(Text, nullable=True)
|
||||||
|
remote_url = Column(Text, nullable=True)
|
||||||
|
task_id = Column(String, nullable=True) # 外部 API 的任务 ID
|
||||||
|
metadata_json = Column("metadata", Text, nullable=True) # JSON string (renamed to avoid conflict with metadata attr)
|
||||||
|
created_at = Column(Float, default=time.time)
|
||||||
|
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||||||
|
|
||||||
|
__table_args__ = (UniqueConstraint('project_id', 'scene_id', 'asset_type', name='uix_project_scene_asset'),)
|
||||||
|
|
||||||
|
class AppConfig(Base):
|
||||||
|
__tablename__ = 'app_config'
|
||||||
|
|
||||||
|
key = Column(String, primary_key=True)
|
||||||
|
value = Column(Text) # JSON string
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
updated_at = Column(Float, default=time.time, onupdate=time.time)
|
||||||
|
|
||||||
|
class DBManager:
|
||||||
|
def __init__(self, connection_string: str = None):
|
||||||
|
if not connection_string:
|
||||||
|
connection_string = config.DB_CONNECTION_STRING
|
||||||
|
|
||||||
|
self.engine = create_engine(connection_string, pool_recycle=3600)
|
||||||
|
self.Session = scoped_session(sessionmaker(bind=self.engine))
|
||||||
|
self._init_db()
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
"""初始化表结构"""
|
||||||
|
Base.metadata.create_all(self.engine)
|
||||||
|
|
||||||
|
def _get_session(self):
|
||||||
|
return self.Session()
|
||||||
|
|
||||||
|
# --- Project Operations ---
|
||||||
|
|
||||||
|
def create_project(self, project_id: str, name: str, product_info: Dict[str, Any]):
|
||||||
|
session = self._get_session()
|
||||||
|
try:
|
||||||
|
# Check if exists
|
||||||
|
existing = session.query(Project).filter_by(id=project_id).first()
|
||||||
|
if existing:
|
||||||
|
logger.warning(f"Project {project_id} already exists.")
|
||||||
|
return
|
||||||
|
|
||||||
|
new_project = Project(
|
||||||
|
id=project_id,
|
||||||
|
name=name,
|
||||||
|
status="created",
|
||||||
|
product_info=json.dumps(product_info, ensure_ascii=False),
|
||||||
|
created_at=time.time(),
|
||||||
|
updated_at=time.time()
|
||||||
|
)
|
||||||
|
session.add(new_project)
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
logger.error(f"Error creating project: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def update_project_script(self, project_id: str, script: Dict[str, Any]):
|
||||||
|
session = self._get_session()
|
||||||
|
try:
|
||||||
|
project = session.query(Project).filter_by(id=project_id).first()
|
||||||
|
if project:
|
||||||
|
project.script_data = json.dumps(script, ensure_ascii=False)
|
||||||
|
project.status = "script_generated"
|
||||||
|
project.updated_at = time.time()
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
logger.error(f"Error updating script: {e}")
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def update_project_status(self, project_id: str, status: str):
|
||||||
|
session = self._get_session()
|
||||||
|
try:
|
||||||
|
project = session.query(Project).filter_by(id=project_id).first()
|
||||||
|
if project:
|
||||||
|
project.status = status
|
||||||
|
project.updated_at = time.time()
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
logger.error(f"Error updating status: {e}")
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def get_project(self, project_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
session = self._get_session()
|
||||||
|
try:
|
||||||
|
project = session.query(Project).filter_by(id=project_id).first()
|
||||||
|
if project:
|
||||||
|
data = {
|
||||||
|
"id": project.id,
|
||||||
|
"name": project.name,
|
||||||
|
"status": project.status,
|
||||||
|
"product_info": json.loads(project.product_info) if project.product_info else {},
|
||||||
|
"script_data": json.loads(project.script_data) if project.script_data else None,
|
||||||
|
"created_at": project.created_at,
|
||||||
|
"updated_at": project.updated_at
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def list_projects(self) -> List[Dict[str, Any]]:
|
||||||
|
session = self._get_session()
|
||||||
|
try:
|
||||||
|
projects = session.query(Project).order_by(Project.updated_at.desc()).all()
|
||||||
|
results = []
|
||||||
|
for p in projects:
|
||||||
|
results.append({
|
||||||
|
"id": p.id,
|
||||||
|
"name": p.name,
|
||||||
|
"status": p.status,
|
||||||
|
"updated_at": p.updated_at
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
# --- Asset/Task Operations ---
|
||||||
|
|
||||||
|
def save_asset(self, project_id: str, scene_id: int, asset_type: str,
|
||||||
|
status: str, local_path: str = None, remote_url: str = None,
|
||||||
|
task_id: str = None, metadata: Dict = None):
|
||||||
|
"""保存或更新资产记录 (UPSERT 逻辑)"""
|
||||||
|
session = self._get_session()
|
||||||
|
try:
|
||||||
|
asset = session.query(SceneAsset).filter_by(
|
||||||
|
project_id=project_id,
|
||||||
|
scene_id=scene_id,
|
||||||
|
asset_type=asset_type
|
||||||
|
).first()
|
||||||
|
|
||||||
|
meta_json = json.dumps(metadata, ensure_ascii=False) if metadata else "{}"
|
||||||
|
|
||||||
|
if asset:
|
||||||
|
asset.status = status
|
||||||
|
asset.local_path = local_path
|
||||||
|
asset.remote_url = remote_url
|
||||||
|
asset.task_id = task_id
|
||||||
|
asset.metadata_json = meta_json
|
||||||
|
asset.updated_at = time.time()
|
||||||
|
else:
|
||||||
|
new_asset = SceneAsset(
|
||||||
|
project_id=project_id,
|
||||||
|
scene_id=scene_id,
|
||||||
|
asset_type=asset_type,
|
||||||
|
status=status,
|
||||||
|
local_path=local_path,
|
||||||
|
remote_url=remote_url,
|
||||||
|
task_id=task_id,
|
||||||
|
metadata_json=meta_json,
|
||||||
|
created_at=time.time(),
|
||||||
|
updated_at=time.time()
|
||||||
|
)
|
||||||
|
session.add(new_asset)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
logger.error(f"Error saving asset: {e}")
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def get_assets(self, project_id: str, asset_type: str = None) -> List[Dict[str, Any]]:
|
||||||
|
session = self._get_session()
|
||||||
|
try:
|
||||||
|
query = session.query(SceneAsset).filter_by(project_id=project_id)
|
||||||
|
if asset_type:
|
||||||
|
query = query.filter_by(asset_type=asset_type)
|
||||||
|
|
||||||
|
assets = query.all()
|
||||||
|
results = []
|
||||||
|
for a in assets:
|
||||||
|
data = {
|
||||||
|
"id": a.id,
|
||||||
|
"project_id": a.project_id,
|
||||||
|
"scene_id": a.scene_id,
|
||||||
|
"asset_type": a.asset_type,
|
||||||
|
"status": a.status,
|
||||||
|
"local_path": a.local_path,
|
||||||
|
"remote_url": a.remote_url,
|
||||||
|
"task_id": a.task_id,
|
||||||
|
"metadata": json.loads(a.metadata_json) if a.metadata_json else {},
|
||||||
|
"updated_at": a.updated_at
|
||||||
|
}
|
||||||
|
results.append(data)
|
||||||
|
return results
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def get_asset(self, project_id: str, scene_id: int, asset_type: str) -> Optional[Dict[str, Any]]:
|
||||||
|
session = self._get_session()
|
||||||
|
try:
|
||||||
|
a = session.query(SceneAsset).filter_by(
|
||||||
|
project_id=project_id,
|
||||||
|
scene_id=scene_id,
|
||||||
|
asset_type=asset_type
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if a:
|
||||||
|
return {
|
||||||
|
"id": a.id,
|
||||||
|
"project_id": a.project_id,
|
||||||
|
"scene_id": a.scene_id,
|
||||||
|
"asset_type": a.asset_type,
|
||||||
|
"status": a.status,
|
||||||
|
"local_path": a.local_path,
|
||||||
|
"remote_url": a.remote_url,
|
||||||
|
"task_id": a.task_id,
|
||||||
|
"metadata": json.loads(a.metadata_json) if a.metadata_json else {},
|
||||||
|
"updated_at": a.updated_at
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
# --- Config/Prompt Operations ---
|
||||||
|
|
||||||
|
def get_config(self, key: str, default: Any = None) -> Any:
|
||||||
|
session = self._get_session()
|
||||||
|
try:
|
||||||
|
cfg = session.query(AppConfig).filter_by(key=key).first()
|
||||||
|
if cfg:
|
||||||
|
try:
|
||||||
|
return json.loads(cfg.value)
|
||||||
|
except:
|
||||||
|
return cfg.value
|
||||||
|
return default
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def set_config(self, key: str, value: Any, description: str = None):
|
||||||
|
session = self._get_session()
|
||||||
|
try:
|
||||||
|
json_val = json.dumps(value, ensure_ascii=False)
|
||||||
|
|
||||||
|
cfg = session.query(AppConfig).filter_by(key=key).first()
|
||||||
|
if cfg:
|
||||||
|
cfg.value = json_val
|
||||||
|
if description:
|
||||||
|
cfg.description = description
|
||||||
|
cfg.updated_at = time.time()
|
||||||
|
else:
|
||||||
|
new_cfg = AppConfig(
|
||||||
|
key=key,
|
||||||
|
value=json_val,
|
||||||
|
description=description,
|
||||||
|
updated_at=time.time()
|
||||||
|
)
|
||||||
|
session.add(new_cfg)
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
logger.error(f"Error setting config: {e}")
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
db = DBManager()
|
||||||
269
modules/editor.py
Normal file
269
modules/editor.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""
|
||||||
|
MatchMe Studio - Editor Module (Assembly + BGM)
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from moviepy.editor import (
|
||||||
|
VideoFileClip, AudioFileClip, TextClip,
|
||||||
|
CompositeVideoClip, CompositeAudioClip,
|
||||||
|
concatenate_videoclips
|
||||||
|
)
|
||||||
|
|
||||||
|
import config
|
||||||
|
from modules import storage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Video Assembly
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def download_video(url: str) -> str:
|
||||||
|
"""Download video from URL to temp."""
|
||||||
|
filename = f"dl_{Path(url).name}"
|
||||||
|
local_path = config.TEMP_DIR / filename
|
||||||
|
|
||||||
|
with open(local_path, "wb") as f:
|
||||||
|
f.write(requests.get(url).content)
|
||||||
|
|
||||||
|
return str(local_path)
|
||||||
|
|
||||||
|
|
||||||
|
def concatenate_scenes(video_urls: List[str]) -> str:
|
||||||
|
"""Concatenate multiple video clips into one."""
|
||||||
|
logger.info(f"Concatenating {len(video_urls)} clips...")
|
||||||
|
|
||||||
|
clips = []
|
||||||
|
for url in video_urls:
|
||||||
|
local_path = download_video(url)
|
||||||
|
clip = VideoFileClip(local_path)
|
||||||
|
|
||||||
|
# Resize to 9:16 if needed
|
||||||
|
if clip.w != 1080 or clip.h != 1920:
|
||||||
|
clip = clip.resize(newsize=(1080, 1920))
|
||||||
|
|
||||||
|
clips.append(clip)
|
||||||
|
|
||||||
|
final = concatenate_videoclips(clips, method="compose")
|
||||||
|
|
||||||
|
output_path = config.TEMP_DIR / f"merged_{int(__import__('time').time())}.mp4"
|
||||||
|
final.write_videofile(
|
||||||
|
str(output_path),
|
||||||
|
fps=30,
|
||||||
|
codec="libx264",
|
||||||
|
audio_codec="aac",
|
||||||
|
threads=4,
|
||||||
|
logger=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
for clip in clips:
|
||||||
|
clip.close()
|
||||||
|
final.close()
|
||||||
|
|
||||||
|
return str(output_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Subtitle Burning
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def burn_subtitles(
|
||||||
|
video_path: str,
|
||||||
|
scenes: List[Dict[str, Any]]
|
||||||
|
) -> str:
|
||||||
|
"""Burn subtitles onto video."""
|
||||||
|
logger.info("Burning subtitles...")
|
||||||
|
|
||||||
|
clip = VideoFileClip(video_path)
|
||||||
|
subtitle_clips = []
|
||||||
|
|
||||||
|
current_time = 0
|
||||||
|
for scene in scenes:
|
||||||
|
voiceover = scene.get("voiceover", "")
|
||||||
|
duration = scene.get("duration", 5)
|
||||||
|
|
||||||
|
if voiceover:
|
||||||
|
try:
|
||||||
|
txt = TextClip(
|
||||||
|
voiceover,
|
||||||
|
fontsize=48,
|
||||||
|
color='white',
|
||||||
|
stroke_color='black',
|
||||||
|
stroke_width=2,
|
||||||
|
font='DejaVu-Sans',
|
||||||
|
method='caption',
|
||||||
|
size=(900, None)
|
||||||
|
).set_position(('center', 1600)).set_start(current_time).set_duration(duration)
|
||||||
|
|
||||||
|
subtitle_clips.append(txt)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Subtitle error: {e}")
|
||||||
|
|
||||||
|
current_time += duration
|
||||||
|
|
||||||
|
if subtitle_clips:
|
||||||
|
final = CompositeVideoClip([clip] + subtitle_clips)
|
||||||
|
else:
|
||||||
|
final = clip
|
||||||
|
|
||||||
|
output_path = config.TEMP_DIR / f"subtitled_{int(__import__('time').time())}.mp4"
|
||||||
|
final.write_videofile(
|
||||||
|
str(output_path),
|
||||||
|
fps=30,
|
||||||
|
codec="libx264",
|
||||||
|
audio_codec="aac",
|
||||||
|
threads=4,
|
||||||
|
logger=None
|
||||||
|
)
|
||||||
|
|
||||||
|
clip.close()
|
||||||
|
final.close()
|
||||||
|
|
||||||
|
return str(output_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Voiceover Mixing
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def mix_voiceover(video_path: str, voiceover_url: str) -> str:
|
||||||
|
"""Mix voiceover audio with video."""
|
||||||
|
if not voiceover_url:
|
||||||
|
return video_path
|
||||||
|
|
||||||
|
logger.info("Mixing voiceover...")
|
||||||
|
|
||||||
|
# Download voiceover
|
||||||
|
vo_local = download_video(voiceover_url)
|
||||||
|
|
||||||
|
video = VideoFileClip(video_path)
|
||||||
|
voiceover = AudioFileClip(vo_local)
|
||||||
|
|
||||||
|
# Trim voiceover if longer than video
|
||||||
|
if voiceover.duration > video.duration:
|
||||||
|
voiceover = voiceover.subclip(0, video.duration)
|
||||||
|
|
||||||
|
# Mix with original audio (if any)
|
||||||
|
if video.audio:
|
||||||
|
mixed = CompositeAudioClip([
|
||||||
|
video.audio.volumex(0.3), # Lower original
|
||||||
|
voiceover.volumex(1.0)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
mixed = voiceover
|
||||||
|
|
||||||
|
final = video.set_audio(mixed)
|
||||||
|
|
||||||
|
output_path = config.TEMP_DIR / f"voiced_{int(__import__('time').time())}.mp4"
|
||||||
|
final.write_videofile(
|
||||||
|
str(output_path),
|
||||||
|
fps=30,
|
||||||
|
codec="libx264",
|
||||||
|
audio_codec="aac",
|
||||||
|
threads=4,
|
||||||
|
logger=None
|
||||||
|
)
|
||||||
|
|
||||||
|
video.close()
|
||||||
|
voiceover.close()
|
||||||
|
final.close()
|
||||||
|
|
||||||
|
return str(output_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# BGM Mixing
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def mix_bgm(
|
||||||
|
video_path: str,
|
||||||
|
bgm_path: str,
|
||||||
|
bgm_volume: float = 0.2
|
||||||
|
) -> str:
|
||||||
|
"""Mix background music with video."""
|
||||||
|
logger.info("Mixing BGM...")
|
||||||
|
|
||||||
|
video = VideoFileClip(video_path)
|
||||||
|
bgm = AudioFileClip(bgm_path)
|
||||||
|
|
||||||
|
# Loop BGM if shorter than video
|
||||||
|
if bgm.duration < video.duration:
|
||||||
|
loops_needed = int(video.duration / bgm.duration) + 1
|
||||||
|
bgm = bgm.loop(loops_needed)
|
||||||
|
|
||||||
|
# Trim to video length
|
||||||
|
bgm = bgm.subclip(0, video.duration).volumex(bgm_volume)
|
||||||
|
|
||||||
|
# Mix with existing audio
|
||||||
|
if video.audio:
|
||||||
|
mixed = CompositeAudioClip([video.audio, bgm])
|
||||||
|
else:
|
||||||
|
mixed = bgm
|
||||||
|
|
||||||
|
final = video.set_audio(mixed)
|
||||||
|
|
||||||
|
output_path = config.TEMP_DIR / f"bgm_{int(__import__('time').time())}.mp4"
|
||||||
|
final.write_videofile(
|
||||||
|
str(output_path),
|
||||||
|
fps=30,
|
||||||
|
codec="libx264",
|
||||||
|
audio_codec="aac",
|
||||||
|
threads=4,
|
||||||
|
logger=None
|
||||||
|
)
|
||||||
|
|
||||||
|
video.close()
|
||||||
|
bgm.close()
|
||||||
|
final.close()
|
||||||
|
|
||||||
|
return str(output_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Full Pipeline
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def assemble_final_video(
|
||||||
|
video_urls: List[str],
|
||||||
|
scenes: List[Dict[str, Any]],
|
||||||
|
voiceover_url: str = "",
|
||||||
|
bgm_url: str = ""
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Full assembly pipeline:
|
||||||
|
1. Concatenate scene videos
|
||||||
|
2. Burn subtitles
|
||||||
|
3. Mix voiceover
|
||||||
|
4. Mix BGM
|
||||||
|
5. Upload to R2
|
||||||
|
"""
|
||||||
|
logger.info("Starting full assembly...")
|
||||||
|
|
||||||
|
# Step 1: Concatenate
|
||||||
|
merged = concatenate_scenes(video_urls)
|
||||||
|
|
||||||
|
# Step 2: Subtitles
|
||||||
|
subtitled = burn_subtitles(merged, scenes)
|
||||||
|
|
||||||
|
# Step 3: Voiceover
|
||||||
|
if voiceover_url:
|
||||||
|
voiced = mix_voiceover(subtitled, voiceover_url)
|
||||||
|
else:
|
||||||
|
voiced = subtitled
|
||||||
|
|
||||||
|
# Step 4: BGM
|
||||||
|
if bgm_url:
|
||||||
|
bgm_local = download_video(bgm_url)
|
||||||
|
final_path = mix_bgm(voiced, bgm_local)
|
||||||
|
else:
|
||||||
|
final_path = voiced
|
||||||
|
|
||||||
|
# Step 5: Upload
|
||||||
|
final_url = storage.upload_file(final_path)
|
||||||
|
logger.info(f"Final video uploaded: {final_url}")
|
||||||
|
|
||||||
|
return final_url
|
||||||
157
modules/export_utils.py
Normal file
157
modules/export_utils.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
import os
|
||||||
|
import zipfile
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def format_timestamp(seconds: float) -> str:
|
||||||
|
"""Convert seconds to SRT timestamp format (HH:MM:SS,mmm)"""
|
||||||
|
hours = int(seconds // 3600)
|
||||||
|
minutes = int((seconds % 3600) // 60)
|
||||||
|
secs = int(seconds % 60)
|
||||||
|
millis = int((seconds - int(seconds)) * 1000)
|
||||||
|
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
|
||||||
|
|
||||||
|
def generate_srt(script_data: Dict[str, Any], video_map: Dict[int, str]) -> str:
|
||||||
|
"""Generate SRT content from script data"""
|
||||||
|
scenes = script_data.get("scenes", [])
|
||||||
|
srt_content = ""
|
||||||
|
current_time = 0.0
|
||||||
|
|
||||||
|
# Need to get durations from actual videos if possible, else estimate
|
||||||
|
from modules import ffmpeg_utils
|
||||||
|
|
||||||
|
for i, scene in enumerate(scenes):
|
||||||
|
scene_id = scene["id"]
|
||||||
|
# Get duration
|
||||||
|
duration = 5.0
|
||||||
|
if scene_id in video_map and os.path.exists(video_map[scene_id]):
|
||||||
|
try:
|
||||||
|
info = ffmpeg_utils.get_video_info(video_map[scene_id])
|
||||||
|
duration = info.get("duration", 5.0)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
start_time = current_time
|
||||||
|
end_time = current_time + duration
|
||||||
|
current_time = end_time
|
||||||
|
|
||||||
|
text = scene.get("subtitle", "")
|
||||||
|
if text:
|
||||||
|
srt_content += f"{i+1}\n"
|
||||||
|
srt_content += f"{format_timestamp(start_time)} --> {format_timestamp(end_time)}\n"
|
||||||
|
srt_content += f"{text}\n\n"
|
||||||
|
|
||||||
|
return srt_content
|
||||||
|
|
||||||
|
def create_capcut_package(project_id: str, script_data: Dict[str, Any], assets: Dict[str, str]) -> str:
|
||||||
|
"""
|
||||||
|
Create a ZIP package for CapCut (JianYing) import
|
||||||
|
Contains:
|
||||||
|
- videos/ (scene videos)
|
||||||
|
- audios/ (voiceover, bgm)
|
||||||
|
- images/ (fancy text transparent pngs)
|
||||||
|
- subtitles.srt
|
||||||
|
"""
|
||||||
|
package_dir = config.TEMP_DIR / f"capcut_pkg_{project_id}_{int(os.getpid())}"
|
||||||
|
if package_dir.exists():
|
||||||
|
shutil.rmtree(package_dir)
|
||||||
|
package_dir.mkdir()
|
||||||
|
|
||||||
|
(package_dir / "videos").mkdir()
|
||||||
|
(package_dir / "audios").mkdir()
|
||||||
|
(package_dir / "images").mkdir()
|
||||||
|
|
||||||
|
# 1. Generate SRT
|
||||||
|
# Need to reconstruct video map from assets or script
|
||||||
|
# Assuming 'assets' contains 'scene_videos' map
|
||||||
|
scene_videos = assets.get("scene_videos", {})
|
||||||
|
srt_content = generate_srt(script_data, scene_videos)
|
||||||
|
with open(package_dir / "subtitles.srt", "w", encoding="utf-8") as f:
|
||||||
|
f.write(srt_content)
|
||||||
|
|
||||||
|
# 2. Copy Videos
|
||||||
|
scenes = script_data.get("scenes", [])
|
||||||
|
for i, scene in enumerate(scenes):
|
||||||
|
sid = scene["id"]
|
||||||
|
if sid in scene_videos and os.path.exists(scene_videos[sid]):
|
||||||
|
# Rename with sequence number for easy sorting: 01_scene.mp4
|
||||||
|
ext = Path(scene_videos[sid]).suffix
|
||||||
|
dest_name = f"{i+1:02d}_scene_{sid}{ext}"
|
||||||
|
shutil.copy(scene_videos[sid], package_dir / "videos" / dest_name)
|
||||||
|
|
||||||
|
# 3. Copy Audio (Voiceover)
|
||||||
|
# We might not have the separate voiceover file easily accessible if it was mixed on the fly.
|
||||||
|
# But usually we generate it to temp.
|
||||||
|
# Option: Re-generate voiceover audio for the whole track or segments?
|
||||||
|
# Better: If we have 'voiceover_segments', generate them or copy if cached.
|
||||||
|
# For now, let's try to find if we have a full voiceover file or just use segments.
|
||||||
|
# Simplest: Re-generate the full voiceover audio file if it doesn't exist as a standalone asset.
|
||||||
|
# Or check if user just wants the pieces.
|
||||||
|
# Let's check if we have a mixed audio file. Usually we don't save the intermediate audio as an asset.
|
||||||
|
# So we might need to re-generate the voiceover audio here.
|
||||||
|
from modules import factory
|
||||||
|
full_vo_text = " ".join([s.get("voiceover", "") for s in scenes if s.get("voiceover")])
|
||||||
|
if full_vo_text:
|
||||||
|
try:
|
||||||
|
# Assuming default voice
|
||||||
|
voice_type = config.VOLC_TTS_DEFAULT_VOICE
|
||||||
|
vo_path = factory.generate_voiceover_volcengine(full_vo_text, voice_type)
|
||||||
|
shutil.copy(vo_path, package_dir / "audios" / "full_voiceover.mp3")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to generate export voiceover: {e}")
|
||||||
|
|
||||||
|
# Copy BGM
|
||||||
|
# Check settings or script for BGM? BGM is usually a global setting in Composer.
|
||||||
|
# We'll just look for BGM in assets folder or let user drag their own.
|
||||||
|
# Or if we saved the BGM selection in the project, we could copy it.
|
||||||
|
# For now, skip specific BGM unless we know which one was used.
|
||||||
|
|
||||||
|
# 4. Copy Fancy Text Images
|
||||||
|
# We need to re-render them or find them.
|
||||||
|
# Since they are generated to temp in composer, they might be gone.
|
||||||
|
# Safer to re-render them.
|
||||||
|
from modules.text_renderer import renderer
|
||||||
|
for i, scene in enumerate(scenes):
|
||||||
|
ft = scene.get("fancy_text")
|
||||||
|
if ft:
|
||||||
|
text = ft.get("text", "") if isinstance(ft, dict) else ""
|
||||||
|
style = ft.get("style", "highlight") if isinstance(ft, dict) else "highlight"
|
||||||
|
if text:
|
||||||
|
try:
|
||||||
|
# Render
|
||||||
|
if isinstance(style, str):
|
||||||
|
# Simple mapping or default
|
||||||
|
# We need the full style dict logic from composer ideally
|
||||||
|
# For export, we just use default render
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Actually, composer logic for style resolution is complex.
|
||||||
|
# Let's just use a simple render here.
|
||||||
|
img_path = renderer.render(text, {"font_size": 60, "font_color": "#FFFFFF"}, cache=False)
|
||||||
|
shutil.copy(img_path, package_dir / "images" / f"{i+1:02d}_text_{scene['id']}.png")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 5. Zip it
|
||||||
|
zip_path = config.TEMP_DIR / f"capcut_export_{project_id}.zip"
|
||||||
|
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||||
|
for root, dirs, files in os.walk(package_dir):
|
||||||
|
for file in files:
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
arcname = os.path.relpath(file_path, package_dir)
|
||||||
|
zipf.write(file_path, arcname)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
shutil.rmtree(package_dir)
|
||||||
|
return str(zip_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
801
modules/factory.py
Normal file
801
modules/factory.py
Normal file
@@ -0,0 +1,801 @@
|
|||||||
|
"""
|
||||||
|
MatchMe Studio - Factory Module (Concurrent Scene Generation)
|
||||||
|
Using Volcengine (Doubao) API for Image and Video
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import base64
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from elevenlabs import ElevenLabs, VoiceSettings
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
import config
|
||||||
|
from modules import storage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Initialize OpenAI Client for Volcengine Image Generation
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=config.VOLC_API_KEY,
|
||||||
|
base_url=config.VOLC_BASE_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Helper Functions
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def _download_as_base64(url: str) -> str:
|
||||||
|
"""Download image from URL and convert to Base64."""
|
||||||
|
try:
|
||||||
|
response = requests.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
return base64.b64encode(response.content).decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download/encode image: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Image Generation (Doubao / Volcengine)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def generate_scene_image(
|
||||||
|
scene: Dict[str, Any],
|
||||||
|
brief: Dict[str, Any] = None,
|
||||||
|
reference_images: List[str] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate image using Volcengine API (Doubao Image).
|
||||||
|
Using raw requests to match user's curl example exactly.
|
||||||
|
"""
|
||||||
|
# Build prompt
|
||||||
|
image_prompt = scene.get("image_prompt", "")
|
||||||
|
if not image_prompt:
|
||||||
|
# Fallback prompt construction
|
||||||
|
keyframe = scene.get("keyframe", {})
|
||||||
|
# Stronger style consistency intro
|
||||||
|
parts = ["Cinematic shot, 8k, photorealistic"]
|
||||||
|
if brief:
|
||||||
|
if brief.get("product_visual_description"):
|
||||||
|
parts.append(f"Product: {brief['product_visual_description']}")
|
||||||
|
parts.extend([
|
||||||
|
f"Subject: {keyframe.get('subject', 'product')}",
|
||||||
|
f"Environment: {keyframe.get('environment', 'studio')}",
|
||||||
|
f"Action: {keyframe.get('focus', '')}"
|
||||||
|
])
|
||||||
|
image_prompt = ", ".join(parts)
|
||||||
|
|
||||||
|
# Append explicit consistency enforcement to prompt
|
||||||
|
if brief and brief.get("product_visual_description"):
|
||||||
|
if brief['product_visual_description'] not in image_prompt:
|
||||||
|
image_prompt = f"{brief['product_visual_description']}, {image_prompt}"
|
||||||
|
|
||||||
|
logger.info(f"Generating image (Volcengine): {image_prompt[:50]}...")
|
||||||
|
|
||||||
|
url = f"{config.VOLC_BASE_URL}/images/generations"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {config.VOLC_API_KEY}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Payload matching user's curl example
|
||||||
|
payload = {
|
||||||
|
"model": config.IMAGE_MODEL_ID,
|
||||||
|
"prompt": image_prompt,
|
||||||
|
"sequential_image_generation": "disabled",
|
||||||
|
"response_format": "b64_json", # Use base64 to avoid temp url expiration issues
|
||||||
|
"size": "2K", # User specified 2K
|
||||||
|
"stream": False,
|
||||||
|
"watermark": True
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=headers, json=payload, timeout=60)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"Image API Error: {response.text}")
|
||||||
|
raise ValueError(f"Image API failed: {response.status_code} - {response.text}")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Extract Image Data
|
||||||
|
image_data = None
|
||||||
|
if "data" in data and len(data["data"]) > 0:
|
||||||
|
image_data = data["data"][0].get("b64_json")
|
||||||
|
if not image_data:
|
||||||
|
# Fallback to URL download if b64 not present
|
||||||
|
img_url = data["data"][0].get("url")
|
||||||
|
if img_url:
|
||||||
|
# Download the image to ensure we have it locally
|
||||||
|
image_data = _download_as_base64(img_url)
|
||||||
|
|
||||||
|
if not image_data:
|
||||||
|
raise ValueError("No image data returned")
|
||||||
|
|
||||||
|
# Decode and Save
|
||||||
|
filename = f"scene_{scene.get('id', 0)}_{int(time.time())}.jpg"
|
||||||
|
local_path = config.TEMP_DIR / filename
|
||||||
|
|
||||||
|
with open(local_path, "wb") as f:
|
||||||
|
f.write(base64.b64decode(image_data))
|
||||||
|
|
||||||
|
# Upload to R2
|
||||||
|
r2_url = storage.upload_file(str(local_path))
|
||||||
|
logger.info(f"Scene {scene.get('id', '?')} image uploaded: {r2_url}")
|
||||||
|
return r2_url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Image Generation Failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def generate_all_scene_images_concurrent(
|
||||||
|
scenes: List[Dict[str, Any]],
|
||||||
|
brief: Dict[str, Any] = None,
|
||||||
|
reference_images: List[str] = None,
|
||||||
|
max_workers: int = 3
|
||||||
|
) -> List[str]:
|
||||||
|
"""Generate images for all scenes concurrently."""
|
||||||
|
logger.info(f"Generating {len(scenes)} images concurrently...")
|
||||||
|
image_urls = [None] * len(scenes)
|
||||||
|
|
||||||
|
def generate_single(index: int, scene: Dict[str, Any]) -> tuple:
|
||||||
|
url = generate_scene_image(scene, brief, reference_images)
|
||||||
|
return index, url
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
futures = {
|
||||||
|
executor.submit(generate_single, i, scene): i
|
||||||
|
for i, scene in enumerate(scenes)
|
||||||
|
}
|
||||||
|
|
||||||
|
for future in as_completed(futures):
|
||||||
|
index = futures[future]
|
||||||
|
try:
|
||||||
|
_, url = future.result()
|
||||||
|
image_urls[index] = url
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Scene {index+1} failed: {e}")
|
||||||
|
|
||||||
|
return image_urls
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Video Generation (Doubao Video / PixelDance)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def generate_scene_video(
|
||||||
|
start_frame_url: str,
|
||||||
|
motion_prompt: str,
|
||||||
|
duration: int = 5
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate video using Volcengine API (Async Task Flow).
|
||||||
|
"""
|
||||||
|
logger.info(f"Generating video (Volcengine): {motion_prompt[:50]}...")
|
||||||
|
|
||||||
|
# 1. Create Task
|
||||||
|
create_url = f"{config.VOLC_BASE_URL}/contents/generations/tasks"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {config.VOLC_API_KEY}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Construct Content List (Text + Optional Image)
|
||||||
|
content_list = [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"{motion_prompt} --resolution 1080p --duration {duration} --camerafixed false --watermark true"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
if start_frame_url:
|
||||||
|
content_list.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": start_frame_url}
|
||||||
|
})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": config.VIDEO_MODEL_ID,
|
||||||
|
"content": content_list
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(create_url, headers=headers, json=payload, timeout=30)
|
||||||
|
if response.status_code != 200:
|
||||||
|
# 202 Accepted is also possible for async tasks
|
||||||
|
if response.status_code != 202:
|
||||||
|
logger.error(f"Video Task Creation Error: {response.text}")
|
||||||
|
raise ValueError(f"Video Task failed: {response.status_code} - {response.text}")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
task_id = data.get("id")
|
||||||
|
if not task_id:
|
||||||
|
# Sometimes ID is in data.id or similar
|
||||||
|
task_id = data.get("data", {}).get("id")
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
raise ValueError(f"No Task ID returned: {data}")
|
||||||
|
|
||||||
|
logger.info(f"Video Task Created: {task_id}. Polling for result...")
|
||||||
|
|
||||||
|
# 2. Poll for Result
|
||||||
|
# GET /contents/generations/tasks/{id}
|
||||||
|
max_retries = 60 # 5 mins max (5s interval)
|
||||||
|
video_url = None
|
||||||
|
|
||||||
|
for _ in range(max_retries):
|
||||||
|
time.sleep(5)
|
||||||
|
status_url = f"{config.VOLC_BASE_URL}/contents/generations/tasks/{task_id}"
|
||||||
|
resp = requests.get(status_url, headers=headers, timeout=30)
|
||||||
|
|
||||||
|
if resp.status_code == 200:
|
||||||
|
res_data = resp.json()
|
||||||
|
# Check status
|
||||||
|
# Structure usually: data.status = "succeeded" / "running" / "failed"
|
||||||
|
# Or top level status
|
||||||
|
|
||||||
|
status = res_data.get("status")
|
||||||
|
if not status and "data" in res_data:
|
||||||
|
status = res_data["data"].get("status")
|
||||||
|
|
||||||
|
if status == "succeeded" or status == "SUCCEEDED":
|
||||||
|
# Extract URL
|
||||||
|
content = res_data.get("data", {}).get("content", [])
|
||||||
|
if not content and "content" in res_data:
|
||||||
|
content = res_data["content"]
|
||||||
|
|
||||||
|
# Find video url in content
|
||||||
|
# Content is usually list of dicts with type='video' or 'video_url'
|
||||||
|
for item in content:
|
||||||
|
if item.get("video_url"):
|
||||||
|
video_url = item["video_url"]
|
||||||
|
break
|
||||||
|
if item.get("url"): # sometimes just url
|
||||||
|
video_url = item["url"]
|
||||||
|
break
|
||||||
|
|
||||||
|
if video_url:
|
||||||
|
break
|
||||||
|
elif status == "failed" or status == "FAILED":
|
||||||
|
reason = res_data.get("data", {}).get("error", "Unknown error")
|
||||||
|
raise ValueError(f"Video Generation Failed: {reason}")
|
||||||
|
|
||||||
|
# If running/queued, continue waiting
|
||||||
|
|
||||||
|
if not video_url:
|
||||||
|
raise TimeoutError("Video generation timed out or failed to return URL.")
|
||||||
|
|
||||||
|
# 3. Download and Upload to R2
|
||||||
|
logger.info(f"Video Generated. Downloading: {video_url}")
|
||||||
|
filename = f"vid_doubao_{int(time.time())}.mp4"
|
||||||
|
local_path = config.TEMP_DIR / filename
|
||||||
|
|
||||||
|
resp = requests.get(video_url, stream=True)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise ValueError(f"Failed to download generated video: {resp.status_code}")
|
||||||
|
|
||||||
|
with open(local_path, "wb") as f:
|
||||||
|
for chunk in resp.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
r2_url = storage.upload_file(str(local_path))
|
||||||
|
return r2_url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Video Generation Error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def generate_all_scene_videos_concurrent(
|
||||||
|
scenes: List[Dict[str, Any]],
|
||||||
|
image_urls: List[str],
|
||||||
|
max_workers: int = 2
|
||||||
|
) -> List[str]:
|
||||||
|
"""Generate videos concurrently."""
|
||||||
|
logger.info(f"Generating {len(scenes)} videos concurrently...")
|
||||||
|
video_urls = [None] * len(scenes)
|
||||||
|
|
||||||
|
def generate_single(index: int, scene: Dict[str, Any], img_url: str) -> tuple:
|
||||||
|
motion = scene.get("camera_movement", "slow zoom")
|
||||||
|
if scene.get("image_prompt"):
|
||||||
|
motion = f"{scene['image_prompt']}. {motion}"
|
||||||
|
|
||||||
|
duration = scene.get("duration", 5)
|
||||||
|
url = generate_scene_video(img_url, motion, duration)
|
||||||
|
return index, url
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
futures = {
|
||||||
|
executor.submit(generate_single, i, scene, image_urls[i]): i
|
||||||
|
for i, scene in enumerate(scenes)
|
||||||
|
}
|
||||||
|
|
||||||
|
for future in as_completed(futures):
|
||||||
|
index = futures[future]
|
||||||
|
try:
|
||||||
|
_, url = future.result()
|
||||||
|
video_urls[index] = url
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Scene {index+1} video failed: {e}")
|
||||||
|
|
||||||
|
return video_urls
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Audio Generation (ElevenLabs)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def generate_voiceover(text: str, style: str = "") -> str:
|
||||||
|
"""Generate voiceover audio. Returns R2 URL."""
|
||||||
|
if not text or not text.strip():
|
||||||
|
return ""
|
||||||
|
|
||||||
|
stability = 0.3 if "ASMR" in style else 0.5
|
||||||
|
similarity = 0.9 if "ASMR" in style else 0.8
|
||||||
|
|
||||||
|
logger.info(f"Generating voiceover ({len(text)} chars, style={style})...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
el_client = ElevenLabs(api_key=config.XI_KEY)
|
||||||
|
|
||||||
|
audio_stream = el_client.text_to_speech.convert(
|
||||||
|
voice_id=config.ELEVENLABS_VOICE_ID,
|
||||||
|
text=text,
|
||||||
|
model_id=config.ELEVENLABS_MODEL,
|
||||||
|
voice_settings=VoiceSettings(stability=stability, similarity_boost=similarity)
|
||||||
|
)
|
||||||
|
|
||||||
|
filename = f"vo_{int(time.time())}.mp3"
|
||||||
|
local_path = config.TEMP_DIR / filename
|
||||||
|
|
||||||
|
with open(local_path, "wb") as f:
|
||||||
|
for chunk in audio_stream:
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
r2_url = storage.upload_file(str(local_path))
|
||||||
|
return r2_url
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Voiceover failed: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_full_voiceover(scenes: List[Dict[str, Any]], style: str = "") -> str:
|
||||||
|
"""Generate combined voiceover for all scenes."""
|
||||||
|
voiceovers = []
|
||||||
|
for s in scenes:
|
||||||
|
vo = s.get("voiceover", "")
|
||||||
|
if vo and vo.strip() and not vo.startswith("("):
|
||||||
|
voiceovers.append(vo.strip())
|
||||||
|
|
||||||
|
if not voiceovers:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
full_text = " ".join(voiceovers)
|
||||||
|
return generate_voiceover(full_text, style)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Audio Generation (Edge TTS - 免费中文语音合成)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Edge TTS 中文音色预设 (免费,效果好)
|
||||||
|
EDGE_TTS_VOICES = {
|
||||||
|
# 女声
|
||||||
|
"sweet_female": "zh-CN-XiaoxiaoNeural", # 晓晓 - 甜美活泼(推荐)
|
||||||
|
"gentle_female": "zh-CN-XiaoyiNeural", # 晓伊 - 温柔知性
|
||||||
|
"lively_female": "zh-CN-XiaochenNeural", # 晓辰 - 活泼可爱
|
||||||
|
"broadcast_female": "zh-CN-XiaoqiuNeural", # 晓秋 - 新闻播报
|
||||||
|
# 男声
|
||||||
|
"general_male": "zh-CN-YunxiNeural", # 云希 - 温暖男声
|
||||||
|
"broadcast_male": "zh-CN-YunjianNeural", # 云健 - 专业播报
|
||||||
|
}
|
||||||
|
|
||||||
|
# 火山引擎 TTS 音色预设 (需开通服务) - 选择抖音带货友好的音色
|
||||||
|
VOLC_TTS_VOICES = {
|
||||||
|
# 抖音带货友好女声
|
||||||
|
"sweet_female": "zh_female_vv_uranus_bigtts", # viv 2.0 通用女声(甜美)
|
||||||
|
"lively_female": "zh_female_jitangnv_saturn_bigtts", # 鸡汤女(元气)
|
||||||
|
"broadcast_female": "zh_male_ruyaichen_saturn_bigtts", # 入雅尘(新闻播报)- 若需女声播报可换 zh_female_meilinyou_saturn_bigtts
|
||||||
|
"meilinvyou": "zh_female_meilinvyou_saturn_bigtts",
|
||||||
|
# 男声
|
||||||
|
"general_male": "zh_male_dayi_saturn_bigtts", # 大义(沉稳男声)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_voiceover_edge(
|
||||||
|
text: str,
|
||||||
|
voice_type: str = "sweet_female",
|
||||||
|
rate: str = "+0%",
|
||||||
|
volume: str = "+0%",
|
||||||
|
output_path: str = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
使用 Edge TTS 生成中文旁白(免费,效果好)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 旁白文本
|
||||||
|
voice_type: 音色类型(见 EDGE_TTS_VOICES)或直接使用音色名
|
||||||
|
rate: 语速调整,如 "+10%", "-20%"
|
||||||
|
volume: 音量调整,如 "+10%", "-20%"
|
||||||
|
output_path: 输出路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
音频文件路径
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import edge_tts
|
||||||
|
|
||||||
|
if not text or not text.strip():
|
||||||
|
logger.warning("Empty text provided for TTS")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 获取音色
|
||||||
|
voice = EDGE_TTS_VOICES.get(voice_type, voice_type)
|
||||||
|
|
||||||
|
logger.info(f"Generating voiceover (Edge TTS): {len(text)} chars, voice={voice}")
|
||||||
|
|
||||||
|
if not output_path:
|
||||||
|
filename = f"vo_edge_{int(time.time())}.mp3"
|
||||||
|
output_path = str(config.TEMP_DIR / filename)
|
||||||
|
|
||||||
|
async def _generate():
|
||||||
|
communicate = edge_tts.Communicate(text, voice, rate=rate, volume=volume)
|
||||||
|
await communicate.save(output_path)
|
||||||
|
|
||||||
|
# Simple retry logic for Edge TTS
|
||||||
|
max_retries = 3
|
||||||
|
for i in range(max_retries):
|
||||||
|
try:
|
||||||
|
asyncio.run(_generate())
|
||||||
|
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
||||||
|
logger.info(f"Edge TTS voiceover generated: {output_path}")
|
||||||
|
return output_path
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Edge TTS attempt {i+1} failed: {e}")
|
||||||
|
time.sleep(1.0) # wait before retry
|
||||||
|
|
||||||
|
logger.error("Edge TTS failed after retries.")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_voiceover_volcengine_ws(
|
||||||
|
text: str,
|
||||||
|
voice_type: str = "sweet_female",
|
||||||
|
output_path: str = None,
|
||||||
|
timeout: int = 120
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
使用火山 WebSocket Binary Demo 生成 TTS 音频
|
||||||
|
依赖目录:/Volumes/Tony/video-flow/volcengine_binary_demo/.venv/bin/python
|
||||||
|
"""
|
||||||
|
if not text or not text.strip():
|
||||||
|
logger.warning("Empty text provided for TTS (ws)")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
voice_id = VOLC_TTS_VOICES.get(voice_type, voice_type)
|
||||||
|
|
||||||
|
venv_python = Path("/Volumes/Tony/video-flow/volcengine_binary_demo/.venv/bin/python")
|
||||||
|
demo_script = Path("/Volumes/Tony/video-flow/volcengine_binary_demo/examples/volcengine/binary.py")
|
||||||
|
|
||||||
|
if not venv_python.exists() or not demo_script.exists():
|
||||||
|
logger.error("Volcengine WS demo or venv not found. Please install under volcengine_binary_demo/.venv")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if not output_path:
|
||||||
|
output_path = str(config.TEMP_DIR / f"vo_volc_ws_{int(time.time())}.mp3")
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
str(venv_python),
|
||||||
|
str(demo_script),
|
||||||
|
"--appid", config.VOLC_TTS_APPID,
|
||||||
|
"--access_token", config.VOLC_TTS_ACCESS_TOKEN,
|
||||||
|
"--voice_type", voice_id,
|
||||||
|
"--text", text,
|
||||||
|
"--encoding", "mp3",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(f"Calling Volcengine WS TTS: voice={voice_id}, len={len(text)}")
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
cwd="/Volumes/Tony/video-flow/volcengine_binary_demo",
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.error(f"Volc WS TTS failed: {result.stderr}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# demo 保存在 cwd 下 voice_type.mp3
|
||||||
|
demo_out = Path("/Volumes/Tony/video-flow/volcengine_binary_demo") / f"{voice_id}.mp3"
|
||||||
|
if not demo_out.exists():
|
||||||
|
logger.error("Volc WS TTS output not found")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
Path(output_path).write_bytes(demo_out.read_bytes())
|
||||||
|
logger.info(f"Volc WS TTS saved to {output_path}")
|
||||||
|
return output_path
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Volc WS TTS error: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_voiceover_volcengine(
|
||||||
|
text: str,
|
||||||
|
voice_type: str = "sweet_female",
|
||||||
|
speed_ratio: float = 1.0,
|
||||||
|
volume_ratio: float = 1.0,
|
||||||
|
pitch_ratio: float = 1.0,
|
||||||
|
output_path: str = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
使用火山引擎 TTS 生成中文旁白
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 旁白文本
|
||||||
|
voice_type: 音色类型(见 VOLC_TTS_VOICES)或直接使用音色 ID
|
||||||
|
speed_ratio: 语速(0.5-2.0,默认1.0)
|
||||||
|
volume_ratio: 音量(0.5-2.0,默认1.0)
|
||||||
|
pitch_ratio: 音调(0.5-2.0,默认1.0)
|
||||||
|
output_path: 输出路径(可选,默认自动生成)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
音频文件路径
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
if not text or not text.strip():
|
||||||
|
logger.warning("Empty text provided for TTS")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 获取音色 ID(火山音色表 + fallback 自定义)
|
||||||
|
voice_id = VOLC_TTS_VOICES.get(voice_type, voice_type)
|
||||||
|
|
||||||
|
logger.info(f"Generating voiceover (Volcengine TTS): {len(text)} chars, voice={voice_id}")
|
||||||
|
|
||||||
|
# 先尝试 WebSocket Binary(官方 demo 已验证可用)
|
||||||
|
ws_path = generate_voiceover_volcengine_ws(text, voice_type, output_path)
|
||||||
|
if ws_path:
|
||||||
|
return ws_path
|
||||||
|
|
||||||
|
# 若 WS 异常,再尝试 HTTP
|
||||||
|
url = "https://openspeech.bytedance.com/api/v1/tts"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer;{config.VOLC_TTS_ACCESS_TOKEN}"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"app": {
|
||||||
|
"appid": config.VOLC_TTS_APPID,
|
||||||
|
"token": config.VOLC_TTS_ACCESS_TOKEN,
|
||||||
|
"cluster": "volcano_tts"
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"uid": "video_flow_user"
|
||||||
|
},
|
||||||
|
"audio": {
|
||||||
|
"voice_type": voice_id,
|
||||||
|
"encoding": "mp3",
|
||||||
|
"speed_ratio": speed_ratio,
|
||||||
|
"volume_ratio": volume_ratio,
|
||||||
|
"pitch_ratio": pitch_ratio
|
||||||
|
},
|
||||||
|
"request": {
|
||||||
|
"reqid": str(uuid.uuid4()),
|
||||||
|
"text": text,
|
||||||
|
"text_type": "plain",
|
||||||
|
"operation": "query",
|
||||||
|
"with_timestamp": "1",
|
||||||
|
"extra_param": json.dumps({
|
||||||
|
"disable_markdown_filter": False
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=headers, json=payload, timeout=60)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"Volcengine TTS Error: {response.status_code} - {response.text}")
|
||||||
|
# Fallback to Edge TTS with a safe default voice
|
||||||
|
fallback_voice = "sweet_female" if voice_type not in EDGE_TTS_VOICES else voice_type
|
||||||
|
return generate_voiceover_edge(text, fallback_voice, output_path=output_path)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
ret_code = data.get("code")
|
||||||
|
if ret_code not in (0, 3000, 20000000):
|
||||||
|
error_msg = data.get("message", "Unknown error")
|
||||||
|
logger.error(f"Volcengine TTS Error: {error_msg}")
|
||||||
|
# Fallback to Edge TTS with a safe default voice
|
||||||
|
fallback_voice = "sweet_female" if voice_type not in EDGE_TTS_VOICES else voice_type
|
||||||
|
return generate_voiceover_edge(text, fallback_voice, output_path=output_path)
|
||||||
|
|
||||||
|
audio_data = data.get("data", "")
|
||||||
|
if not audio_data:
|
||||||
|
raise ValueError("No audio data returned")
|
||||||
|
|
||||||
|
if not output_path:
|
||||||
|
filename = f"vo_volc_{int(time.time())}.mp3"
|
||||||
|
output_path = str(config.TEMP_DIR / filename)
|
||||||
|
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(base64.b64decode(audio_data))
|
||||||
|
|
||||||
|
logger.info(f"Voiceover generated (HTTP): {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Volcengine TTS HTTP error: {e}")
|
||||||
|
# Fallback to Edge TTS with a safe default voice
|
||||||
|
fallback_voice = "sweet_female" if voice_type not in EDGE_TTS_VOICES else voice_type
|
||||||
|
return generate_voiceover_edge(text, fallback_voice, output_path=output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_voiceover_volcengine_long(
|
||||||
|
text: str,
|
||||||
|
voice_type: str = "sweet_female",
|
||||||
|
speed_ratio: float = 1.0,
|
||||||
|
output_path: str = None,
|
||||||
|
max_chunk_length: int = 300
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
火山引擎 TTS 长文本处理(自动分段合成)
|
||||||
|
|
||||||
|
对于超过 max_chunk_length 的文本,自动分段合成后拼接
|
||||||
|
"""
|
||||||
|
if len(text) <= max_chunk_length:
|
||||||
|
return generate_voiceover_volcengine(
|
||||||
|
text=text,
|
||||||
|
voice_type=voice_type,
|
||||||
|
speed_ratio=speed_ratio,
|
||||||
|
output_path=output_path
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Long text ({len(text)} chars), splitting into chunks...")
|
||||||
|
|
||||||
|
# 按句子分段
|
||||||
|
import re
|
||||||
|
sentences = re.split(r'([。!?;.!?;])', text)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
current_chunk = ""
|
||||||
|
|
||||||
|
for i in range(0, len(sentences) - 1, 2):
|
||||||
|
sentence = sentences[i] + (sentences[i + 1] if i + 1 < len(sentences) else "")
|
||||||
|
|
||||||
|
if len(current_chunk) + len(sentence) <= max_chunk_length:
|
||||||
|
current_chunk += sentence
|
||||||
|
else:
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk)
|
||||||
|
current_chunk = sentence
|
||||||
|
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk)
|
||||||
|
|
||||||
|
# 如果最后一段是奇数句子
|
||||||
|
if len(sentences) % 2 == 1 and sentences[-1]:
|
||||||
|
if chunks:
|
||||||
|
chunks[-1] += sentences[-1]
|
||||||
|
else:
|
||||||
|
chunks.append(sentences[-1])
|
||||||
|
|
||||||
|
logger.info(f"Split into {len(chunks)} chunks")
|
||||||
|
|
||||||
|
# 生成每段音频
|
||||||
|
chunk_files = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunk_path = str(config.TEMP_DIR / f"vo_chunk_{i}_{int(time.time())}.mp3")
|
||||||
|
try:
|
||||||
|
path = generate_voiceover_volcengine(
|
||||||
|
text=chunk,
|
||||||
|
voice_type=voice_type,
|
||||||
|
speed_ratio=speed_ratio,
|
||||||
|
output_path=chunk_path
|
||||||
|
)
|
||||||
|
chunk_files.append(path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Chunk {i} failed: {e}")
|
||||||
|
# 继续处理其他段落
|
||||||
|
|
||||||
|
if not chunk_files:
|
||||||
|
raise ValueError("All TTS chunks failed")
|
||||||
|
|
||||||
|
# 使用 FFmpeg 合并音频
|
||||||
|
if len(chunk_files) == 1:
|
||||||
|
if output_path:
|
||||||
|
import shutil
|
||||||
|
shutil.move(chunk_files[0], output_path)
|
||||||
|
return output_path
|
||||||
|
return chunk_files[0]
|
||||||
|
|
||||||
|
# 创建合并文件列表
|
||||||
|
concat_list = config.TEMP_DIR / f"concat_audio_{os.getpid()}.txt"
|
||||||
|
with open(concat_list, "w") as f:
|
||||||
|
for cf in chunk_files:
|
||||||
|
f.write(f"file '{cf}'\n")
|
||||||
|
|
||||||
|
if not output_path:
|
||||||
|
output_path = str(config.TEMP_DIR / f"vo_volc_merged_{int(time.time())}.mp3")
|
||||||
|
|
||||||
|
# FFmpeg 合并
|
||||||
|
import subprocess
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-f", "concat",
|
||||||
|
"-safe", "0",
|
||||||
|
"-i", str(concat_list),
|
||||||
|
"-c", "copy",
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
|
||||||
|
subprocess.run(cmd, capture_output=True, check=True)
|
||||||
|
|
||||||
|
# 清理临时文件
|
||||||
|
for cf in chunk_files:
|
||||||
|
try:
|
||||||
|
os.remove(cf)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
concat_list.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
logger.info(f"Merged voiceover: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def generate_scene_voiceovers_volcengine(
|
||||||
|
scenes: List[Dict[str, Any]],
|
||||||
|
voice_type: str = "sweet_female",
|
||||||
|
output_dir: str = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
为每个场景单独生成旁白音频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scenes: 场景列表,每个场景包含 voiceover 字段
|
||||||
|
voice_type: 音色类型
|
||||||
|
output_dir: 输出目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
音频文件路径列表
|
||||||
|
"""
|
||||||
|
if output_dir:
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
else:
|
||||||
|
output_dir = config.TEMP_DIR
|
||||||
|
|
||||||
|
audio_paths = []
|
||||||
|
|
||||||
|
for i, scene in enumerate(scenes):
|
||||||
|
vo_text = scene.get("voiceover", "")
|
||||||
|
|
||||||
|
if not vo_text or not vo_text.strip() or vo_text.startswith("("):
|
||||||
|
# 无旁白或是注释
|
||||||
|
audio_paths.append("")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
output_path = str(output_dir / f"scene_{i+1}_vo.mp3")
|
||||||
|
path = generate_voiceover_volcengine(
|
||||||
|
text=vo_text.strip(),
|
||||||
|
voice_type=voice_type,
|
||||||
|
output_path=output_path
|
||||||
|
)
|
||||||
|
audio_paths.append(path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Scene {i+1} voiceover failed: {e}")
|
||||||
|
audio_paths.append("")
|
||||||
|
|
||||||
|
return audio_paths
|
||||||
708
modules/fancy_text.py
Normal file
708
modules/fancy_text.py
Normal file
@@ -0,0 +1,708 @@
|
|||||||
|
"""
|
||||||
|
抖音风格花字生成模块
|
||||||
|
使用 Pillow 生成透明 PNG 图片,支持描边、渐变、气泡框等效果
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Tuple, List, Optional
|
||||||
|
|
||||||
|
from PIL import Image, ImageDraw, ImageFont, ImageFilter
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 花字缓存目录
|
||||||
|
FANCY_TEXT_CACHE_DIR = config.TEMP_DIR / "fancy_text_cache"
|
||||||
|
FANCY_TEXT_CACHE_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_font(font_name: str = None, size: int = 48) -> ImageFont.FreeTypeFont:
|
||||||
|
"""获取字体对象,遇到无效字体会继续尝试下一候选,最后才降级为默认字体"""
|
||||||
|
candidates = []
|
||||||
|
if font_name and os.path.exists(font_name):
|
||||||
|
candidates.append(font_name)
|
||||||
|
else:
|
||||||
|
candidates.extend([
|
||||||
|
config.FONTS_DIR / "AlibabaPuHuiTi-Bold.ttf",
|
||||||
|
config.FONTS_DIR / "AlibabaPuHuiTi-Regular.ttf",
|
||||||
|
config.FONTS_DIR / "NotoSansSC-Bold.otf",
|
||||||
|
config.FONTS_DIR / "NotoSansSC-Regular.otf",
|
||||||
|
])
|
||||||
|
candidates.extend([
|
||||||
|
"/System/Library/Fonts/PingFang.ttc",
|
||||||
|
"/System/Library/Fonts/STHeiti Medium.ttc",
|
||||||
|
"/Library/Fonts/Arial Unicode.ttf",
|
||||||
|
"/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc",
|
||||||
|
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
|
||||||
|
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
|
||||||
|
"C:/Windows/Fonts/msyh.ttc",
|
||||||
|
"C:/Windows/Fonts/simhei.ttf",
|
||||||
|
])
|
||||||
|
for path in candidates:
|
||||||
|
if not path:
|
||||||
|
continue
|
||||||
|
p = str(path)
|
||||||
|
if not os.path.exists(p):
|
||||||
|
continue
|
||||||
|
if isinstance(path, Path) and path.stat().st_size < 10000:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
return ImageFont.truetype(p, size)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load font {p}: {e}")
|
||||||
|
continue
|
||||||
|
logger.warning("No suitable font found, using default")
|
||||||
|
return ImageFont.load_default()
|
||||||
|
|
||||||
|
|
||||||
|
def _hex_to_rgb(hex_color: str) -> Tuple[int, int, int]:
|
||||||
|
"""十六进制颜色转 RGB"""
|
||||||
|
hex_color = hex_color.lstrip("#")
|
||||||
|
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]:
|
||||||
|
"""获取文字尺寸"""
|
||||||
|
# 创建临时图像来测量文字
|
||||||
|
dummy_img = Image.new("RGBA", (1, 1))
|
||||||
|
draw = ImageDraw.Draw(dummy_img)
|
||||||
|
bbox = draw.textbbox((0, 0), text, font=font)
|
||||||
|
return bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _cache_key(text: str, style: Dict) -> str:
|
||||||
|
"""生成缓存键"""
|
||||||
|
content = f"{text}_{str(sorted(style.items()))}"
|
||||||
|
return hashlib.md5(content.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def create_text_with_stroke(
|
||||||
|
text: str,
|
||||||
|
font_size: int = 60,
|
||||||
|
font_color: str = "#FFFFFF",
|
||||||
|
stroke_color: str = "#000000",
|
||||||
|
stroke_width: int = 4,
|
||||||
|
font_name: str = None,
|
||||||
|
padding: int = 20
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
创建带描边的文字图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 文字内容
|
||||||
|
font_size: 字体大小
|
||||||
|
font_color: 字体颜色(十六进制)
|
||||||
|
stroke_color: 描边颜色
|
||||||
|
stroke_width: 描边宽度
|
||||||
|
font_name: 字体路径
|
||||||
|
padding: 内边距
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
透明 PNG 图片
|
||||||
|
"""
|
||||||
|
font = _get_font(font_name, font_size)
|
||||||
|
text_w, text_h = _get_text_size(text, font)
|
||||||
|
|
||||||
|
# 图片尺寸(加上描边和内边距)
|
||||||
|
img_w = text_w + stroke_width * 2 + padding * 2
|
||||||
|
img_h = text_h + stroke_width * 2 + padding * 2
|
||||||
|
|
||||||
|
# 创建透明图片
|
||||||
|
img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
# 文字位置
|
||||||
|
x = padding + stroke_width
|
||||||
|
y = padding + stroke_width
|
||||||
|
|
||||||
|
# 绘制描边(通过偏移绘制多次)
|
||||||
|
stroke_rgb = _hex_to_rgb(stroke_color) + (255,)
|
||||||
|
for dx in range(-stroke_width, stroke_width + 1):
|
||||||
|
for dy in range(-stroke_width, stroke_width + 1):
|
||||||
|
if dx * dx + dy * dy <= stroke_width * stroke_width:
|
||||||
|
draw.text((x + dx, y + dy), text, font=font, fill=stroke_rgb)
|
||||||
|
|
||||||
|
# 绘制主文字
|
||||||
|
font_rgb = _hex_to_rgb(font_color) + (255,)
|
||||||
|
draw.text((x, y), text, font=font, fill=font_rgb)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def create_text_with_shadow(
|
||||||
|
text: str,
|
||||||
|
font_size: int = 60,
|
||||||
|
font_color: str = "#FFFFFF",
|
||||||
|
shadow_color: str = "#000000",
|
||||||
|
shadow_offset: Tuple[int, int] = (4, 4),
|
||||||
|
shadow_blur: int = 5,
|
||||||
|
font_name: str = None,
|
||||||
|
padding: int = 30,
|
||||||
|
stroke_color: str = None,
|
||||||
|
stroke_width: int = 0
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
创建带阴影的文字图片,可选描边(用于双层安全描边)
|
||||||
|
"""
|
||||||
|
font = _get_font(font_name, font_size)
|
||||||
|
text_w, text_h = _get_text_size(text, font)
|
||||||
|
|
||||||
|
# 图片尺寸
|
||||||
|
extra = max(shadow_blur, stroke_width * 2)
|
||||||
|
img_w = text_w + abs(shadow_offset[0]) + extra * 2 + padding * 2
|
||||||
|
img_h = text_h + abs(shadow_offset[1]) + extra * 2 + padding * 2
|
||||||
|
|
||||||
|
shadow_img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0))
|
||||||
|
shadow_draw = ImageDraw.Draw(shadow_img)
|
||||||
|
|
||||||
|
x = padding + extra
|
||||||
|
y = padding + extra
|
||||||
|
|
||||||
|
# 阴影
|
||||||
|
shadow_rgb = _hex_to_rgb(shadow_color) + (180,)
|
||||||
|
shadow_draw.text((x + shadow_offset[0], y + shadow_offset[1]), text, font=font, fill=shadow_rgb)
|
||||||
|
shadow_img = shadow_img.filter(ImageFilter.GaussianBlur(shadow_blur))
|
||||||
|
|
||||||
|
draw = ImageDraw.Draw(shadow_img)
|
||||||
|
|
||||||
|
# 可选描边(外层深色或浅色)
|
||||||
|
if stroke_color and stroke_width > 0:
|
||||||
|
stroke_rgb = _hex_to_rgb(stroke_color) + (255,)
|
||||||
|
for dx in range(-stroke_width, stroke_width + 1):
|
||||||
|
for dy in range(-stroke_width, stroke_width + 1):
|
||||||
|
if dx * dx + dy * dy <= stroke_width * stroke_width:
|
||||||
|
draw.text((x + dx, y + dy), text, font=font, fill=stroke_rgb)
|
||||||
|
|
||||||
|
# 主文字
|
||||||
|
font_rgb = _hex_to_rgb(font_color) + (255,)
|
||||||
|
draw.text((x, y), text, font=font, fill=font_rgb)
|
||||||
|
|
||||||
|
return shadow_img
|
||||||
|
|
||||||
|
|
||||||
|
def create_text_with_gradient(
|
||||||
|
text: str,
|
||||||
|
font_size: int = 60,
|
||||||
|
gradient_colors: List[str] = None,
|
||||||
|
gradient_direction: str = "vertical", # vertical, horizontal
|
||||||
|
stroke_color: str = "#000000",
|
||||||
|
stroke_width: int = 3,
|
||||||
|
font_name: str = None,
|
||||||
|
padding: int = 20
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
创建渐变色文字图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gradient_colors: 渐变颜色列表,如 ["#FF6B6B", "#FFE66D"]
|
||||||
|
gradient_direction: 渐变方向
|
||||||
|
"""
|
||||||
|
if not gradient_colors:
|
||||||
|
gradient_colors = ["#FF6B6B", "#FFE66D"] # 默认红黄渐变
|
||||||
|
|
||||||
|
font = _get_font(font_name, font_size)
|
||||||
|
text_w, text_h = _get_text_size(text, font)
|
||||||
|
|
||||||
|
img_w = text_w + stroke_width * 2 + padding * 2
|
||||||
|
img_h = text_h + stroke_width * 2 + padding * 2
|
||||||
|
|
||||||
|
# 创建渐变图层
|
||||||
|
gradient = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0))
|
||||||
|
gradient_draw = ImageDraw.Draw(gradient)
|
||||||
|
|
||||||
|
# 生成渐变
|
||||||
|
colors = [_hex_to_rgb(c) for c in gradient_colors]
|
||||||
|
|
||||||
|
for i in range(img_h if gradient_direction == "vertical" else img_w):
|
||||||
|
ratio = i / (img_h if gradient_direction == "vertical" else img_w)
|
||||||
|
# 线性插值颜色
|
||||||
|
if ratio < 0.5:
|
||||||
|
r = ratio * 2
|
||||||
|
c1, c2 = colors[0], colors[min(1, len(colors) - 1)]
|
||||||
|
else:
|
||||||
|
r = (ratio - 0.5) * 2
|
||||||
|
c1 = colors[min(1, len(colors) - 1)]
|
||||||
|
c2 = colors[min(2, len(colors) - 1)] if len(colors) > 2 else c1
|
||||||
|
|
||||||
|
color = tuple(int(c1[j] + (c2[j] - c1[j]) * r) for j in range(3)) + (255,)
|
||||||
|
|
||||||
|
if gradient_direction == "vertical":
|
||||||
|
gradient_draw.line([(0, i), (img_w, i)], fill=color)
|
||||||
|
else:
|
||||||
|
gradient_draw.line([(i, 0), (i, img_h)], fill=color)
|
||||||
|
|
||||||
|
# 创建文字蒙版
|
||||||
|
mask = Image.new("L", (img_w, img_h), 0)
|
||||||
|
mask_draw = ImageDraw.Draw(mask)
|
||||||
|
|
||||||
|
x = padding + stroke_width
|
||||||
|
y = padding + stroke_width
|
||||||
|
|
||||||
|
# 先绘制描边蒙版
|
||||||
|
for dx in range(-stroke_width, stroke_width + 1):
|
||||||
|
for dy in range(-stroke_width, stroke_width + 1):
|
||||||
|
if dx * dx + dy * dy <= stroke_width * stroke_width:
|
||||||
|
mask_draw.text((x + dx, y + dy), text, font=font, fill=128)
|
||||||
|
|
||||||
|
# 主文字蒙版
|
||||||
|
mask_draw.text((x, y), text, font=font, fill=255)
|
||||||
|
|
||||||
|
# 创建结果图片
|
||||||
|
result = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0))
|
||||||
|
|
||||||
|
# 绘制描边
|
||||||
|
stroke_img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0))
|
||||||
|
stroke_draw = ImageDraw.Draw(stroke_img)
|
||||||
|
stroke_rgb = _hex_to_rgb(stroke_color) + (255,)
|
||||||
|
|
||||||
|
for dx in range(-stroke_width, stroke_width + 1):
|
||||||
|
for dy in range(-stroke_width, stroke_width + 1):
|
||||||
|
if dx * dx + dy * dy <= stroke_width * stroke_width:
|
||||||
|
stroke_draw.text((x + dx, y + dy), text, font=font, fill=stroke_rgb)
|
||||||
|
|
||||||
|
result = Image.alpha_composite(result, stroke_img)
|
||||||
|
|
||||||
|
# 应用渐变到文字
|
||||||
|
text_mask = Image.new("L", (img_w, img_h), 0)
|
||||||
|
ImageDraw.Draw(text_mask).text((x, y), text, font=font, fill=255)
|
||||||
|
|
||||||
|
gradient_text = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0))
|
||||||
|
gradient_text.paste(gradient, mask=text_mask)
|
||||||
|
|
||||||
|
result = Image.alpha_composite(result, gradient_text)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def create_bubble_text(
|
||||||
|
text: str,
|
||||||
|
font_size: int = 48,
|
||||||
|
font_color: str = "#333333",
|
||||||
|
bg_color: str = "#FFFFFF",
|
||||||
|
border_color: str = "#CCCCCC",
|
||||||
|
border_width: int = 2,
|
||||||
|
corner_radius: int = 20,
|
||||||
|
padding: Tuple[int, int] = (30, 15),
|
||||||
|
font_name: str = None,
|
||||||
|
tail_direction: str = None # "left", "right", "bottom", None
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
创建气泡框文字(对话框效果)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tail_direction: 气泡尾巴方向
|
||||||
|
"""
|
||||||
|
font = _get_font(font_name, font_size)
|
||||||
|
text_w, text_h = _get_text_size(text, font)
|
||||||
|
|
||||||
|
# 气泡尺寸
|
||||||
|
bubble_w = text_w + padding[0] * 2
|
||||||
|
bubble_h = text_h + padding[1] * 2
|
||||||
|
|
||||||
|
# 增加尾巴空间
|
||||||
|
tail_size = 20 if tail_direction else 0
|
||||||
|
|
||||||
|
if tail_direction in ["left", "right"]:
|
||||||
|
img_w = bubble_w + tail_size
|
||||||
|
img_h = bubble_h
|
||||||
|
elif tail_direction == "bottom":
|
||||||
|
img_w = bubble_w
|
||||||
|
img_h = bubble_h + tail_size
|
||||||
|
else:
|
||||||
|
img_w = bubble_w
|
||||||
|
img_h = bubble_h
|
||||||
|
|
||||||
|
img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
# 气泡位置
|
||||||
|
if tail_direction == "left":
|
||||||
|
bx = tail_size
|
||||||
|
else:
|
||||||
|
bx = 0
|
||||||
|
by = 0
|
||||||
|
|
||||||
|
# 绘制圆角矩形
|
||||||
|
bg_rgb = _hex_to_rgb(bg_color) + (255,)
|
||||||
|
border_rgb = _hex_to_rgb(border_color) + (255,)
|
||||||
|
|
||||||
|
# 使用圆角矩形
|
||||||
|
draw.rounded_rectangle(
|
||||||
|
[bx, by, bx + bubble_w, by + bubble_h],
|
||||||
|
radius=corner_radius,
|
||||||
|
fill=bg_rgb,
|
||||||
|
outline=border_rgb,
|
||||||
|
width=border_width
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制尾巴
|
||||||
|
if tail_direction == "left":
|
||||||
|
points = [
|
||||||
|
(bx, bubble_h // 2 - 10),
|
||||||
|
(0, bubble_h // 2),
|
||||||
|
(bx, bubble_h // 2 + 10)
|
||||||
|
]
|
||||||
|
draw.polygon(points, fill=bg_rgb, outline=border_rgb)
|
||||||
|
# 覆盖边框内部分
|
||||||
|
draw.polygon(points, fill=bg_rgb)
|
||||||
|
elif tail_direction == "right":
|
||||||
|
points = [
|
||||||
|
(bx + bubble_w, bubble_h // 2 - 10),
|
||||||
|
(img_w, bubble_h // 2),
|
||||||
|
(bx + bubble_w, bubble_h // 2 + 10)
|
||||||
|
]
|
||||||
|
draw.polygon(points, fill=bg_rgb, outline=border_rgb)
|
||||||
|
draw.polygon(points, fill=bg_rgb)
|
||||||
|
elif tail_direction == "bottom":
|
||||||
|
points = [
|
||||||
|
(bubble_w // 2 - 10, bubble_h),
|
||||||
|
(bubble_w // 2, img_h),
|
||||||
|
(bubble_w // 2 + 10, bubble_h)
|
||||||
|
]
|
||||||
|
draw.polygon(points, fill=bg_rgb, outline=border_rgb)
|
||||||
|
draw.polygon(points, fill=bg_rgb)
|
||||||
|
|
||||||
|
# 绘制文字
|
||||||
|
font_rgb = _hex_to_rgb(font_color) + (255,)
|
||||||
|
text_x = bx + padding[0]
|
||||||
|
text_y = by + padding[1]
|
||||||
|
draw.text((text_x, text_y), text, font=font, fill=font_rgb)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def create_price_tag(
|
||||||
|
price: str,
|
||||||
|
currency: str = "¥",
|
||||||
|
font_size: int = 72,
|
||||||
|
price_color: str = "#FF4444",
|
||||||
|
currency_color: str = "#FF4444",
|
||||||
|
stroke_color: str = "#FFFFFF",
|
||||||
|
stroke_width: int = 4,
|
||||||
|
font_name: str = None
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
创建价格标签(电商风格)
|
||||||
|
"""
|
||||||
|
font_large = _get_font(font_name, font_size)
|
||||||
|
font_small = _get_font(font_name, int(font_size * 0.5))
|
||||||
|
|
||||||
|
# 测量尺寸
|
||||||
|
currency_w, currency_h = _get_text_size(currency, font_small)
|
||||||
|
price_w, price_h = _get_text_size(price, font_large)
|
||||||
|
|
||||||
|
total_w = currency_w + price_w + 5
|
||||||
|
total_h = max(currency_h, price_h)
|
||||||
|
|
||||||
|
padding = stroke_width + 10
|
||||||
|
img_w = total_w + padding * 2
|
||||||
|
img_h = total_h + padding * 2
|
||||||
|
|
||||||
|
img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
# 绘制描边
|
||||||
|
stroke_rgb = _hex_to_rgb(stroke_color) + (255,)
|
||||||
|
for dx in range(-stroke_width, stroke_width + 1):
|
||||||
|
for dy in range(-stroke_width, stroke_width + 1):
|
||||||
|
if dx * dx + dy * dy <= stroke_width * stroke_width:
|
||||||
|
# 货币符号
|
||||||
|
draw.text(
|
||||||
|
(padding + dx, padding + (total_h - currency_h) // 2 + dy),
|
||||||
|
currency, font=font_small, fill=stroke_rgb
|
||||||
|
)
|
||||||
|
# 价格
|
||||||
|
draw.text(
|
||||||
|
(padding + currency_w + 5 + dx, padding + (total_h - price_h) // 2 + dy),
|
||||||
|
price, font=font_large, fill=stroke_rgb
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制文字
|
||||||
|
currency_rgb = _hex_to_rgb(currency_color) + (255,)
|
||||||
|
price_rgb = _hex_to_rgb(price_color) + (255,)
|
||||||
|
|
||||||
|
draw.text(
|
||||||
|
(padding, padding + (total_h - currency_h) // 2),
|
||||||
|
currency, font=font_small, fill=currency_rgb
|
||||||
|
)
|
||||||
|
draw.text(
|
||||||
|
(padding + currency_w + 5, padding + (total_h - price_h) // 2),
|
||||||
|
price, font=font_large, fill=price_rgb
|
||||||
|
)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def create_button(
|
||||||
|
text: str,
|
||||||
|
font_size: int = 36,
|
||||||
|
font_color: str = "#FFFFFF",
|
||||||
|
bg_color: str = "#FF6B35",
|
||||||
|
corner_radius: int = 25,
|
||||||
|
padding: Tuple[int, int] = (40, 15),
|
||||||
|
font_name: str = None,
|
||||||
|
shadow: bool = True
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
创建按钮样式文字(如"立即抢购")
|
||||||
|
"""
|
||||||
|
font = _get_font(font_name, font_size)
|
||||||
|
text_w, text_h = _get_text_size(text, font)
|
||||||
|
|
||||||
|
btn_w = text_w + padding[0] * 2
|
||||||
|
btn_h = text_h + padding[1] * 2
|
||||||
|
|
||||||
|
shadow_offset = 4 if shadow else 0
|
||||||
|
img_w = btn_w + shadow_offset
|
||||||
|
img_h = btn_h + shadow_offset
|
||||||
|
|
||||||
|
img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
# 绘制阴影
|
||||||
|
if shadow:
|
||||||
|
shadow_color = (0, 0, 0, 80)
|
||||||
|
draw.rounded_rectangle(
|
||||||
|
[shadow_offset, shadow_offset, btn_w + shadow_offset, btn_h + shadow_offset],
|
||||||
|
radius=corner_radius,
|
||||||
|
fill=shadow_color
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制按钮背景
|
||||||
|
bg_rgb = _hex_to_rgb(bg_color) + (255,)
|
||||||
|
draw.rounded_rectangle(
|
||||||
|
[0, 0, btn_w, btn_h],
|
||||||
|
radius=corner_radius,
|
||||||
|
fill=bg_rgb
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制文字
|
||||||
|
font_rgb = _hex_to_rgb(font_color) + (255,)
|
||||||
|
text_x = padding[0]
|
||||||
|
text_y = padding[1]
|
||||||
|
draw.text((text_x, text_y), text, font=font, fill=font_rgb)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def create_comparison_text(
|
||||||
|
left_text: str,
|
||||||
|
right_text: str,
|
||||||
|
vs_text: str = "vs",
|
||||||
|
font_size: int = 48,
|
||||||
|
left_color: str = "#666666",
|
||||||
|
right_color: str = "#FF6B35",
|
||||||
|
vs_color: str = "#FF0000",
|
||||||
|
font_name: str = None
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
创建对比文字(如"塌马尾 vs 高颅顶")
|
||||||
|
"""
|
||||||
|
font = _get_font(font_name, font_size)
|
||||||
|
font_vs = _get_font(font_name, int(font_size * 0.8))
|
||||||
|
|
||||||
|
left_w, left_h = _get_text_size(left_text, font)
|
||||||
|
vs_w, vs_h = _get_text_size(vs_text, font_vs)
|
||||||
|
right_w, right_h = _get_text_size(right_text, font)
|
||||||
|
|
||||||
|
spacing = 15
|
||||||
|
total_w = left_w + vs_w + right_w + spacing * 2
|
||||||
|
total_h = max(left_h, vs_h, right_h)
|
||||||
|
|
||||||
|
padding = 20
|
||||||
|
stroke_width = 3
|
||||||
|
img_w = total_w + padding * 2 + stroke_width * 2
|
||||||
|
img_h = total_h + padding * 2 + stroke_width * 2
|
||||||
|
|
||||||
|
img = Image.new("RGBA", (img_w, img_h), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
x = padding + stroke_width
|
||||||
|
y = padding + stroke_width
|
||||||
|
|
||||||
|
# 描边
|
||||||
|
stroke_color = (0, 0, 0, 255)
|
||||||
|
for dx in range(-stroke_width, stroke_width + 1):
|
||||||
|
for dy in range(-stroke_width, stroke_width + 1):
|
||||||
|
if dx * dx + dy * dy <= stroke_width * stroke_width:
|
||||||
|
draw.text((x + dx, y + (total_h - left_h) // 2 + dy), left_text, font=font, fill=stroke_color)
|
||||||
|
draw.text((x + left_w + spacing + dx, y + (total_h - vs_h) // 2 + dy), vs_text, font=font_vs, fill=stroke_color)
|
||||||
|
draw.text((x + left_w + spacing + vs_w + spacing + dx, y + (total_h - right_h) // 2 + dy), right_text, font=font, fill=stroke_color)
|
||||||
|
|
||||||
|
# 绘制文字
|
||||||
|
left_rgb = _hex_to_rgb(left_color) + (255,)
|
||||||
|
vs_rgb = _hex_to_rgb(vs_color) + (255,)
|
||||||
|
right_rgb = _hex_to_rgb(right_color) + (255,)
|
||||||
|
|
||||||
|
draw.text((x, y + (total_h - left_h) // 2), left_text, font=font, fill=left_rgb)
|
||||||
|
draw.text((x + left_w + spacing, y + (total_h - vs_h) // 2), vs_text, font=font_vs, fill=vs_rgb)
|
||||||
|
draw.text((x + left_w + spacing + vs_w + spacing, y + (total_h - right_h) // 2), right_text, font=font, fill=right_rgb)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 预设样式
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
PRESET_STYLES = {
|
||||||
|
"subtitle": {
|
||||||
|
"font_size": 48,
|
||||||
|
"font_color": "#FFFFFF",
|
||||||
|
"stroke_color": "#000000",
|
||||||
|
"stroke_width": 3,
|
||||||
|
"version": "v2"
|
||||||
|
},
|
||||||
|
"highlight": {
|
||||||
|
# 暖米白主色 + 浅描边 + 暗色阴影,匹配浅棕背景
|
||||||
|
"font_size": 90,
|
||||||
|
"font_color": "#F7E7D3",
|
||||||
|
"stroke_color": "#C9B59A", # 浅描边
|
||||||
|
"stroke_width": 4,
|
||||||
|
"type": "shadow",
|
||||||
|
"shadow_color": "#3A2C1F", # 暗棕阴影
|
||||||
|
"shadow_offset": (3, 3),
|
||||||
|
"shadow_blur": 10,
|
||||||
|
"padding": 32,
|
||||||
|
"version": "gloda"
|
||||||
|
},
|
||||||
|
"warning": {
|
||||||
|
# 低饱和陶土红 + 米色描边 + 暗棕阴影
|
||||||
|
"font_size": 80,
|
||||||
|
"font_color": "#D96B4F",
|
||||||
|
"stroke_color": "#F6E5D6",
|
||||||
|
"stroke_width": 4,
|
||||||
|
"type": "shadow",
|
||||||
|
"shadow_color": "#3A2C1F",
|
||||||
|
"shadow_offset": (3, 3),
|
||||||
|
"shadow_blur": 10,
|
||||||
|
"padding": 30,
|
||||||
|
"version": "gloda"
|
||||||
|
},
|
||||||
|
"success": {
|
||||||
|
"font_size": 52,
|
||||||
|
"font_color": "#4CAF50",
|
||||||
|
"stroke_color": "#FFFFFF",
|
||||||
|
"stroke_width": 4,
|
||||||
|
"version": "v2"
|
||||||
|
},
|
||||||
|
"price": {
|
||||||
|
# 价格标签:温暖红 + 米白货币符号 + 暗描边
|
||||||
|
"font_size": 110,
|
||||||
|
"price_color": "#E25B4F",
|
||||||
|
"currency_color": "#F6E5D6",
|
||||||
|
"stroke_color": "#3A2C1F",
|
||||||
|
"stroke_width": 8,
|
||||||
|
"type": "price",
|
||||||
|
"version": "gloda"
|
||||||
|
},
|
||||||
|
"cta_button": {
|
||||||
|
# 暖橙按钮,轻阴影
|
||||||
|
"font_size": 46,
|
||||||
|
"font_color": "#FFFFFF",
|
||||||
|
"bg_color": "#E6763A",
|
||||||
|
"corner_radius": 32,
|
||||||
|
"type": "button",
|
||||||
|
"shadow": True,
|
||||||
|
"version": "gloda"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_fancy_text(
|
||||||
|
text: str,
|
||||||
|
style: str = "subtitle",
|
||||||
|
custom_style: Dict[str, Any] = None,
|
||||||
|
cache: bool = True
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
创建花字图片的统一入口
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 文字内容
|
||||||
|
style: 预设样式名称
|
||||||
|
custom_style: 自定义样式(覆盖预设)
|
||||||
|
cache: 是否缓存
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PNG 图片路径
|
||||||
|
"""
|
||||||
|
# 合并样式
|
||||||
|
base_style = PRESET_STYLES.get(style, PRESET_STYLES["subtitle"]).copy()
|
||||||
|
if custom_style:
|
||||||
|
base_style.update(custom_style)
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
if cache:
|
||||||
|
cache_name = _cache_key(text, base_style)
|
||||||
|
cache_path = FANCY_TEXT_CACHE_DIR / f"{cache_name}.png"
|
||||||
|
if cache_path.exists():
|
||||||
|
return str(cache_path)
|
||||||
|
|
||||||
|
# 根据样式类型创建图片
|
||||||
|
style_type = base_style.pop("type", None)
|
||||||
|
|
||||||
|
if style == "price" or style_type == "price":
|
||||||
|
img = create_price_tag(text, **{k: v for k, v in base_style.items() if k in [
|
||||||
|
"currency", "font_size", "price_color", "currency_color", "stroke_color", "stroke_width", "font_name"
|
||||||
|
]})
|
||||||
|
elif style == "cta_button" or style_type == "button":
|
||||||
|
img = create_button(text, **{k: v for k, v in base_style.items() if k in [
|
||||||
|
"font_size", "font_color", "bg_color", "corner_radius", "padding", "font_name", "shadow"
|
||||||
|
]})
|
||||||
|
elif style_type == "bubble":
|
||||||
|
img = create_bubble_text(text, **{k: v for k, v in base_style.items() if k in [
|
||||||
|
"font_size", "font_color", "bg_color", "border_color", "border_width",
|
||||||
|
"corner_radius", "padding", "font_name", "tail_direction"
|
||||||
|
]})
|
||||||
|
elif style_type == "gradient":
|
||||||
|
img = create_text_with_gradient(text, **{k: v for k, v in base_style.items() if k in [
|
||||||
|
"font_size", "gradient_colors", "gradient_direction", "stroke_color", "stroke_width", "font_name", "padding"
|
||||||
|
]})
|
||||||
|
elif style_type == "shadow":
|
||||||
|
img = create_text_with_shadow(text, **{k: v for k, v in base_style.items() if k in [
|
||||||
|
"font_size", "font_color", "shadow_color", "shadow_offset", "shadow_blur", "font_name", "padding"
|
||||||
|
]})
|
||||||
|
else:
|
||||||
|
# 默认带描边文字
|
||||||
|
img = create_text_with_stroke(text, **{k: v for k, v in base_style.items() if k in [
|
||||||
|
"font_size", "font_color", "stroke_color", "stroke_width", "font_name", "padding"
|
||||||
|
]})
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
if cache:
|
||||||
|
output_path = str(cache_path)
|
||||||
|
else:
|
||||||
|
output_path = str(config.TEMP_DIR / f"fancy_{hash(text)}_{os.getpid()}.png")
|
||||||
|
|
||||||
|
img.save(output_path, "PNG")
|
||||||
|
logger.info(f"Created fancy text: '{text[:20]}...' -> {output_path}")
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def batch_create_fancy_texts(
|
||||||
|
configs: List[Dict[str, Any]]
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
批量创建花字图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
configs: 配置列表 [{text, style, custom_style}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PNG 图片路径列表
|
||||||
|
"""
|
||||||
|
paths = []
|
||||||
|
for cfg in configs:
|
||||||
|
path = create_fancy_text(
|
||||||
|
text=cfg.get("text", ""),
|
||||||
|
style=cfg.get("style", "subtitle"),
|
||||||
|
custom_style=cfg.get("custom_style")
|
||||||
|
)
|
||||||
|
paths.append(path)
|
||||||
|
return paths
|
||||||
|
|
||||||
960
modules/ffmpeg_utils.py
Normal file
960
modules/ffmpeg_utils.py
Normal file
@@ -0,0 +1,960 @@
|
|||||||
|
"""
|
||||||
|
FFmpeg 视频处理工具模块
|
||||||
|
支持规模化批量视频处理:拼接、字幕、叠加、混音
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# FFmpeg/FFprobe 路径(优先使用项目内的二进制)
|
||||||
|
FFMPEG_PATH = str(config.BASE_DIR / "bin" / "ffmpeg") if (config.BASE_DIR / "bin" / "ffmpeg").exists() else "ffmpeg"
|
||||||
|
FFPROBE_PATH = str(config.BASE_DIR / "bin" / "ffprobe") if (config.BASE_DIR / "bin" / "ffprobe").exists() else "ffprobe"
|
||||||
|
|
||||||
|
# 字体路径优先使用项目自带中文字体,其次使用 Linux 系统字体,最后再回退到 macOS 路径
|
||||||
|
DEFAULT_FONT_PATHS = [
|
||||||
|
# 优先使用 Linux 系统级中文字体 (服务器环境最稳健)
|
||||||
|
"/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf",
|
||||||
|
"/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
|
||||||
|
|
||||||
|
# 项目内字体 (注意:需确保文件不是 LFS 指针)
|
||||||
|
str(config.FONTS_DIR / "HarmonyOS-Sans-SC-Regular.ttf"),
|
||||||
|
str(config.FONTS_DIR / "AlibabaPuHuiTi-Regular.ttf"),
|
||||||
|
|
||||||
|
# macOS 字体(仅本地调试生效)
|
||||||
|
"/System/Library/Fonts/PingFang.ttc",
|
||||||
|
"/System/Library/Fonts/STHeiti Medium.ttc",
|
||||||
|
"/System/Library/Fonts/Supplemental/Arial Unicode.ttf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_font_path() -> str:
|
||||||
|
for p in DEFAULT_FONT_PATHS:
|
||||||
|
if os.path.exists(p) and os.path.getsize(p) > 1000:
|
||||||
|
return p
|
||||||
|
return "Arial" # 极端情况下退回英文字体,避免崩溃
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_text(text: str) -> str:
|
||||||
|
"""
|
||||||
|
去除可能导致 ffmpeg 命令行错误的特殊控制字符,但保留 Emoji、数字、标点和各国语言。
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 不再过滤任何字符,只确保不是 None
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def add_silence_audio(video_path: str, output_path: str) -> str:
|
||||||
|
"""
|
||||||
|
给无音轨的视频补一条静音轨(立体声 44.1k),避免后续 filter 找不到 0:a
|
||||||
|
"""
|
||||||
|
cmd = [
|
||||||
|
FFMPEG_PATH, "-y",
|
||||||
|
"-i", video_path,
|
||||||
|
"-f", "lavfi",
|
||||||
|
"-i", "anullsrc=channel_layout=stereo:sample_rate=44100",
|
||||||
|
"-shortest",
|
||||||
|
"-c:v", "copy",
|
||||||
|
"-c:a", "aac",
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def _run_ffmpeg(cmd: List[str], check: bool = True) -> subprocess.CompletedProcess:
|
||||||
|
"""执行 FFmpeg 命令"""
|
||||||
|
logger.debug(f"FFmpeg command: {' '.join(cmd)}")
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=check
|
||||||
|
)
|
||||||
|
# 无论成功失败,输出 stderr 以便排查字体等警告
|
||||||
|
if result.stderr:
|
||||||
|
print(f"[FFmpeg stderr] {result.stderr}", flush=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.error(f"FFmpeg stderr: {result.stderr}")
|
||||||
|
return result
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"FFmpeg failed: {e.stderr}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_info(video_path: str) -> Dict[str, Any]:
|
||||||
|
"""获取视频信息(时长、分辨率、帧率等)"""
|
||||||
|
cmd = [
|
||||||
|
FFPROBE_PATH,
|
||||||
|
"-v", "quiet",
|
||||||
|
"-print_format", "json",
|
||||||
|
"-show_format",
|
||||||
|
"-show_streams",
|
||||||
|
video_path
|
||||||
|
]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise ValueError(f"Failed to probe video: {video_path}")
|
||||||
|
|
||||||
|
import json
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
|
||||||
|
# 提取关键信息
|
||||||
|
info = {
|
||||||
|
"duration": float(data.get("format", {}).get("duration", 0)),
|
||||||
|
"width": 0,
|
||||||
|
"height": 0,
|
||||||
|
"fps": 30
|
||||||
|
}
|
||||||
|
|
||||||
|
for stream in data.get("streams", []):
|
||||||
|
if stream.get("codec_type") == "video":
|
||||||
|
info["width"] = stream.get("width", 0)
|
||||||
|
info["height"] = stream.get("height", 0)
|
||||||
|
# 解析帧率 (如 "30/1" 或 "29.97")
|
||||||
|
fps_str = stream.get("r_frame_rate", "30/1")
|
||||||
|
if "/" in fps_str:
|
||||||
|
num, den = fps_str.split("/")
|
||||||
|
info["fps"] = float(num) / float(den) if float(den) != 0 else 30
|
||||||
|
else:
|
||||||
|
info["fps"] = float(fps_str)
|
||||||
|
break
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def concat_videos(
|
||||||
|
video_paths: List[str],
|
||||||
|
output_path: str,
|
||||||
|
target_size: Tuple[int, int] = (1080, 1920)
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
使用 FFmpeg concat demuxer 拼接多段视频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_paths: 视频文件路径列表
|
||||||
|
output_path: 输出文件路径
|
||||||
|
target_size: 目标分辨率 (width, height),默认竖屏 1080x1920
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
输出文件路径
|
||||||
|
"""
|
||||||
|
if not video_paths:
|
||||||
|
raise ValueError("No video paths provided")
|
||||||
|
|
||||||
|
logger.info(f"Concatenating {len(video_paths)} videos...")
|
||||||
|
|
||||||
|
# 创建 concat 文件列表
|
||||||
|
concat_file = config.TEMP_DIR / f"concat_{os.getpid()}.txt"
|
||||||
|
|
||||||
|
with open(concat_file, "w", encoding="utf-8") as f:
|
||||||
|
for vp in video_paths:
|
||||||
|
# 使用绝对路径并转义单引号
|
||||||
|
abs_path = os.path.abspath(vp)
|
||||||
|
f.write(f"file '{abs_path}'\n")
|
||||||
|
|
||||||
|
width, height = target_size
|
||||||
|
|
||||||
|
# 使用 filter_complex 统一分辨率后拼接
|
||||||
|
# 每个视频先 scale + pad 到目标尺寸
|
||||||
|
filter_parts = []
|
||||||
|
for i in range(len(video_paths)):
|
||||||
|
# scale 保持宽高比,pad 填充黑边居中
|
||||||
|
filter_parts.append(
|
||||||
|
f"[{i}:v]scale={width}:{height}:force_original_aspect_ratio=decrease,"
|
||||||
|
f"pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black,setsar=1[v{i}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 拼接所有视频流
|
||||||
|
concat_inputs = "".join([f"[v{i}]" for i in range(len(video_paths))])
|
||||||
|
filter_parts.append(f"{concat_inputs}concat=n={len(video_paths)}:v=1:a=0[outv]")
|
||||||
|
|
||||||
|
filter_complex = ";".join(filter_parts)
|
||||||
|
|
||||||
|
# 构建 ffmpeg 命令
|
||||||
|
cmd = [FFMPEG_PATH, "-y"]
|
||||||
|
for vp in video_paths:
|
||||||
|
cmd.extend(["-i", vp])
|
||||||
|
|
||||||
|
cmd.extend([
|
||||||
|
"-filter_complex", filter_complex,
|
||||||
|
"-map", "[outv]",
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "fast",
|
||||||
|
"-crf", "23",
|
||||||
|
"-pix_fmt", "yuv420p",
|
||||||
|
output_path
|
||||||
|
])
|
||||||
|
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
|
||||||
|
# 清理临时文件
|
||||||
|
if concat_file.exists():
|
||||||
|
concat_file.unlink()
|
||||||
|
|
||||||
|
logger.info(f"Concatenated video saved: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def concat_videos_with_audio(
|
||||||
|
video_paths: List[str],
|
||||||
|
output_path: str,
|
||||||
|
target_size: Tuple[int, int] = (1080, 1920)
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
拼接视频并保留音频轨道
|
||||||
|
"""
|
||||||
|
if not video_paths:
|
||||||
|
raise ValueError("No video paths provided")
|
||||||
|
|
||||||
|
logger.info(f"Concatenating {len(video_paths)} videos with audio...")
|
||||||
|
|
||||||
|
width, height = target_size
|
||||||
|
n = len(video_paths)
|
||||||
|
|
||||||
|
# 构建 filter_complex
|
||||||
|
filter_parts = []
|
||||||
|
|
||||||
|
# 视频处理
|
||||||
|
for i in range(n):
|
||||||
|
filter_parts.append(
|
||||||
|
f"[{i}:v]scale={width}:{height}:force_original_aspect_ratio=decrease,"
|
||||||
|
f"pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:black,setsar=1[v{i}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 音频处理(静音填充如果没有音频)
|
||||||
|
for i in range(n):
|
||||||
|
filter_parts.append(f"[{i}:a]aformat=sample_rates=44100:channel_layouts=stereo[a{i}]")
|
||||||
|
|
||||||
|
# 拼接
|
||||||
|
v_concat = "".join([f"[v{i}]" for i in range(n)])
|
||||||
|
a_concat = "".join([f"[a{i}]" for i in range(n)])
|
||||||
|
filter_parts.append(f"{v_concat}concat=n={n}:v=1:a=0[outv]")
|
||||||
|
filter_parts.append(f"{a_concat}concat=n={n}:v=0:a=1[outa]")
|
||||||
|
|
||||||
|
filter_complex = ";".join(filter_parts)
|
||||||
|
|
||||||
|
cmd = [FFMPEG_PATH, "-y"]
|
||||||
|
for vp in video_paths:
|
||||||
|
cmd.extend(["-i", vp])
|
||||||
|
|
||||||
|
cmd.extend([
|
||||||
|
"-filter_complex", filter_complex,
|
||||||
|
"-map", "[outv]",
|
||||||
|
"-map", "[outa]",
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "fast",
|
||||||
|
"-crf", "23",
|
||||||
|
"-c:a", "aac",
|
||||||
|
"-b:a", "128k",
|
||||||
|
"-pix_fmt", "yuv420p",
|
||||||
|
output_path
|
||||||
|
])
|
||||||
|
|
||||||
|
try:
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
# 如果音频拼接失败,回退到无音频版本
|
||||||
|
logger.warning("Audio concat failed, falling back to video only")
|
||||||
|
return concat_videos(video_paths, output_path, target_size)
|
||||||
|
|
||||||
|
logger.info(f"Concatenated video with audio saved: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def add_subtitle(
|
||||||
|
video_path: str,
|
||||||
|
text: str,
|
||||||
|
start: float,
|
||||||
|
duration: float,
|
||||||
|
output_path: str,
|
||||||
|
style: Dict[str, Any] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
使用 drawtext filter 添加单条字幕
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: 输入视频路径
|
||||||
|
text: 字幕文本
|
||||||
|
start: 开始时间(秒)
|
||||||
|
duration: 持续时间(秒)
|
||||||
|
output_path: 输出路径
|
||||||
|
style: 样式配置 {
|
||||||
|
fontsize: 字体大小,
|
||||||
|
fontcolor: 字体颜色,
|
||||||
|
borderw: 描边宽度,
|
||||||
|
bordercolor: 描边颜色,
|
||||||
|
x: x位置 (可用表达式如 "(w-text_w)/2"),
|
||||||
|
y: y位置,
|
||||||
|
font: 字体路径或名称
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
输出文件路径
|
||||||
|
"""
|
||||||
|
style = style or {}
|
||||||
|
|
||||||
|
# 默认样式
|
||||||
|
fontsize = style.get("fontsize", 48)
|
||||||
|
fontcolor = style.get("fontcolor", "white")
|
||||||
|
borderw = style.get("borderw", 3)
|
||||||
|
bordercolor = style.get("bordercolor", "black")
|
||||||
|
x = style.get("x", "(w-text_w)/2") # 默认水平居中
|
||||||
|
y = style.get("y", "h-200") # 默认底部偏上
|
||||||
|
|
||||||
|
# 优先使用动态检测到的有效字体,而不是硬编码的可能损坏的路径
|
||||||
|
default_font_path = _get_font_path()
|
||||||
|
font = style.get("font", default_font_path)
|
||||||
|
|
||||||
|
# 转义特殊字符
|
||||||
|
escaped_text = text.replace("'", "\\'").replace(":", "\\:")
|
||||||
|
|
||||||
|
# drawtext filter
|
||||||
|
drawtext = (
|
||||||
|
f"drawtext=text='{escaped_text}':"
|
||||||
|
f"fontfile='{font}':"
|
||||||
|
f"fontsize={fontsize}:"
|
||||||
|
f"fontcolor={fontcolor}:"
|
||||||
|
f"borderw={borderw}:"
|
||||||
|
f"bordercolor={bordercolor}:"
|
||||||
|
f"x={x}:y={y}:"
|
||||||
|
f"enable='between(t,{start},{start + duration})'"
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
FFMPEG_PATH, "-y",
|
||||||
|
"-i", video_path,
|
||||||
|
"-vf", drawtext,
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "fast",
|
||||||
|
"-crf", "23",
|
||||||
|
"-c:a", "copy",
|
||||||
|
"-pix_fmt", "yuv420p",
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
logger.info(f"Added subtitle: '{text[:20]}...' at {start}s")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_text(text: str, max_chars: int = 18) -> str:
|
||||||
|
"""
|
||||||
|
简单的文本换行处理
|
||||||
|
"""
|
||||||
|
if not text: return ""
|
||||||
|
|
||||||
|
# 如果已经有换行符,假设用户已经手动处理
|
||||||
|
if "\n" in text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
result = ""
|
||||||
|
count = 0
|
||||||
|
for char in text:
|
||||||
|
if count >= max_chars:
|
||||||
|
result += "\n"
|
||||||
|
count = 0
|
||||||
|
result += char
|
||||||
|
# 简单估算:中文算1个,英文也算1个(等宽字体)
|
||||||
|
# 实际上中英文混合较复杂,这里简化处理
|
||||||
|
count += 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def mix_audio_at_offset(
|
||||||
|
base_audio: str,
|
||||||
|
overlay_audio: str,
|
||||||
|
offset: float,
|
||||||
|
output_path: str,
|
||||||
|
base_volume: float = 1.0,
|
||||||
|
overlay_volume: float = 1.0
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
在指定偏移位置混合音频
|
||||||
|
"""
|
||||||
|
# 如果 base_audio 不存在,创建一个静音底
|
||||||
|
if not os.path.exists(base_audio):
|
||||||
|
logger.warning(f"Base audio not found: {base_audio}")
|
||||||
|
return overlay_audio
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
FFMPEG_PATH, "-y",
|
||||||
|
"-i", base_audio,
|
||||||
|
"-i", overlay_audio,
|
||||||
|
"-filter_complex",
|
||||||
|
f"[0:a]volume={base_volume}[a0];[1:a]volume={overlay_volume},adelay={int(offset*1000)}|{int(offset*1000)}[a1];[a0][a1]amix=inputs=2:duration=first:dropout_transition=0:normalize=0[out]",
|
||||||
|
"-map", "[out]",
|
||||||
|
"-c:a", "mp3", # Use MP3 for audio only mixing
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_audio_duration(
|
||||||
|
input_path: str,
|
||||||
|
target_duration: float,
|
||||||
|
output_path: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
调整音频时长(仅在音频过长时加速,音频较短时保持原速)
|
||||||
|
|
||||||
|
用户需求:
|
||||||
|
- 音频时长 > 目标时长 → 加速播放
|
||||||
|
- 音频时长 <= 目标时长 → 保持原速(不慢放)
|
||||||
|
"""
|
||||||
|
if not os.path.exists(input_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
current_duration = float(get_audio_info(input_path).get("duration", 0))
|
||||||
|
if current_duration <= 0:
|
||||||
|
return input_path
|
||||||
|
|
||||||
|
# 只在音频过长时才加速,音频较短时保持原速
|
||||||
|
if current_duration <= target_duration:
|
||||||
|
# 音频时长 <= 目标时长,不需要调整,直接复制
|
||||||
|
import shutil
|
||||||
|
shutil.copy(input_path, output_path)
|
||||||
|
logger.info(f"Audio ({current_duration:.2f}s) <= target ({target_duration:.2f}s), keeping original speed")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# 音频过长,需要加速
|
||||||
|
speed_ratio = current_duration / target_duration
|
||||||
|
|
||||||
|
# 限制加速范围 (最多2倍速),避免声音变调太严重
|
||||||
|
speed_ratio = min(speed_ratio, 2.0)
|
||||||
|
|
||||||
|
logger.info(f"Audio ({current_duration:.2f}s) > target ({target_duration:.2f}s), speeding up {speed_ratio:.2f}x")
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
FFMPEG_PATH, "-y",
|
||||||
|
"-i", input_path,
|
||||||
|
"-filter:a", f"atempo={speed_ratio}",
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_info(file_path: str) -> Dict[str, Any]:
|
||||||
|
"""获取音频信息"""
|
||||||
|
return get_video_info(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_text_smart(text: str, max_chars: int = 15) -> str:
|
||||||
|
"""
|
||||||
|
智能字幕换行(上短下长策略)
|
||||||
|
"""
|
||||||
|
if not text or len(text) <= max_chars:
|
||||||
|
return text
|
||||||
|
|
||||||
|
# 优先在标点或空格处换行
|
||||||
|
split_chars = [",", "。", "!", "?", " ", ",", ".", "!", "?"]
|
||||||
|
best_split = -1
|
||||||
|
|
||||||
|
# 寻找中间附近的分割点
|
||||||
|
mid = len(text) // 2
|
||||||
|
|
||||||
|
for i in range(len(text)):
|
||||||
|
if text[i] in split_chars:
|
||||||
|
# 偏好后半部分(上短下长)
|
||||||
|
if abs(i - mid) < abs(best_split - mid):
|
||||||
|
best_split = i
|
||||||
|
|
||||||
|
if best_split != -1 and best_split < len(text) - 1:
|
||||||
|
return text[:best_split+1] + "\n" + text[best_split+1:]
|
||||||
|
|
||||||
|
# 强制换行(上短下长)
|
||||||
|
split_idx = int(len(text) * 0.4) # 上面 40%
|
||||||
|
return text[:split_idx] + "\n" + text[split_idx:]
|
||||||
|
|
||||||
|
|
||||||
|
def add_multiple_subtitles(
|
||||||
|
video_path: str,
|
||||||
|
subtitles: List[Dict[str, Any]],
|
||||||
|
output_path: str,
|
||||||
|
default_style: Dict[str, Any] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
添加多条字幕
|
||||||
|
"""
|
||||||
|
if not subtitles:
|
||||||
|
# 无字幕直接复制
|
||||||
|
import shutil
|
||||||
|
shutil.copy(video_path, output_path)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
default_style = default_style or {}
|
||||||
|
# 强制使用完整字体(先用项目内 NotoSansSC,如果不存在则回退 Droid)
|
||||||
|
font = "/root/video-flow/assets/fonts/NotoSansSC-Regular.otf"
|
||||||
|
if not (os.path.exists(font) and os.path.getsize(font) > 1024 * 100): # 至少100KB以上认为有效
|
||||||
|
font = "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf"
|
||||||
|
if not (os.path.exists(font) and os.path.getsize(font) > 1024 * 100):
|
||||||
|
font = _get_font_path()
|
||||||
|
|
||||||
|
print(f"[SubDebug] Using font for subtitles: {font}", flush=True)
|
||||||
|
|
||||||
|
# 构建多个 drawtext filter
|
||||||
|
filters = []
|
||||||
|
for sub in subtitles:
|
||||||
|
raw_text = sub.get("text", "")
|
||||||
|
# 打印原始文本的 repr 和 hex,以便排查特殊字符
|
||||||
|
print(f"[SubDebug] Subtitle text repr: {repr(raw_text)}", flush=True)
|
||||||
|
print(f"[SubDebug] Subtitle text hex: {' '.join(hex(ord(c)) for c in raw_text)}", flush=True)
|
||||||
|
|
||||||
|
text = _sanitize_text(raw_text)
|
||||||
|
# 自动换行
|
||||||
|
text = wrap_text(text)
|
||||||
|
|
||||||
|
start = sub.get("start", 0)
|
||||||
|
duration = sub.get("duration", 3)
|
||||||
|
style = {**default_style, **sub.get("style", {})}
|
||||||
|
|
||||||
|
fontsize = style.get("fontsize", 48)
|
||||||
|
fontcolor = style.get("fontcolor", "white")
|
||||||
|
borderw = style.get("borderw", 3)
|
||||||
|
bordercolor = style.get("bordercolor", "black")
|
||||||
|
x = style.get("x", "(w-text_w)/2")
|
||||||
|
y = style.get("y", "h-200")
|
||||||
|
|
||||||
|
# 默认启用背景框以提高可读性
|
||||||
|
box = style.get("box", 1)
|
||||||
|
boxcolor = style.get("boxcolor", "black@0.5")
|
||||||
|
boxborderw = style.get("boxborderw", 10)
|
||||||
|
|
||||||
|
# 转义:反斜杠、单引号、冒号、百分号
|
||||||
|
escaped_text = text.replace("\\", "\\\\").replace("'", "\\'").replace(":", "\\:").replace("%", "\\%")
|
||||||
|
|
||||||
|
drawtext = (
|
||||||
|
f"drawtext=text='{escaped_text}':"
|
||||||
|
f"fontfile='{font}':"
|
||||||
|
f"fontsize={fontsize}:"
|
||||||
|
f"fontcolor={fontcolor}:"
|
||||||
|
f"borderw={borderw}:"
|
||||||
|
f"bordercolor={bordercolor}:"
|
||||||
|
f"box={box}:boxcolor={boxcolor}:boxborderw={boxborderw}:"
|
||||||
|
f"x={x}:y={y}:"
|
||||||
|
f"enable='between(t,{start},{start + duration})'"
|
||||||
|
)
|
||||||
|
filters.append(drawtext)
|
||||||
|
|
||||||
|
# 用逗号连接多个 filter
|
||||||
|
vf = ",".join(filters)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
FFMPEG_PATH, "-y",
|
||||||
|
"-i", video_path,
|
||||||
|
"-vf", vf,
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "fast",
|
||||||
|
"-crf", "23",
|
||||||
|
"-c:a", "copy",
|
||||||
|
"-pix_fmt", "yuv420p",
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
logger.info(f"Added {len(subtitles)} subtitles")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def overlay_image(
|
||||||
|
video_path: str,
|
||||||
|
image_path: str,
|
||||||
|
output_path: str,
|
||||||
|
position: Tuple[int, int] = None,
|
||||||
|
start: float = 0,
|
||||||
|
duration: float = None,
|
||||||
|
fade_in: float = 0,
|
||||||
|
fade_out: float = 0
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
叠加透明PNG图片(花字、水印等)到视频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: 输入视频路径
|
||||||
|
image_path: PNG图片路径(支持透明通道)
|
||||||
|
output_path: 输出路径
|
||||||
|
position: (x, y) 位置,None则居中
|
||||||
|
start: 开始时间(秒)
|
||||||
|
duration: 持续时间(秒),None则到视频结束
|
||||||
|
fade_in: 淡入时间(秒)
|
||||||
|
fade_out: 淡出时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
输出文件路径
|
||||||
|
"""
|
||||||
|
# 获取视频信息
|
||||||
|
info = get_video_info(video_path)
|
||||||
|
video_duration = info["duration"]
|
||||||
|
|
||||||
|
if duration is None:
|
||||||
|
duration = video_duration - start
|
||||||
|
|
||||||
|
# 位置
|
||||||
|
if position:
|
||||||
|
x, y = position
|
||||||
|
pos_str = f"x={x}:y={y}"
|
||||||
|
else:
|
||||||
|
pos_str = "x=(W-w)/2:y=(H-h)/2" # 居中
|
||||||
|
|
||||||
|
# 时间控制
|
||||||
|
enable = f"enable='between(t,{start},{start + duration})'"
|
||||||
|
|
||||||
|
# 构建 overlay filter
|
||||||
|
overlay_filter = f"overlay={pos_str}:{enable}"
|
||||||
|
|
||||||
|
# 添加淡入淡出效果
|
||||||
|
if fade_in > 0 or fade_out > 0:
|
||||||
|
fade_filter = []
|
||||||
|
if fade_in > 0:
|
||||||
|
fade_filter.append(f"fade=t=in:st={start}:d={fade_in}:alpha=1")
|
||||||
|
if fade_out > 0:
|
||||||
|
fade_out_start = start + duration - fade_out
|
||||||
|
fade_filter.append(f"fade=t=out:st={fade_out_start}:d={fade_out}:alpha=1")
|
||||||
|
|
||||||
|
img_filter = ",".join(fade_filter) if fade_filter else ""
|
||||||
|
filter_complex = f"[1:v]{img_filter}[img];[0:v][img]{overlay_filter}[outv]"
|
||||||
|
else:
|
||||||
|
filter_complex = f"[0:v][1:v]{overlay_filter}[outv]"
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
FFMPEG_PATH, "-y",
|
||||||
|
"-i", video_path,
|
||||||
|
"-i", image_path,
|
||||||
|
"-filter_complex", filter_complex,
|
||||||
|
"-map", "[outv]",
|
||||||
|
"-map", "0:a?",
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "fast",
|
||||||
|
"-crf", "23",
|
||||||
|
"-c:a", "copy",
|
||||||
|
"-pix_fmt", "yuv420p",
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
logger.info(f"Overlaid image at {position or 'center'}, {start}s-{start+duration}s")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def overlay_multiple_images(
|
||||||
|
video_path: str,
|
||||||
|
images: List[Dict[str, Any]],
|
||||||
|
output_path: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
叠加多个透明PNG图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: 输入视频路径
|
||||||
|
images: 图片配置列表 [{path, x, y, start, duration}]
|
||||||
|
output_path: 输出路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
输出文件路径
|
||||||
|
"""
|
||||||
|
if not images:
|
||||||
|
import shutil
|
||||||
|
shutil.copy(video_path, output_path)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# 构建复杂 filter_complex
|
||||||
|
inputs = ["-i", video_path]
|
||||||
|
for img in images:
|
||||||
|
inputs.extend(["-i", img["path"]])
|
||||||
|
|
||||||
|
# 链式 overlay
|
||||||
|
filter_parts = []
|
||||||
|
prev_output = "0:v"
|
||||||
|
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
x = img.get("x", "(W-w)/2")
|
||||||
|
y = img.get("y", "(H-h)/2")
|
||||||
|
start = img.get("start", 0)
|
||||||
|
duration = img.get("duration", 999)
|
||||||
|
|
||||||
|
enable = f"enable='between(t,{start},{start + duration})'"
|
||||||
|
|
||||||
|
if i == len(images) - 1:
|
||||||
|
out_label = "outv"
|
||||||
|
else:
|
||||||
|
out_label = f"tmp{i}"
|
||||||
|
|
||||||
|
filter_parts.append(
|
||||||
|
f"[{prev_output}][{i+1}:v]overlay=x={x}:y={y}:{enable}[{out_label}]"
|
||||||
|
)
|
||||||
|
prev_output = out_label
|
||||||
|
|
||||||
|
filter_complex = ";".join(filter_parts)
|
||||||
|
|
||||||
|
cmd = [FFMPEG_PATH, "-y"] + inputs + [
|
||||||
|
"-filter_complex", filter_complex,
|
||||||
|
"-map", "[outv]",
|
||||||
|
"-map", "0:a?",
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "fast",
|
||||||
|
"-crf", "23",
|
||||||
|
"-c:a", "copy",
|
||||||
|
"-pix_fmt", "yuv420p",
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
logger.info(f"Overlaid {len(images)} images")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def mix_audio(
|
||||||
|
video_path: str,
|
||||||
|
audio_path: str,
|
||||||
|
output_path: str,
|
||||||
|
audio_volume: float = 1.0,
|
||||||
|
video_volume: float = 0.1,
|
||||||
|
audio_start: float = 0
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
混合音频到视频(旁白、BGM等)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: 输入视频路径
|
||||||
|
audio_path: 音频文件路径
|
||||||
|
output_path: 输出路径
|
||||||
|
audio_volume: 新音频音量(0-1)
|
||||||
|
video_volume: 原视频音量(0-1)
|
||||||
|
audio_start: 音频开始时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
输出文件路径
|
||||||
|
"""
|
||||||
|
logger.info(f"Mixing audio: {audio_path}")
|
||||||
|
|
||||||
|
# 检查视频是否有音频轨道
|
||||||
|
info = get_video_info(video_path)
|
||||||
|
video_duration = info["duration"]
|
||||||
|
|
||||||
|
# 构建 filter_complex
|
||||||
|
# adelay 用于延迟音频开始时间(毫秒)
|
||||||
|
delay_ms = int(audio_start * 1000)
|
||||||
|
|
||||||
|
filter_complex = (
|
||||||
|
f"[0:a]volume={video_volume}[va];"
|
||||||
|
f"[1:a]adelay={delay_ms}|{delay_ms},volume={audio_volume}[aa];"
|
||||||
|
f"[va][aa]amix=inputs=2:duration=longest:dropout_transition=0:normalize=0[outa]"
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
FFMPEG_PATH, "-y",
|
||||||
|
"-i", video_path,
|
||||||
|
"-i", audio_path,
|
||||||
|
"-filter_complex", filter_complex,
|
||||||
|
"-map", "0:v",
|
||||||
|
"-map", "[outa]",
|
||||||
|
"-c:v", "copy",
|
||||||
|
"-c:a", "aac",
|
||||||
|
"-b:a", "192k",
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
# 如果原视频没有音频,直接添加新音频
|
||||||
|
logger.warning("Video has no audio track, adding audio directly")
|
||||||
|
cmd = [
|
||||||
|
FFMPEG_PATH, "-y",
|
||||||
|
"-i", video_path,
|
||||||
|
"-i", audio_path,
|
||||||
|
"-map", "0:v",
|
||||||
|
"-map", "1:a",
|
||||||
|
"-c:v", "copy",
|
||||||
|
"-c:a", "aac",
|
||||||
|
"-b:a", "192k",
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
|
||||||
|
logger.info(f"Audio mixed: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def add_bgm(
|
||||||
|
video_path: str,
|
||||||
|
bgm_path: str,
|
||||||
|
output_path: str,
|
||||||
|
bgm_volume: float = 0.06,
|
||||||
|
loop: bool = True,
|
||||||
|
ducking: bool = True,
|
||||||
|
duck_gain_db: float = -6.0,
|
||||||
|
fade_in: float = 1.0,
|
||||||
|
fade_out: float = 1.0
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
添加背景音乐(自动循环以匹配视频长度)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: 输入视频路径
|
||||||
|
bgm_path: BGM文件路径
|
||||||
|
output_path: 输出路径
|
||||||
|
bgm_volume: BGM音量
|
||||||
|
loop: 是否循环BGM
|
||||||
|
"""
|
||||||
|
info = get_video_info(video_path)
|
||||||
|
video_duration = info["duration"]
|
||||||
|
|
||||||
|
if loop:
|
||||||
|
bgm_chain = (
|
||||||
|
f"[1:a]aloop=-1:size=2e+09,asetpts=N/SR/TB,"
|
||||||
|
f"atrim=0:{video_duration},"
|
||||||
|
f"afade=t=in:st=0:d={fade_in},"
|
||||||
|
f"afade=t=out:st={max(video_duration - fade_out, 0)}:d={fade_out},"
|
||||||
|
f"volume={bgm_volume}[bgm]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
bgm_chain = (
|
||||||
|
f"[1:a]"
|
||||||
|
f"afade=t=in:st=0:d={fade_in},"
|
||||||
|
f"afade=t=out:st={max(video_duration - fade_out, 0)}:d={fade_out},"
|
||||||
|
f"volume={bgm_volume}[bgm]"
|
||||||
|
)
|
||||||
|
|
||||||
|
if ducking:
|
||||||
|
# 使用安全参数的 sidechaincompress,避免 unsupported 参数
|
||||||
|
filter_complex = (
|
||||||
|
f"{bgm_chain};"
|
||||||
|
f"[0:a][bgm]sidechaincompress=threshold=0.1:ratio=4:attack=5:release=250:makeup=1:mix=1:level_in=1:level_sc=1[outa]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filter_complex = f"{bgm_chain};[0:a][bgm]amix=inputs=2:duration=first[outa]"
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
FFMPEG_PATH, "-y",
|
||||||
|
"-i", video_path,
|
||||||
|
"-stream_loop", "-1" if loop else "0",
|
||||||
|
"-i", bgm_path,
|
||||||
|
"-filter_complex", filter_complex,
|
||||||
|
"-map", "0:v",
|
||||||
|
"-map", "[outa]",
|
||||||
|
"-c:v", "copy",
|
||||||
|
"-c:a", "aac",
|
||||||
|
"-b:a", "192k",
|
||||||
|
"-t", str(video_duration),
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
# sidechain失败时,回退为 amix(保留原有音频 + 低音量BGM)
|
||||||
|
logger.warning("Sidechain failed, fallback to simple amix for BGM")
|
||||||
|
filter_complex = f"{bgm_chain};[0:a][bgm]amix=inputs=2:duration=first[outa]"
|
||||||
|
cmd = [
|
||||||
|
FFMPEG_PATH, "-y",
|
||||||
|
"-i", video_path,
|
||||||
|
"-stream_loop", "-1" if loop else "0",
|
||||||
|
"-i", bgm_path,
|
||||||
|
"-filter_complex", filter_complex,
|
||||||
|
"-map", "0:v",
|
||||||
|
"-map", "[outa]",
|
||||||
|
"-c:v", "copy",
|
||||||
|
"-c:a", "aac",
|
||||||
|
"-b:a", "192k",
|
||||||
|
"-t", str(video_duration),
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
|
||||||
|
logger.info(f"BGM added: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def trim_video(
|
||||||
|
video_path: str,
|
||||||
|
output_path: str,
|
||||||
|
start: float = 0,
|
||||||
|
duration: float = None,
|
||||||
|
end: float = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
裁剪视频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: 输入视频路径
|
||||||
|
output_path: 输出路径
|
||||||
|
start: 开始时间(秒)
|
||||||
|
duration: 持续时间(秒)
|
||||||
|
end: 结束时间(秒),与 duration 二选一
|
||||||
|
"""
|
||||||
|
cmd = [
|
||||||
|
FFMPEG_PATH, "-y",
|
||||||
|
"-i", video_path,
|
||||||
|
"-ss", str(start)
|
||||||
|
]
|
||||||
|
|
||||||
|
if duration:
|
||||||
|
cmd.extend(["-t", str(duration)])
|
||||||
|
elif end:
|
||||||
|
cmd.extend(["-to", str(end)])
|
||||||
|
|
||||||
|
cmd.extend([
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "fast",
|
||||||
|
"-crf", "23",
|
||||||
|
"-c:a", "copy",
|
||||||
|
output_path
|
||||||
|
])
|
||||||
|
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
logger.info(f"Trimmed video: {start}s - {end or start + duration}s")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def speed_up_video(
|
||||||
|
video_path: str,
|
||||||
|
output_path: str,
|
||||||
|
speed: float = 1.5
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
加速/减速视频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: 输入视频路径
|
||||||
|
output_path: 输出路径
|
||||||
|
speed: 速度倍率(>1 加速,<1 减速)
|
||||||
|
"""
|
||||||
|
# setpts 控制视频速度,atempo 控制音频速度
|
||||||
|
video_filter = f"setpts={1/speed}*PTS"
|
||||||
|
|
||||||
|
# atempo 只支持 0.5-2.0,超出需要链式处理
|
||||||
|
if speed > 2.0:
|
||||||
|
audio_filter = "atempo=2.0,atempo=" + str(speed / 2.0)
|
||||||
|
elif speed < 0.5:
|
||||||
|
audio_filter = "atempo=0.5,atempo=" + str(speed / 0.5)
|
||||||
|
else:
|
||||||
|
audio_filter = f"atempo={speed}"
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
FFMPEG_PATH, "-y",
|
||||||
|
"-i", video_path,
|
||||||
|
"-vf", video_filter,
|
||||||
|
"-af", audio_filter,
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "fast",
|
||||||
|
"-crf", "23",
|
||||||
|
"-c:a", "aac",
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
|
||||||
|
_run_ffmpeg(cmd)
|
||||||
|
logger.info(f"Speed changed to {speed}x: {output_path}")
|
||||||
|
return output_path
|
||||||
491
modules/image_gen.py
Normal file
491
modules/image_gen.py
Normal file
@@ -0,0 +1,491 @@
|
|||||||
|
"""
|
||||||
|
连贯生图模块 (Volcengine Doubao)
|
||||||
|
负责根据分镜脚本和原始素材生成一系列连贯的分镜图片
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
from modules import storage
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ImageGenerator:
|
||||||
|
"""连贯图片生成器 (Volcengine Provider)"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.api_key = config.VOLC_API_KEY
|
||||||
|
# Endpoint: https://ark.cn-beijing.volces.com/api/v3/images/generations
|
||||||
|
self.endpoint = f"https://ark.cn-beijing.volces.com/api/v3/images/generations"
|
||||||
|
self.model = config.IMAGE_MODEL_ID
|
||||||
|
|
||||||
|
def _encode_image(self, image_path: str) -> str:
|
||||||
|
"""读取图片,调整大小并转为 Base64"""
|
||||||
|
try:
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
if img.mode != 'RGB':
|
||||||
|
img = img.convert('RGB')
|
||||||
|
|
||||||
|
max_size = 1024
|
||||||
|
if max(img.size) > max_size:
|
||||||
|
img.thumbnail((max_size, max_size), Image.LANCZOS)
|
||||||
|
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
img.save(buffer, format="JPEG", quality=80)
|
||||||
|
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing image {image_path}: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def generate_single_scene_image(
|
||||||
|
self,
|
||||||
|
scene: Dict[str, Any],
|
||||||
|
original_image_path: Any,
|
||||||
|
previous_image_path: Optional[str] = None,
|
||||||
|
model_provider: str = "shubiaobiao", # "shubiaobiao", "gemini", "doubao"
|
||||||
|
visual_anchor: str = "" # 视觉锚点,强制拼接到 prompt 前
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
生成单张分镜图片 (Public)
|
||||||
|
"""
|
||||||
|
scene_id = scene["id"]
|
||||||
|
visual_prompt = scene.get("visual_prompt", "")
|
||||||
|
|
||||||
|
# 强制拼接 Visual Anchor (确保生图一致性)
|
||||||
|
if visual_anchor and visual_anchor not in visual_prompt:
|
||||||
|
visual_prompt = f"[{visual_anchor}] {visual_prompt}"
|
||||||
|
logger.info(f"Scene {scene_id}: Prepended visual_anchor to prompt")
|
||||||
|
|
||||||
|
logger.info(f"Generating image for Scene {scene_id} (Provider: {model_provider})...")
|
||||||
|
|
||||||
|
input_images = []
|
||||||
|
|
||||||
|
# Handle original_image_path (can be str or list)
|
||||||
|
if isinstance(original_image_path, list):
|
||||||
|
input_images.extend(original_image_path)
|
||||||
|
elif isinstance(original_image_path, str) and original_image_path:
|
||||||
|
input_images.append(original_image_path)
|
||||||
|
|
||||||
|
if previous_image_path:
|
||||||
|
input_images.append(previous_image_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
output_path = self._generate_single_image(
|
||||||
|
prompt=visual_prompt,
|
||||||
|
reference_images=input_images,
|
||||||
|
output_filename=f"scene_{scene_id}_{int(time.time())}.png",
|
||||||
|
provider=model_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
return output_path
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Image generation returned empty for Scene {scene_id}")
|
||||||
|
|
||||||
|
except PermissionError as e:
|
||||||
|
logger.error(f"Critical API Error for Scene {scene_id}: {e}")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Image generation failed for Scene {scene_id}: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def generate_group_images_doubao(
|
||||||
|
self,
|
||||||
|
scenes: List[Dict[str, Any]],
|
||||||
|
reference_images: List[str],
|
||||||
|
visual_anchor: str = "" # 视觉锚点
|
||||||
|
) -> Dict[int, str]:
|
||||||
|
"""
|
||||||
|
Doubao 组图生成 (Batch) - 拼接 Prompt 一次生成多张
|
||||||
|
"""
|
||||||
|
logger.info("Starting Doubao Group Image Generation...")
|
||||||
|
|
||||||
|
# 1. 拼接 Prompts
|
||||||
|
# 格式: "Global: [Visual Anchor] ... | S1: ... | S2: ..."
|
||||||
|
|
||||||
|
scene_prompts = []
|
||||||
|
for scene in scenes:
|
||||||
|
# 提取分镜 Visual Prompt
|
||||||
|
p = scene.get("visual_prompt", "")
|
||||||
|
scene_prompts.append(f"S{scene['id']}:{p}")
|
||||||
|
|
||||||
|
combined_scenes_text = " | ".join(scene_prompts)
|
||||||
|
|
||||||
|
# 构造 Combined Prompt - 将 visual_anchor 放入 Global 部分
|
||||||
|
global_context = f"[{visual_anchor}] Consistent product appearance & style." if visual_anchor else "Consistent product appearance & style."
|
||||||
|
combined_prompt = (
|
||||||
|
f"Global: {global_context}\n"
|
||||||
|
f"{combined_scenes_text}\n"
|
||||||
|
"Req: 1 img per scene. Follow specific angles."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Visual Anchor applied to group prompt: {visual_anchor[:50]}..." if visual_anchor else "No visual_anchor")
|
||||||
|
|
||||||
|
# 记录 Prompt 长度供参考
|
||||||
|
logger.info(f"Doubao Group Prompt Length: {len(combined_prompt)} chars")
|
||||||
|
|
||||||
|
# 2. 准备 payload
|
||||||
|
payload = {
|
||||||
|
"model": config.DOUBAO_IMG_MODEL,
|
||||||
|
"prompt": combined_prompt,
|
||||||
|
"sequential_image_generation": "auto", # 开启组图
|
||||||
|
"sequential_image_generation_options": {
|
||||||
|
"max_images": len(scenes) # 限制最大张数
|
||||||
|
},
|
||||||
|
"response_format": "url",
|
||||||
|
"size": "1440x2560",
|
||||||
|
"stream": False,
|
||||||
|
"watermark": False
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. 处理参考图
|
||||||
|
img_urls = []
|
||||||
|
if reference_images:
|
||||||
|
for ref_path in reference_images:
|
||||||
|
if os.path.exists(ref_path):
|
||||||
|
try:
|
||||||
|
url = storage.upload_file(ref_path)
|
||||||
|
if url: img_urls.append(url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to upload ref image {ref_path}: {e}")
|
||||||
|
|
||||||
|
if img_urls:
|
||||||
|
payload["image_urls"] = img_urls
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {config.VOLC_API_KEY}"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Submitting Doubao Group Request (Scenes: {len(scenes)})...")
|
||||||
|
resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=240)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
if "data" in data:
|
||||||
|
items = data["data"]
|
||||||
|
logger.info(f"Doubao returned {len(items)} images.")
|
||||||
|
|
||||||
|
# 尝试将返回的图片映射回 Scene
|
||||||
|
# 假设顺序一致
|
||||||
|
for i, item in enumerate(items):
|
||||||
|
if i < len(scenes):
|
||||||
|
scene_id = scenes[i]["id"]
|
||||||
|
image_url = item.get("url")
|
||||||
|
|
||||||
|
if image_url:
|
||||||
|
# Download
|
||||||
|
img_resp = requests.get(image_url, timeout=60)
|
||||||
|
output_path = config.TEMP_DIR / f"scene_{scene_id}_{int(time.time())}.png"
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(img_resp.content)
|
||||||
|
results[scene_id] = str(output_path)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Doubao Group Generation Failed: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _generate_single_image(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
reference_images: List[str],
|
||||||
|
output_filename: str,
|
||||||
|
provider: str = "shubiaobiao"
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""统一入口"""
|
||||||
|
if provider == "doubao":
|
||||||
|
return self._generate_single_image_doubao(prompt, reference_images, output_filename)
|
||||||
|
elif provider == "gemini":
|
||||||
|
return self._generate_single_image_gemini(prompt, reference_images, output_filename)
|
||||||
|
else:
|
||||||
|
return self._generate_single_image_shubiao(prompt, reference_images, output_filename)
|
||||||
|
|
||||||
|
def _generate_single_image_doubao(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
reference_images: List[str],
|
||||||
|
output_filename: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""调用 Volcengine Doubao (Image API)"""
|
||||||
|
|
||||||
|
# 1. Upload all reference images to R2
|
||||||
|
img_urls = []
|
||||||
|
if reference_images:
|
||||||
|
for ref_path in reference_images:
|
||||||
|
if os.path.exists(ref_path):
|
||||||
|
try:
|
||||||
|
url = storage.upload_file(ref_path)
|
||||||
|
if url:
|
||||||
|
img_urls.append(url)
|
||||||
|
logger.info(f"Uploaded Doubao ref image: {url}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to upload Doubao ref image {ref_path}: {e}")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": config.DOUBAO_IMG_MODEL,
|
||||||
|
"prompt": prompt,
|
||||||
|
"sequential_image_generation": "disabled",
|
||||||
|
"response_format": "url",
|
||||||
|
"size": "1440x2560",
|
||||||
|
"stream": False,
|
||||||
|
"watermark": False
|
||||||
|
}
|
||||||
|
|
||||||
|
if img_urls:
|
||||||
|
payload["image_urls"] = img_urls
|
||||||
|
logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', image_urls={len(img_urls)}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', no reference images")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {config.VOLC_API_KEY}"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Submitting to Doubao Image: {self.endpoint}")
|
||||||
|
resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=180)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
msg = f"Doubao Image Failed ({resp.status_code}): {resp.text}"
|
||||||
|
logger.error(msg)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
if "data" in data and len(data["data"]) > 0:
|
||||||
|
image_url = data["data"][0].get("url")
|
||||||
|
if image_url:
|
||||||
|
img_resp = requests.get(image_url, timeout=60)
|
||||||
|
img_resp.raise_for_status()
|
||||||
|
|
||||||
|
output_path = config.TEMP_DIR / output_filename
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(img_resp.content)
|
||||||
|
return str(output_path)
|
||||||
|
|
||||||
|
raise RuntimeError(f"No image URL in Doubao response: {data}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Doubao Gen Failed: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _generate_single_image_shubiao(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
reference_images: List[str],
|
||||||
|
output_filename: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""调用 api2img.shubiaobiao.com 通道生成图片(同步返回 base64)"""
|
||||||
|
# 准备参考图,内联 base64 方式
|
||||||
|
parts = [{"text": prompt}]
|
||||||
|
|
||||||
|
# 严格过滤和排序参考图
|
||||||
|
valid_refs = []
|
||||||
|
if reference_images:
|
||||||
|
for p in reference_images:
|
||||||
|
if p and os.path.exists(p) and p not in valid_refs:
|
||||||
|
valid_refs.append(p)
|
||||||
|
|
||||||
|
logger.info(f"[Shubiaobiao] Input reference images ({len(valid_refs)}): {valid_refs}")
|
||||||
|
|
||||||
|
if valid_refs:
|
||||||
|
for ref_path in valid_refs:
|
||||||
|
try:
|
||||||
|
encoded = self._encode_image(ref_path)
|
||||||
|
if encoded:
|
||||||
|
parts.append({
|
||||||
|
"inlineData": {
|
||||||
|
"mimeType": "image/jpeg",
|
||||||
|
"data": encoded
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to encode image {ref_path}: {e}")
|
||||||
|
|
||||||
|
logger.info(f"[Shubiaobiao] Final payload parts count: {len(parts)} (1 prompt + {len(parts)-1} images)")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"contents": [{
|
||||||
|
"role": "user",
|
||||||
|
"parts": parts
|
||||||
|
}],
|
||||||
|
"generationConfig": {
|
||||||
|
"responseModalities": ["IMAGE"],
|
||||||
|
"imageConfig": {
|
||||||
|
"aspectRatio": "9:16",
|
||||||
|
"imageSize": "2K"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint = f"{config.SHUBIAOBIAO_IMG_BASE_URL}/v1beta/models/{config.SHUBIAOBIAO_IMG_MODEL_NAME}:generateContent"
|
||||||
|
headers = {
|
||||||
|
"x-goog-api-key": config.SHUBIAOBIAO_IMG_KEY,
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Submitting to Shubiaobiao Img: {endpoint}")
|
||||||
|
resp = requests.post(endpoint, json=payload, headers=headers, timeout=120)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
msg = f"Shubiaobiao 提交失败 ({resp.status_code}): {resp.text}"
|
||||||
|
logger.error(msg)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
# 查找 base64 图像
|
||||||
|
img_b64 = None
|
||||||
|
candidates = data.get("candidates") or []
|
||||||
|
if candidates:
|
||||||
|
content_parts = candidates[0].get("content", {}).get("parts", [])
|
||||||
|
for part in content_parts:
|
||||||
|
inline = part.get("inlineData") if isinstance(part, dict) else None
|
||||||
|
if inline and inline.get("data"):
|
||||||
|
img_b64 = inline["data"]
|
||||||
|
break
|
||||||
|
|
||||||
|
if not img_b64:
|
||||||
|
msg = f"Shubiaobiao 响应缺少图片数据: {data}"
|
||||||
|
logger.error(msg)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
output_path = config.TEMP_DIR / output_filename
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(base64.b64decode(img_b64))
|
||||||
|
|
||||||
|
logger.info(f"Shubiaobiao Generation Success: {output_path}")
|
||||||
|
return str(output_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Shubiaobiao Generation Exception: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _generate_single_image_gemini(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
reference_images: List[str],
|
||||||
|
output_filename: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""调用 Gemini (Wuyin Keji / NanoBanana-Pro) 生成单张图片"""
|
||||||
|
|
||||||
|
# 1. 构造 Payload
|
||||||
|
payload = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"aspectRatio": "9:16",
|
||||||
|
"imageSize": "2K"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 处理参考图 (Image-to-Image)
|
||||||
|
if reference_images:
|
||||||
|
valid_paths = []
|
||||||
|
seen = set()
|
||||||
|
for p in reference_images:
|
||||||
|
if p and os.path.exists(p) and p not in seen:
|
||||||
|
valid_paths.append(p)
|
||||||
|
seen.add(p)
|
||||||
|
|
||||||
|
if valid_paths:
|
||||||
|
img_urls = []
|
||||||
|
for ref_path in valid_paths:
|
||||||
|
try:
|
||||||
|
url = storage.upload_file(ref_path)
|
||||||
|
if url:
|
||||||
|
img_urls.append(url)
|
||||||
|
logger.info(f"Uploaded ref image: {url}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error uploading ref image {ref_path}: {e}")
|
||||||
|
|
||||||
|
if img_urls:
|
||||||
|
payload["img_url"] = img_urls
|
||||||
|
logger.info(f"Using {len(img_urls)} reference images for Gemini Img2Img")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": config.GEMINI_IMG_KEY,
|
||||||
|
"Content-Type": "application/json;charset:utf-8"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. 提交任务
|
||||||
|
try:
|
||||||
|
logger.info(f"Submitting to Gemini: {config.GEMINI_IMG_API_URL}")
|
||||||
|
resp = requests.post(config.GEMINI_IMG_API_URL, json=payload, headers=headers, timeout=30)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
msg = f"Gemini 提交失败 ({resp.status_code}): {resp.text}"
|
||||||
|
logger.error(msg)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
if data.get("code") != 200:
|
||||||
|
msg = f"Gemini 返回错误: {data}"
|
||||||
|
logger.error(msg)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
task_id = data.get("data", {}).get("id")
|
||||||
|
if not task_id:
|
||||||
|
raise RuntimeError(f"Gemini 响应缺少 task id: {data}")
|
||||||
|
|
||||||
|
logger.info(f"Gemini Task Submitted, ID: {task_id}")
|
||||||
|
|
||||||
|
# 3. 轮询状态
|
||||||
|
max_retries = 60
|
||||||
|
for i in range(max_retries):
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
poll_url = f"{config.GEMINI_IMG_DETAIL_URL}?key={config.GEMINI_IMG_KEY}&id={task_id}"
|
||||||
|
try:
|
||||||
|
poll_resp = requests.get(poll_url, headers=headers, timeout=30)
|
||||||
|
except requests.Timeout:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if poll_resp.status_code != 200:
|
||||||
|
continue
|
||||||
|
|
||||||
|
poll_data = poll_resp.json()
|
||||||
|
if poll_data.get("code") != 200:
|
||||||
|
raise RuntimeError(f"Gemini 轮询返回错误: {poll_data}")
|
||||||
|
|
||||||
|
result_data = poll_data.get("data", {}) or {}
|
||||||
|
status = result_data.get("status") # 0:排队, 1:生成中, 2:成功, 3:失败
|
||||||
|
|
||||||
|
if status == 2:
|
||||||
|
image_url = result_data.get("image_url")
|
||||||
|
if not image_url:
|
||||||
|
raise RuntimeError("Gemini 成功但缺少 image_url")
|
||||||
|
|
||||||
|
logger.info(f"Gemini Generation Success: {image_url}")
|
||||||
|
img_resp = requests.get(image_url, timeout=60)
|
||||||
|
img_resp.raise_for_status()
|
||||||
|
|
||||||
|
output_path = config.TEMP_DIR / output_filename
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(img_resp.content)
|
||||||
|
|
||||||
|
return str(output_path)
|
||||||
|
|
||||||
|
if status == 3:
|
||||||
|
fail_reason = result_data.get("fail_reason", "Unknown")
|
||||||
|
raise RuntimeError(f"Gemini 生成失败: {fail_reason}")
|
||||||
|
|
||||||
|
raise RuntimeError("Gemini 生成超时")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Gemini Generation Exception: {e}")
|
||||||
|
raise
|
||||||
60
modules/ingest.py
Normal file
60
modules/ingest.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""
|
||||||
|
MatchMe Studio - Ingest Module (Video Processing)
|
||||||
|
"""
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
import config
|
||||||
|
from modules import storage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def process_uploaded_video(video_path: str) -> Tuple[List[str], str]:
|
||||||
|
"""
|
||||||
|
Process uploaded video:
|
||||||
|
1. Upload raw video to R2.
|
||||||
|
2. Extract 3 keyframes (10%, 50%, 90%).
|
||||||
|
3. Return local frame paths and R2 video URL.
|
||||||
|
"""
|
||||||
|
if not Path(video_path).exists():
|
||||||
|
raise FileNotFoundError(f"Video not found: {video_path}")
|
||||||
|
|
||||||
|
logger.info(f"Processing video: {video_path}")
|
||||||
|
|
||||||
|
# 1. Upload to R2
|
||||||
|
video_url = storage.upload_file(video_path)
|
||||||
|
if not video_url:
|
||||||
|
raise RuntimeError("Failed to upload video to R2")
|
||||||
|
|
||||||
|
# 2. Extract Frames
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise IOError(f"Cannot open video: {video_path}")
|
||||||
|
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
frame_indices = [
|
||||||
|
int(total_frames * 0.1),
|
||||||
|
int(total_frames * 0.5),
|
||||||
|
int(total_frames * 0.9)
|
||||||
|
]
|
||||||
|
|
||||||
|
frame_urls = []
|
||||||
|
for i, idx in enumerate(frame_indices):
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if ret:
|
||||||
|
frame_name = f"frame_{Path(video_path).stem}_{i}.jpg"
|
||||||
|
frame_path = config.TEMP_DIR / frame_name
|
||||||
|
cv2.imwrite(str(frame_path), frame)
|
||||||
|
|
||||||
|
# Upload frame to R2 immediately
|
||||||
|
frame_url = storage.upload_file(str(frame_path))
|
||||||
|
if frame_url:
|
||||||
|
frame_urls.append(frame_url)
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
logger.info(f"Extracted and uploaded {len(frame_urls)} frames")
|
||||||
|
|
||||||
|
return frame_urls, video_url
|
||||||
151
modules/project.py
Normal file
151
modules/project.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
"""
|
||||||
|
MatchMe Studio - Project State Management (R2 Persistence)
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from dataclasses import dataclass, asdict, field
|
||||||
|
|
||||||
|
import config
|
||||||
|
from modules import storage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Scene:
|
||||||
|
id: int
|
||||||
|
duration: int = 5
|
||||||
|
timeline: str = ""
|
||||||
|
keyframe: Dict[str, str] = field(default_factory=dict)
|
||||||
|
camera_movement: str = ""
|
||||||
|
story_beat: str = ""
|
||||||
|
voiceover: str = ""
|
||||||
|
rhythm: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
image_url: str = ""
|
||||||
|
video_url: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Project:
|
||||||
|
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
|
||||||
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
status: str = "draft" # draft | analyzing | scripting | imaging | video | rendering | done
|
||||||
|
|
||||||
|
# Step 0: Input
|
||||||
|
input_mode: str = "" # text | images | video
|
||||||
|
prompt: str = ""
|
||||||
|
image_urls: List[str] = field(default_factory=list)
|
||||||
|
video_url: str = ""
|
||||||
|
asr_text: str = ""
|
||||||
|
|
||||||
|
# Step 1: Analysis
|
||||||
|
analysis: str = ""
|
||||||
|
questions: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
answers: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Step 2: Script
|
||||||
|
hook: str = ""
|
||||||
|
scenes: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
cta: str = ""
|
||||||
|
|
||||||
|
# Step 6: Final
|
||||||
|
final_video_url: str = ""
|
||||||
|
bgm_url: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def save_project(project: Project) -> str:
|
||||||
|
"""Save project state to R2 as JSON."""
|
||||||
|
data = asdict(project)
|
||||||
|
json_str = json.dumps(data, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# Write to temp file
|
||||||
|
temp_path = config.TEMP_DIR / f"project_{project.id}.json"
|
||||||
|
with open(temp_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(json_str)
|
||||||
|
|
||||||
|
# Upload to R2
|
||||||
|
object_name = f"projects/{project.id}.json"
|
||||||
|
s3 = storage.get_s3_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
s3.upload_file(
|
||||||
|
str(temp_path),
|
||||||
|
config.R2_BUCKET_NAME,
|
||||||
|
object_name,
|
||||||
|
ExtraArgs={'ContentType': 'application/json'}
|
||||||
|
)
|
||||||
|
logger.info(f"Project {project.id} saved to R2")
|
||||||
|
return project.id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save project: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def load_project(project_id: str) -> Optional[Project]:
|
||||||
|
"""Load project state from R2."""
|
||||||
|
object_name = f"projects/{project_id}.json"
|
||||||
|
temp_path = config.TEMP_DIR / f"project_{project_id}.json"
|
||||||
|
|
||||||
|
s3 = storage.get_s3_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
s3.download_file(config.R2_BUCKET_NAME, object_name, str(temp_path))
|
||||||
|
|
||||||
|
with open(temp_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Reconstruct Project
|
||||||
|
project = Project(
|
||||||
|
id=data.get("id", project_id),
|
||||||
|
created_at=data.get("created_at", ""),
|
||||||
|
status=data.get("status", "draft"),
|
||||||
|
input_mode=data.get("input_mode", ""),
|
||||||
|
prompt=data.get("prompt", ""),
|
||||||
|
image_urls=data.get("image_urls", []),
|
||||||
|
video_url=data.get("video_url", ""),
|
||||||
|
asr_text=data.get("asr_text", ""),
|
||||||
|
analysis=data.get("analysis", ""),
|
||||||
|
questions=data.get("questions", []),
|
||||||
|
answers=data.get("answers", {}),
|
||||||
|
hook=data.get("hook", ""),
|
||||||
|
scenes=data.get("scenes", []),
|
||||||
|
cta=data.get("cta", ""),
|
||||||
|
final_video_url=data.get("final_video_url", ""),
|
||||||
|
bgm_url=data.get("bgm_url", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Project {project_id} loaded from R2")
|
||||||
|
return project
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load project {project_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_project() -> Project:
|
||||||
|
"""Create a new project with unique ID."""
|
||||||
|
project = Project()
|
||||||
|
logger.info(f"Created new project: {project.id}")
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
390
modules/script_gen.py
Normal file
390
modules/script_gen.py
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
"""
|
||||||
|
脚本生成模块 (Gemini-3-Pro)
|
||||||
|
负责解析商品信息,生成分镜脚本
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import config
|
||||||
|
from modules.db_manager import db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ScriptGenerator:
|
||||||
|
"""分镜脚本生成器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.api_key = config.SHUBIAOBIAO_KEY
|
||||||
|
# 注意:API 地址可能需要适配 gemini-3-pro-preview 的具体路径
|
||||||
|
# 根据 demo: https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent
|
||||||
|
# 这里我们先假设 base_url 是 v1beta/models/
|
||||||
|
self.endpoint = "https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent"
|
||||||
|
|
||||||
|
# Default System Prompt
|
||||||
|
self.default_system_prompt = """
|
||||||
|
你是一个专业的抖音电商短视频导演。请根据提供的商品信息和图片,设计一个高转化率的商品详情页首图视频脚本。
|
||||||
|
|
||||||
|
## 目标
|
||||||
|
- 提升商品详情页的 GPM 和下单转化率
|
||||||
|
- 视频时长 9-12 秒 (由 3-4 个分镜组成)
|
||||||
|
- **每个分镜时长固定为 3 秒** (duration: 3),不要超过 3 秒
|
||||||
|
- 必须包含:目标人群分析、卖点提炼、分镜设计
|
||||||
|
|
||||||
|
## 分镜设计原则
|
||||||
|
1. **单分镜单主体**:每个分镜聚焦一个视觉主体或动作,避免复杂运镜,因为 AI 生视频在长时间(>3秒)容易出现画面异常。
|
||||||
|
2. **旁白跨分镜**:一段完整的旁白/卖点可以跨越多个分镜。在 voiceover_timeline 中,通过 start_time 和 duration (秒) 控制旁白的绝对时间位置,无需与分镜一一对应。
|
||||||
|
3. **节奏感**:分镜之间保持视觉连贯,通过景别变化(特写 -> 中景 -> 全景)制造节奏。
|
||||||
|
4. **语速控制**:旁白语速约 4 字/秒,12字旁白约需 3 秒。
|
||||||
|
|
||||||
|
## 输出格式要求 (JSON)
|
||||||
|
必须严格遵守以下 JSON 结构:
|
||||||
|
{
|
||||||
|
"product_name": "商品名称",
|
||||||
|
"visual_anchor": "商品视觉锚点:材质+颜色+形状+包装特征(用于保持生图一致性)",
|
||||||
|
"selling_points": ["卖点1", "卖点2"],
|
||||||
|
"target_audience": "目标人群描述",
|
||||||
|
"video_style": "视频风格关键词",
|
||||||
|
"bgm_style": "BGM风格关键词",
|
||||||
|
"voiceover_timeline": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"text": "旁白文案片段1(可横跨多个分镜)",
|
||||||
|
"subtitle": "字幕文案1 (简短有力)",
|
||||||
|
"start_time": 0.0,
|
||||||
|
"duration": 3.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"text": "旁白文案片段2",
|
||||||
|
"subtitle": "字幕文案2",
|
||||||
|
"start_time": 3.5,
|
||||||
|
"duration": 2.5
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"scenes": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"duration": 3,
|
||||||
|
"visual_prompt": "详细的画面描述,用于AI生图,包含主体、背景、构图、光影。英文描述。",
|
||||||
|
"video_prompt": "详细的动效描述,用于AI图生视频。英文描述。",
|
||||||
|
"fancy_text": {
|
||||||
|
"text": "花字文案 (最多6字)",
|
||||||
|
"style": "highlight",
|
||||||
|
"position": "center",
|
||||||
|
"start_time": 0.5,
|
||||||
|
"duration": 2.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
1. **visual_prompt**:
|
||||||
|
- 必须是英文。
|
||||||
|
- 描述要具体,例如 "Close-up shot of a hair clip, soft lighting, minimalist background".
|
||||||
|
- **CRITICAL**: 禁止 AI 额外生成装饰性文字、标语、水印。但必须保留商品包装自带的文字和 Logo(这是商品真实外观的一部分)。
|
||||||
|
- 正确写法: "Product front view, keep original packaging design --no added text --no watermarks"
|
||||||
|
- **EMPHASIS**: Strictly follow the appearance of the product in the reference images.
|
||||||
|
2. **video_prompt**: 必须是英文,描述动作,例如 "Slow zoom in, the hair clip rotates slightly"。注意保持动作简单,避免复杂运镜和人体动作。
|
||||||
|
3. **voiceover_timeline**:
|
||||||
|
- 这是整个视频的旁白和字幕时间轴,独立于分镜。
|
||||||
|
- `start_time` 是旁白开始的绝对时间 (秒),`duration` 是旁白持续时长 (秒)。
|
||||||
|
- **一段旁白可以横跨多个分镜**,例如:总时长 9 秒 (3 个分镜),一段旁白从 start_time=0,duration=5,则覆盖前两个分镜。
|
||||||
|
- 两段旁白之间留 0.3-0.5 秒间隙(气口)。
|
||||||
|
4. **fancy_text**:
|
||||||
|
- 花字要精简(最多 6 字),突出卖点。
|
||||||
|
- **Style Selection**:
|
||||||
|
- `highlight`: 默认样式,适合通用卖点 (Yellow/Black)。
|
||||||
|
- `warning`: 强调痛点或食欲 (Red/White)。
|
||||||
|
- `price`: 价格显示 (Big Red)。
|
||||||
|
- `bubble`: 旁白补充或用户评价 (Bubble)。
|
||||||
|
- `minimal`: 高级感,适合时尚类 (Thin/White)。
|
||||||
|
- `tech`: 数码类 (Cyan/Glow)。
|
||||||
|
- `position` 默认 `center`,可选 top/bottom/top-left/bottom-right 等。
|
||||||
|
5. **场景连贯性**: 确保分镜之间的逻辑和视觉风格连贯。每个分镜 duration 必须为 3。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _encode_image(self, image_path: str) -> str:
|
||||||
|
"""读取图片并转为 Base64"""
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
|
||||||
|
def generate_script(
|
||||||
|
self,
|
||||||
|
product_name: str,
|
||||||
|
product_info: Dict[str, Any],
|
||||||
|
image_paths: List[str] = None,
|
||||||
|
model_provider: str = "shubiaobiao" # "shubiaobiao" or "doubao"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
生成分镜脚本
|
||||||
|
"""
|
||||||
|
logger.info(f"Generating script for: {product_name} (Provider: {model_provider})")
|
||||||
|
|
||||||
|
# 1. 构造 Prompt (优先从数据库读取配置)
|
||||||
|
system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt)
|
||||||
|
user_prompt = self._build_user_prompt(product_name, product_info)
|
||||||
|
|
||||||
|
# Branch for Doubao
|
||||||
|
if model_provider == "doubao":
|
||||||
|
return self._generate_script_doubao(system_prompt, user_prompt, image_paths)
|
||||||
|
|
||||||
|
# ... Existing Shubiaobiao Logic ...
|
||||||
|
|
||||||
|
# 调试: 检查是否使用了自定义 Prompt
|
||||||
|
if system_prompt != self.default_system_prompt:
|
||||||
|
logger.info("Using CUSTOM system prompt from database")
|
||||||
|
else:
|
||||||
|
logger.info("Using DEFAULT system prompt")
|
||||||
|
|
||||||
|
# 2. 构造请求 Payload (Gemini/Shubiaobiao)
|
||||||
|
contents = []
|
||||||
|
|
||||||
|
# User message parts
|
||||||
|
user_parts = [{"text": user_prompt}]
|
||||||
|
|
||||||
|
# 添加图片 (Multimodal input)
|
||||||
|
if image_paths:
|
||||||
|
for path in image_paths[:10]: # 限制10张,Gemini-3-Pro 支持多图
|
||||||
|
if Path(path).exists():
|
||||||
|
try:
|
||||||
|
b64_img = self._encode_image(path)
|
||||||
|
user_parts.append({
|
||||||
|
"inline_data": {
|
||||||
|
"mime_type": "image/jpeg", # 假设是 JPG/PNG
|
||||||
|
"data": b64_img
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to encode image {path}: {e}")
|
||||||
|
|
||||||
|
contents.append({
|
||||||
|
"role": "user",
|
||||||
|
"parts": user_parts
|
||||||
|
})
|
||||||
|
|
||||||
|
# System instruction (Gemini 支持 system instruction 或者是放在 user prompt 前)
|
||||||
|
user_parts.insert(0, {"text": system_prompt})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"contents": contents,
|
||||||
|
"generationConfig": {
|
||||||
|
"response_mime_type": "application/json",
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"x-goog-api-key": self.api_key,
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. 调用 API
|
||||||
|
try:
|
||||||
|
response = requests.post(self.endpoint, headers=headers, json=payload, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# 4. 解析结果
|
||||||
|
if "candidates" in result and result["candidates"]:
|
||||||
|
content_text = result["candidates"][0]["content"]["parts"][0]["text"]
|
||||||
|
|
||||||
|
# 提取 JSON 部分 (处理 Markdown 代码块或纯文本)
|
||||||
|
script_json = self._extract_json_from_response(content_text)
|
||||||
|
|
||||||
|
if script_json is None:
|
||||||
|
logger.error(f"Failed to extract JSON from response: {content_text[:500]}...")
|
||||||
|
return None
|
||||||
|
|
||||||
|
final_script = self._validate_and_fix_script(script_json)
|
||||||
|
|
||||||
|
# Add Debug Info (包含原始输出)
|
||||||
|
final_script["_debug"] = {
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
|
"raw_output": content_text,
|
||||||
|
"provider": "shubiaobiao"
|
||||||
|
}
|
||||||
|
return final_script
|
||||||
|
else:
|
||||||
|
logger.error(f"No candidates in response: {result}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Script generation failed: {e}")
|
||||||
|
if 'response' in locals():
|
||||||
|
logger.error(f"Response content: {response.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _generate_script_doubao(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
image_paths: List[str]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Doubao 脚本生成实现 (Multimodal)"""
|
||||||
|
# User Provided: https://ark.cn-beijing.volces.com/api/v3/responses
|
||||||
|
# But for 'responses' API, structure is specific. Let's try to match user's curl format exactly but adapting content.
|
||||||
|
# User curl uses "input": [{"role": "user", "content": [{"type": "input_image"...}, {"type": "input_text"...}]}]
|
||||||
|
|
||||||
|
endpoint = "https://ark.cn-beijing.volces.com/api/v3/chat/completions" # Recommend standard Chat API first as 'responses' is usually non-standard or older
|
||||||
|
# However, user explicitly provided /responses curl. Let's try to stick to standard Chat Completions first because Doubao Pro 1.5 is OpenAI compatible.
|
||||||
|
# If that fails or if user insists on the specific structure, we can adapt.
|
||||||
|
# Volcengine 'ep-...' models are usually served via standard /chat/completions.
|
||||||
|
|
||||||
|
# Let's try standard OpenAI format which Doubao supports perfectly.
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt}
|
||||||
|
]
|
||||||
|
|
||||||
|
user_content = []
|
||||||
|
|
||||||
|
# Add Images (Doubao Vision supports image_url)
|
||||||
|
if image_paths:
|
||||||
|
for path in image_paths[:5]: # Limit
|
||||||
|
if os.path.exists(path):
|
||||||
|
# For Volcengine, need to upload or use base64?
|
||||||
|
# Standard OpenAI format supports base64 data urls.
|
||||||
|
# "image_url": {"url": "data:image/jpeg;base64,..."}
|
||||||
|
try:
|
||||||
|
b64_img = self._encode_image(path)
|
||||||
|
user_content.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{b64_img}"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to encode image for Doubao: {e}")
|
||||||
|
|
||||||
|
# Add Text
|
||||||
|
user_content.append({"type": "text", "text": user_prompt})
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": user_content
|
||||||
|
})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": config.DOUBAO_SCRIPT_MODEL,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": False,
|
||||||
|
# "response_format": {"type": "json_object"} # Try enabling JSON mode if supported
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {config.VOLC_API_KEY}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try standard chat/completions first
|
||||||
|
resp = requests.post(endpoint, headers=headers, json=payload, timeout=120)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
# If 404, maybe endpoint is wrong, try the user's 'responses' endpoint?
|
||||||
|
# But 'responses' usually implies a different payload structure.
|
||||||
|
logger.warning(f"Doubao Chat API failed ({resp.status_code}), trying legacy/custom endpoint...")
|
||||||
|
# Fallback to user provided structure if needed (implement later if this fails)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
result = resp.json()
|
||||||
|
content_text = result["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
script_json = self._extract_json_from_response(content_text)
|
||||||
|
|
||||||
|
if script_json is None:
|
||||||
|
logger.error(f"Failed to extract JSON from Doubao response: {content_text[:500]}...")
|
||||||
|
return None
|
||||||
|
|
||||||
|
final_script = self._validate_and_fix_script(script_json)
|
||||||
|
final_script["_debug"] = {
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
|
"raw_output": content_text,
|
||||||
|
"provider": "doubao"
|
||||||
|
}
|
||||||
|
return final_script
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Doubao script generation failed: {e}")
|
||||||
|
if 'resp' in locals():
|
||||||
|
logger.error(f"Response: {resp.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_json_from_response(self, text: str) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
从 API 响应中提取 JSON 对象
|
||||||
|
支持:
|
||||||
|
1. 纯 JSON 响应
|
||||||
|
2. Markdown 代码块包裹的 JSON (```json ... ```)
|
||||||
|
3. 文本中嵌入的 JSON (找到第一个 { 和最后一个 })
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# 方法1: 尝试直接解析(纯 JSON 情况)
|
||||||
|
try:
|
||||||
|
return json.loads(text.strip())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 方法2: 提取 ```json ... ``` 代码块
|
||||||
|
json_block_match = re.search(r'```json\s*([\s\S]*?)\s*```', text)
|
||||||
|
if json_block_match:
|
||||||
|
try:
|
||||||
|
return json.loads(json_block_match.group(1))
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"JSON block found but parse failed: {e}")
|
||||||
|
|
||||||
|
# 方法3: 提取 ``` ... ``` 代码块 (无 json 标记)
|
||||||
|
code_block_match = re.search(r'```\s*([\s\S]*?)\s*```', text)
|
||||||
|
if code_block_match:
|
||||||
|
try:
|
||||||
|
return json.loads(code_block_match.group(1))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 方法4: 找到第一个 { 和最后一个 } 之间的内容
|
||||||
|
first_brace = text.find('{')
|
||||||
|
last_brace = text.rfind('}')
|
||||||
|
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
|
||||||
|
try:
|
||||||
|
return json.loads(text[first_brace:last_brace + 1])
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"Brace extraction failed: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _build_user_prompt(self, product_name: str, product_info: Dict[str, Any]) -> str:
|
||||||
|
# 提取商家偏好提示
|
||||||
|
style_hint = product_info.get("style_hint", "")
|
||||||
|
# 过滤掉不需要展示的字段
|
||||||
|
filtered_info = {k: v for k, v in product_info.items() if k not in ["uploaded_images", "style_hint"]}
|
||||||
|
info_str = "\n".join([f"- {k}: {v}" for k, v in filtered_info.items()])
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
商品名称:{product_name}
|
||||||
|
商品信息:
|
||||||
|
{info_str}
|
||||||
|
"""
|
||||||
|
if style_hint:
|
||||||
|
prompt += f"""
|
||||||
|
## 商家特别要求
|
||||||
|
{style_hint}
|
||||||
|
"""
|
||||||
|
prompt += "\n请根据以上信息设计视频脚本。"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _validate_and_fix_script(self, script: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""校验并修复脚本结构"""
|
||||||
|
# 简单校验,确保必要字段存在
|
||||||
|
if "scenes" not in script:
|
||||||
|
script["scenes"] = []
|
||||||
|
return script
|
||||||
84
modules/storage.py
Normal file
84
modules/storage.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""
|
||||||
|
MatchMe Studio - Storage Module (R2)
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import NoCredentialsError
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def get_s3_client():
|
||||||
|
try:
|
||||||
|
return boto3.client(
|
||||||
|
's3',
|
||||||
|
endpoint_url=config.R2_ENDPOINT,
|
||||||
|
aws_access_key_id=config.R2_ACCESS_KEY,
|
||||||
|
aws_secret_access_key=config.R2_SECRET_KEY,
|
||||||
|
region_name='auto'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create R2 client: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def upload_file(file_path: str) -> Optional[str]:
|
||||||
|
"""Upload file to R2 and return Public URL."""
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
logger.error(f"File not found: {file_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 使用 UUID 作为文件名,避免中文/特殊字符导致的 URL 问题
|
||||||
|
original_name = Path(file_path).name
|
||||||
|
ext = Path(file_path).suffix.lower() or ".bin"
|
||||||
|
object_name = f"{uuid.uuid4().hex}{ext}"
|
||||||
|
|
||||||
|
s3 = get_s3_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Uploading {original_name} to R2 as {object_name}...")
|
||||||
|
|
||||||
|
# 根据后缀设置正确的 Content-Type
|
||||||
|
if ext == ".png":
|
||||||
|
content_type = "image/png"
|
||||||
|
elif ext in [".jpg", ".jpeg"]:
|
||||||
|
content_type = "image/jpeg"
|
||||||
|
elif ext == ".mp4":
|
||||||
|
content_type = "video/mp4"
|
||||||
|
elif ext == ".mp3":
|
||||||
|
content_type = "audio/mpeg"
|
||||||
|
else:
|
||||||
|
content_type = "application/octet-stream"
|
||||||
|
|
||||||
|
s3.upload_file(
|
||||||
|
file_path,
|
||||||
|
config.R2_BUCKET_NAME,
|
||||||
|
object_name,
|
||||||
|
ExtraArgs={'ContentType': content_type}
|
||||||
|
)
|
||||||
|
|
||||||
|
public_url = f"{config.R2_PUBLIC_URL}/{object_name}"
|
||||||
|
logger.info(f"Upload successful: {public_url}")
|
||||||
|
return public_url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"R2 Upload Failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def cleanup_temp(max_age_seconds: int = 3600):
|
||||||
|
"""Delete old temp files."""
|
||||||
|
logger.info("Running cleanup_temp...")
|
||||||
|
now = time.time()
|
||||||
|
if not config.TEMP_DIR.exists(): return
|
||||||
|
|
||||||
|
for f in config.TEMP_DIR.iterdir():
|
||||||
|
try:
|
||||||
|
if f.is_file() and (now - f.stat().st_mtime) > max_age_seconds:
|
||||||
|
f.unlink()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete {f}: {e}")
|
||||||
76
modules/styles.py
Normal file
76
modules/styles.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""
|
||||||
|
花字样式预设库
|
||||||
|
供 Design Agent 和 Renderer 使用
|
||||||
|
"""
|
||||||
|
|
||||||
|
STYLES = {
|
||||||
|
# 1. 醒目强调 (黄色高亮)
|
||||||
|
"highlight": {
|
||||||
|
"font_size": 60,
|
||||||
|
"font_color": "#FFE66D", # 亮黄
|
||||||
|
"stroke": {"color": "#000000", "width": 4},
|
||||||
|
"shadow": {"color": "#000000", "blur": 8, "offset": [4, 4], "opacity": 0.6}
|
||||||
|
},
|
||||||
|
|
||||||
|
# 2. 警告/痛点 (红色/黑色背景)
|
||||||
|
"warning": {
|
||||||
|
"font_size": 55,
|
||||||
|
"font_color": "#FFFFFF",
|
||||||
|
"stroke": {"color": "#FF0000", "width": 0}, # 无描边
|
||||||
|
"background": {
|
||||||
|
"type": "box",
|
||||||
|
"color": "#FF4D4F", # 红色背景
|
||||||
|
"corner_radius": 12,
|
||||||
|
"padding": [15, 25, 15, 25] # t, r, b, l
|
||||||
|
},
|
||||||
|
"shadow": {"color": "#990000", "blur": 0, "offset": [0, 6], "opacity": 0.4} # 立体感阴影
|
||||||
|
},
|
||||||
|
|
||||||
|
# 3. 价格/促销 (大号红色)
|
||||||
|
"price": {
|
||||||
|
"font_size": 90,
|
||||||
|
"font_color": "#FF2E2E", # 鲜红
|
||||||
|
"stroke": {"color": "#FFFFFF", "width": 6}, # 白边
|
||||||
|
"shadow": {"color": "#FF9999", "blur": 15, "offset": [0, 0], "opacity": 0.8} # 发光效果
|
||||||
|
},
|
||||||
|
|
||||||
|
# 4. 对话/气泡 (黑字白底圆角)
|
||||||
|
"bubble": {
|
||||||
|
"font_size": 48,
|
||||||
|
"font_color": "#333333",
|
||||||
|
"background": {
|
||||||
|
"type": "box",
|
||||||
|
"color": "#FFFFFF",
|
||||||
|
"corner_radius": 40, # 大圆角
|
||||||
|
"padding": [20, 40, 20, 40]
|
||||||
|
},
|
||||||
|
"shadow": {"color": "#000000", "blur": 10, "offset": [2, 5], "opacity": 0.2}
|
||||||
|
},
|
||||||
|
|
||||||
|
# 5. 时尚/极简 (细黑体+白字)
|
||||||
|
"minimal": {
|
||||||
|
"font_size": 65,
|
||||||
|
"font_color": "#FFFFFF",
|
||||||
|
"stroke": {"color": "#000000", "width": 2},
|
||||||
|
"shadow": {"color": "#000000", "blur": 2, "offset": [2, 2], "opacity": 0.8},
|
||||||
|
"font_family": "NotoSansSC-Regular.otf" # 假设有这个字体,或者回退
|
||||||
|
},
|
||||||
|
|
||||||
|
# 6. 科技/未来 (青色+发光)
|
||||||
|
"tech": {
|
||||||
|
"font_size": 60,
|
||||||
|
"font_color": "#00FFFF",
|
||||||
|
"stroke": {"color": "#003333", "width": 3},
|
||||||
|
"shadow": {"color": "#00FFFF", "blur": 20, "offset": [0, 0], "opacity": 0.9}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_style(style_name: str) -> dict:
|
||||||
|
"""获取样式配置,支持回退"""
|
||||||
|
return STYLES.get(style_name, STYLES["highlight"])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
251
modules/text_renderer.py
Normal file
251
modules/text_renderer.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
"""
|
||||||
|
通用文本渲染引擎
|
||||||
|
支持原子化设计参数,供上游 Design Agent 灵活调用
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Tuple, Union, Optional
|
||||||
|
|
||||||
|
from PIL import Image, ImageDraw, ImageFont, ImageFilter, ImageColor
|
||||||
|
|
||||||
|
import config
|
||||||
|
from modules.styles import get_style
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 缓存目录
|
||||||
|
CACHE_DIR = config.TEMP_DIR / "text_renderer_cache"
|
||||||
|
CACHE_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TextRenderer:
|
||||||
|
"""
|
||||||
|
通用文本渲染器
|
||||||
|
基于原子化参数渲染文本图片 (PNG)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.default_font_path = self._resolve_font_path(None)
|
||||||
|
|
||||||
|
def _resolve_font_path(self, font_family: Optional[str]) -> str:
|
||||||
|
"""解析字体路径,支持多级回退"""
|
||||||
|
candidates = []
|
||||||
|
if font_family:
|
||||||
|
# 1. 尝试作为绝对路径
|
||||||
|
candidates.append(font_family)
|
||||||
|
# 2. 尝试在 assets/fonts 下查找
|
||||||
|
candidates.append(str(config.FONTS_DIR / font_family))
|
||||||
|
if not font_family.endswith(".ttf") and not font_family.endswith(".otf"):
|
||||||
|
candidates.append(str(config.FONTS_DIR / f"{font_family}.ttf"))
|
||||||
|
candidates.append(str(config.FONTS_DIR / f"{font_family}.otf"))
|
||||||
|
|
||||||
|
# 3. 预设项目字体
|
||||||
|
candidates.extend([
|
||||||
|
str(config.FONTS_DIR / "SmileySans-Oblique.ttf"),
|
||||||
|
str(config.FONTS_DIR / "AlibabaPuHuiTi-Bold.ttf"),
|
||||||
|
str(config.FONTS_DIR / "AlibabaPuHuiTi-Regular.ttf"),
|
||||||
|
str(config.FONTS_DIR / "NotoSansSC-Bold.otf"), # 假如有效
|
||||||
|
])
|
||||||
|
|
||||||
|
# 4. 系统字体回退
|
||||||
|
candidates.extend([
|
||||||
|
"/System/Library/Fonts/PingFang.ttc",
|
||||||
|
"/System/Library/Fonts/STHeiti Medium.ttc",
|
||||||
|
"C:/Windows/Fonts/msyh.ttc",
|
||||||
|
"C:/Windows/Fonts/simhei.ttf",
|
||||||
|
])
|
||||||
|
|
||||||
|
for path in candidates:
|
||||||
|
if path and os.path.exists(path):
|
||||||
|
# 简单验证文件大小
|
||||||
|
try:
|
||||||
|
if os.path.getsize(path) > 10000:
|
||||||
|
return path
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.warning("No valid font found, using default load_default()")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_font(self, font_path: str, size: int) -> ImageFont.FreeTypeFont:
|
||||||
|
try:
|
||||||
|
if font_path:
|
||||||
|
return ImageFont.truetype(font_path, size)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load font {font_path}: {e}")
|
||||||
|
return ImageFont.load_default()
|
||||||
|
|
||||||
|
def _parse_color(self, color: Union[str, Tuple]) -> Tuple[int, int, int, int]:
|
||||||
|
"""解析颜色为 RGBA"""
|
||||||
|
if isinstance(color, str):
|
||||||
|
if color.startswith("#"):
|
||||||
|
rgb = ImageColor.getrgb(color)
|
||||||
|
return rgb + (255,)
|
||||||
|
# TODO: 支持 'rgba(r,g,b,a)' 格式
|
||||||
|
if isinstance(color, tuple):
|
||||||
|
if len(color) == 3:
|
||||||
|
return color + (255,)
|
||||||
|
return color
|
||||||
|
return (0, 0, 0, 255)
|
||||||
|
|
||||||
|
def render(self, text: str, style: Union[Dict[str, Any], str], cache: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
渲染文本并返回图片路径
|
||||||
|
|
||||||
|
style 结构:
|
||||||
|
{
|
||||||
|
"font_family": str,
|
||||||
|
"font_size": int,
|
||||||
|
"font_color": str,
|
||||||
|
"stroke": [{"color": str, "width": int}, ...],
|
||||||
|
"shadow": {"color": str, "blur": int, "offset": [x, y], "opacity": float},
|
||||||
|
"background": {
|
||||||
|
"type": "box", "color": str/list, "corner_radius": int, "padding": [t, r, b, l]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# 0. 解析样式
|
||||||
|
if isinstance(style, str):
|
||||||
|
style = get_style(style)
|
||||||
|
|
||||||
|
# 1. 缓存检查
|
||||||
|
cache_key = hashlib.md5(f"{text}_{str(style)}".encode()).hexdigest()
|
||||||
|
if cache:
|
||||||
|
cache_path = CACHE_DIR / f"{cache_key}.png"
|
||||||
|
if cache_path.exists():
|
||||||
|
return str(cache_path)
|
||||||
|
|
||||||
|
# 2. 解析基本参数
|
||||||
|
font_path = self._resolve_font_path(style.get("font_family"))
|
||||||
|
font_size = style.get("font_size", 60)
|
||||||
|
font = self._get_font(font_path, font_size)
|
||||||
|
font_color = self._parse_color(style.get("font_color", "#FFFFFF"))
|
||||||
|
|
||||||
|
# 3. 测量文本尺寸
|
||||||
|
dummy_draw = ImageDraw.Draw(Image.new("RGBA", (1, 1)))
|
||||||
|
bbox = dummy_draw.textbbox((0, 0), text, font=font)
|
||||||
|
text_w = bbox[2] - bbox[0]
|
||||||
|
text_h = bbox[3] - bbox[1]
|
||||||
|
|
||||||
|
# 4. 计算总尺寸 (包含 padding, stroke, shadow)
|
||||||
|
strokes = style.get("stroke", [])
|
||||||
|
if isinstance(strokes, dict): strokes = [strokes] # 兼容旧格式
|
||||||
|
|
||||||
|
max_stroke = 0
|
||||||
|
for s in strokes:
|
||||||
|
max_stroke = max(max_stroke, s.get("width", 0))
|
||||||
|
|
||||||
|
shadow = style.get("shadow", {})
|
||||||
|
shadow_blur = shadow.get("blur", 0)
|
||||||
|
shadow_offset = shadow.get("offset", [0, 0])
|
||||||
|
|
||||||
|
bg = style.get("background", {})
|
||||||
|
padding = bg.get("padding", [0, 0, 0, 0])
|
||||||
|
if isinstance(padding, int): padding = [padding] * 4
|
||||||
|
if len(padding) == 2: padding = [padding[0], padding[1], padding[0], padding[1]] # v, h -> t, r, b, l
|
||||||
|
|
||||||
|
# 内容区域尺寸 (文本 + padding)
|
||||||
|
content_w = text_w + padding[1] + padding[3]
|
||||||
|
content_h = text_h + padding[0] + padding[2]
|
||||||
|
|
||||||
|
# 扩展区域 (描边 + 阴影)
|
||||||
|
extra_margin = max_stroke + shadow_blur + max(abs(shadow_offset[0]), abs(shadow_offset[1])) + 10
|
||||||
|
|
||||||
|
canvas_w = content_w + extra_margin * 2
|
||||||
|
canvas_h = content_h + extra_margin * 2
|
||||||
|
|
||||||
|
# 5. 创建画布
|
||||||
|
img = Image.new("RGBA", (int(canvas_w), int(canvas_h)), (0, 0, 0, 0))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
# 锚点位置 (文本中心点)
|
||||||
|
center_x = canvas_w // 2
|
||||||
|
center_y = canvas_h // 2
|
||||||
|
|
||||||
|
# 6. 绘制顺序: 阴影 -> 背景 -> 描边 -> 文本
|
||||||
|
|
||||||
|
# --- 绘制阴影 (针对整个块) ---
|
||||||
|
if shadow:
|
||||||
|
shadow_color = self._parse_color(shadow.get("color", "#000000"))
|
||||||
|
opacity = shadow.get("opacity", 0.5)
|
||||||
|
shadow_color = (shadow_color[0], shadow_color[1], shadow_color[2], int(255 * opacity))
|
||||||
|
|
||||||
|
# 临时画布绘制形状用于生成阴影
|
||||||
|
shadow_layer = Image.new("RGBA", (int(canvas_w), int(canvas_h)), (0, 0, 0, 0))
|
||||||
|
shadow_draw = ImageDraw.Draw(shadow_layer)
|
||||||
|
|
||||||
|
# 如果有背景,阴影跟随背景形状;否则跟随文字
|
||||||
|
if bg and bg.get("type") != "none":
|
||||||
|
self._draw_background(shadow_draw, bg, center_x, center_y, content_w, content_h, shadow_color)
|
||||||
|
else:
|
||||||
|
# 文字阴影
|
||||||
|
txt_x = center_x - text_w / 2
|
||||||
|
txt_y = center_y - text_h / 2
|
||||||
|
shadow_draw.text((txt_x, txt_y), text, font=font, fill=shadow_color)
|
||||||
|
# 描边阴影
|
||||||
|
for s in strokes:
|
||||||
|
width = s.get("width", 0)
|
||||||
|
# 简单模拟描边阴影:多次绘制
|
||||||
|
# (略: 完整描边阴影开销大,暂只做文字阴影)
|
||||||
|
|
||||||
|
# 应用模糊
|
||||||
|
if shadow_blur > 0:
|
||||||
|
shadow_layer = shadow_layer.filter(ImageFilter.GaussianBlur(shadow_blur))
|
||||||
|
|
||||||
|
# 应用偏移
|
||||||
|
final_shadow = Image.new("RGBA", (int(canvas_w), int(canvas_h)), (0, 0, 0, 0))
|
||||||
|
final_shadow.paste(shadow_layer, (int(shadow_offset[0]), int(shadow_offset[1])), mask=shadow_layer)
|
||||||
|
|
||||||
|
img = Image.alpha_composite(final_shadow, img)
|
||||||
|
draw = ImageDraw.Draw(img) # 重置 draw
|
||||||
|
|
||||||
|
# --- 绘制背景 ---
|
||||||
|
if bg and bg.get("type") in ["box", "circle"]:
|
||||||
|
bg_color = self._parse_color(bg.get("color", "#000000"))
|
||||||
|
# TODO: 支持渐变背景
|
||||||
|
self._draw_background(draw, bg, center_x, center_y, content_w, content_h, bg_color)
|
||||||
|
|
||||||
|
# --- 绘制描边 (仅针对文字) ---
|
||||||
|
# 从外向内绘制
|
||||||
|
txt_x = center_x - text_w / 2
|
||||||
|
txt_y = center_y - text_h / 2
|
||||||
|
|
||||||
|
for s in reversed(strokes):
|
||||||
|
color = self._parse_color(s.get("color", "#000000"))
|
||||||
|
width = s.get("width", 0)
|
||||||
|
if width > 0:
|
||||||
|
# 通过偏移模拟描边 (Pillow stroke_width 效果一般,但这里先用原生参数)
|
||||||
|
draw.text((txt_x, txt_y), text, font=font, fill=color, stroke_width=width, stroke_fill=color)
|
||||||
|
|
||||||
|
# --- 绘制文字 ---
|
||||||
|
draw.text((txt_x, txt_y), text, font=font, fill=font_color)
|
||||||
|
|
||||||
|
# 7. 裁剪多余透明区域
|
||||||
|
bbox = img.getbbox()
|
||||||
|
if bbox:
|
||||||
|
img = img.crop(bbox)
|
||||||
|
|
||||||
|
# 8. 保存
|
||||||
|
output_path = str(CACHE_DIR / f"{cache_key}.png")
|
||||||
|
img.save(output_path)
|
||||||
|
logger.info(f"Rendered text: {text} -> {output_path}")
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
def _draw_background(self, draw, bg, cx, cy, w, h, color):
|
||||||
|
"""绘制背景形状"""
|
||||||
|
corner_radius = bg.get("corner_radius", 0)
|
||||||
|
x0 = cx - w / 2
|
||||||
|
y0 = cy - h / 2
|
||||||
|
x1 = cx + w / 2
|
||||||
|
y1 = cy + h / 2
|
||||||
|
|
||||||
|
if bg.get("type") == "box":
|
||||||
|
draw.rounded_rectangle([x0, y0, x1, y1], radius=corner_radius, fill=color)
|
||||||
|
elif bg.get("type") == "circle":
|
||||||
|
draw.ellipse([x0, y0, x1, y1], fill=color)
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
renderer = TextRenderer()
|
||||||
177
modules/utils.py
Normal file
177
modules/utils.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
"""
|
||||||
|
Gloda Video Factory - Utility Functions
|
||||||
|
Handles font management, Auto-QC, and helper effects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import urllib.request
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from moviepy.editor import ImageClip, VideoFileClip, AudioFileClip
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Google Fonts CDN URL
|
||||||
|
ROBOTO_BOLD_URL = "https://github.com/googlefonts/roboto/raw/main/src/hinted/Roboto-Bold.ttf"
|
||||||
|
NOTO_SC_BOLD_URL = "https://raw.githubusercontent.com/google/fonts/main/ofl/notosanssc/NotoSansSC-Bold.ttf"
|
||||||
|
|
||||||
|
FONT_PATH_EN = config.FONTS_DIR / "Roboto-Bold.ttf"
|
||||||
|
FONT_PATH_CN = config.FONTS_DIR / "NotoSansSC-Bold.ttf"
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_fonts() -> Path:
|
||||||
|
"""Ensure required fonts (EN & CN) are available."""
|
||||||
|
config.FONTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# English Font
|
||||||
|
if not FONT_PATH_EN.exists():
|
||||||
|
logger.info(f"Downloading Roboto-Bold font...")
|
||||||
|
try:
|
||||||
|
urllib.request.urlretrieve(ROBOTO_BOLD_URL, FONT_PATH_EN)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download EN font: {e}")
|
||||||
|
|
||||||
|
# Chinese Font
|
||||||
|
if not FONT_PATH_CN.exists():
|
||||||
|
logger.info(f"Downloading NotoSansSC-Bold font...")
|
||||||
|
try:
|
||||||
|
# Using a reliable mirror or source if Github raw is flaky, but trying Github first
|
||||||
|
urllib.request.urlretrieve(NOTO_SC_BOLD_URL, FONT_PATH_CN)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download CN font: {e}")
|
||||||
|
|
||||||
|
# Return CN font as default for mixed text
|
||||||
|
if FONT_PATH_CN.exists():
|
||||||
|
return FONT_PATH_CN
|
||||||
|
return FONT_PATH_EN
|
||||||
|
|
||||||
|
|
||||||
|
def check_imagemagick() -> bool:
|
||||||
|
"""Check if ImageMagick is installed."""
|
||||||
|
import shutil
|
||||||
|
if shutil.which("convert"):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning("ImageMagick not found. Text overlays may fail.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def verify_assets(video_path: str, audio_path: str) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Auto-QC: Verify generated assets quality.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
1. File size sanity check
|
||||||
|
2. Duration matching (+/- 2s tolerance)
|
||||||
|
3. Audio silence check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Passed: bool, Reason: str)
|
||||||
|
"""
|
||||||
|
logger.info(f"Running Auto-QC on:\nVideo: {video_path}\nAudio: {audio_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. File Size Check
|
||||||
|
vid_size = os.path.getsize(video_path)
|
||||||
|
if vid_size < 50 * 1024: # < 50KB
|
||||||
|
return False, f"Video file too small ({vid_size/1024:.1f}KB). Likely error/black screen."
|
||||||
|
|
||||||
|
aud_size = os.path.getsize(audio_path)
|
||||||
|
if aud_size < 5 * 1024: # < 5KB
|
||||||
|
return False, f"Audio file too small ({aud_size/1024:.1f}KB)."
|
||||||
|
|
||||||
|
# 2. Duration Check
|
||||||
|
try:
|
||||||
|
v_clip = VideoFileClip(video_path)
|
||||||
|
a_clip = AudioFileClip(audio_path)
|
||||||
|
|
||||||
|
v_dur = v_clip.duration
|
||||||
|
a_dur = a_clip.duration
|
||||||
|
|
||||||
|
# Check for silence (RMS)
|
||||||
|
# Read first 2 seconds of audio
|
||||||
|
chunk = a_clip.to_soundarray(fps=44100, nbytes=2, buffersize=1000)
|
||||||
|
if chunk is not None:
|
||||||
|
rms = np.sqrt(np.mean(chunk**2))
|
||||||
|
if rms < 0.001:
|
||||||
|
v_clip.close()
|
||||||
|
a_clip.close()
|
||||||
|
return False, "Audio appears to be silent (RMS < 0.001)"
|
||||||
|
|
||||||
|
v_clip.close()
|
||||||
|
a_clip.close()
|
||||||
|
|
||||||
|
# Tolerance check
|
||||||
|
if abs(v_dur - a_dur) > 2.0:
|
||||||
|
return False, f"Duration mismatch: Video={v_dur:.1f}s, Audio={a_dur:.1f}s"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return False, f"Media analysis failed: {str(e)}"
|
||||||
|
|
||||||
|
return True, "QC Passed"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auto-QC Error: {e}")
|
||||||
|
return False, f"QC System Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def apply_ken_burns(
|
||||||
|
image_path: str,
|
||||||
|
duration: float = 5.0,
|
||||||
|
zoom_ratio: float = 1.2,
|
||||||
|
output_path: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
"""Apply Ken Burns effect (slow zoom in) to a static image."""
|
||||||
|
if output_path is None:
|
||||||
|
base_name = Path(image_path).stem
|
||||||
|
output_path = str(config.OUTPUT_DIR / f"{base_name}_ken_burns.mp4")
|
||||||
|
|
||||||
|
logger.info(f"Applying Ken Burns effect to {image_path}")
|
||||||
|
|
||||||
|
img = Image.open(image_path)
|
||||||
|
img_width, img_height = img.size
|
||||||
|
target_width = config.VIDEO_SETTINGS["width"]
|
||||||
|
target_height = config.VIDEO_SETTINGS["height"]
|
||||||
|
fps = config.VIDEO_SETTINGS["fps"]
|
||||||
|
|
||||||
|
scale_w = (target_width * zoom_ratio) / img_width
|
||||||
|
scale_h = (target_height * zoom_ratio) / img_height
|
||||||
|
base_scale = max(scale_w, scale_h)
|
||||||
|
|
||||||
|
new_width = int(img_width * base_scale)
|
||||||
|
new_height = int(img_height * base_scale)
|
||||||
|
img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||||
|
img_array = np.array(img_resized)
|
||||||
|
|
||||||
|
def make_frame(t):
|
||||||
|
progress = t / duration
|
||||||
|
eased_progress = 0.5 - 0.5 * np.cos(np.pi * progress)
|
||||||
|
current_zoom = 1 + (zoom_ratio - 1) * eased_progress
|
||||||
|
|
||||||
|
crop_width = int(target_width / current_zoom * (new_width / target_width))
|
||||||
|
crop_height = int(target_height / current_zoom * (new_height / target_height))
|
||||||
|
|
||||||
|
crop_width = min(crop_width, new_width)
|
||||||
|
crop_height = min(crop_height, new_height)
|
||||||
|
|
||||||
|
x_start = (new_width - crop_width) // 2
|
||||||
|
y_start = (new_height - crop_height) // 2
|
||||||
|
|
||||||
|
cropped = img_array[y_start:y_start + crop_height, x_start:x_start + crop_width]
|
||||||
|
cropped_pil = Image.fromarray(cropped)
|
||||||
|
resized = cropped_pil.resize((target_width, target_height), Image.Resampling.LANCZOS)
|
||||||
|
return np.array(resized)
|
||||||
|
|
||||||
|
clip = ImageClip(make_frame, duration=duration)
|
||||||
|
clip = clip.set_fps(fps)
|
||||||
|
clip.write_videofile(output_path, fps=fps, codec=config.VIDEO_SETTINGS["codec"], audio=False, logger=None)
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
return output_path
|
||||||
269
modules/video_gen.py
Normal file
269
modules/video_gen.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""
|
||||||
|
图生视频模块 (Volcengine Doubao-SeedDance)
|
||||||
|
负责将分镜图片转换为视频片段
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import config
|
||||||
|
from modules import storage
|
||||||
|
from modules.db_manager import db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class VideoGenerator:
|
||||||
|
"""图生视频生成器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.api_key = config.VOLC_API_KEY
|
||||||
|
self.base_url = config.VOLC_BASE_URL
|
||||||
|
self.model_id = config.VIDEO_MODEL_ID
|
||||||
|
|
||||||
|
self.headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
def submit_scene_video_task(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
scene_id: int,
|
||||||
|
image_path: str,
|
||||||
|
prompt: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
提交单场景视频生成任务
|
||||||
|
Returns: task_id or None
|
||||||
|
"""
|
||||||
|
if not image_path or not os.path.exists(image_path):
|
||||||
|
logger.warning(f"Skipping video generation for Scene {scene_id}: Image not found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 上传图片到 R2 获取 URL
|
||||||
|
logger.info(f"Uploading image for Scene {scene_id}...")
|
||||||
|
image_url = storage.upload_file(image_path)
|
||||||
|
|
||||||
|
if not image_url:
|
||||||
|
logger.error(f"Failed to upload image for Scene {scene_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(f"Submitting video task for Scene {scene_id}...")
|
||||||
|
task_id = self._submit_task(image_url, prompt)
|
||||||
|
|
||||||
|
if task_id:
|
||||||
|
# 立即保存 task_id 到数据库,状态为 processing
|
||||||
|
db.save_asset(
|
||||||
|
project_id=project_id,
|
||||||
|
scene_id=scene_id,
|
||||||
|
asset_type="video",
|
||||||
|
status="processing",
|
||||||
|
task_id=task_id,
|
||||||
|
local_path=None
|
||||||
|
)
|
||||||
|
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
def recover_video_from_task(self, task_id: str, output_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
尝试从已有的 task_id 恢复视频 (查询状态并下载)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
status, video_url = self._check_task(task_id)
|
||||||
|
logger.info(f"Recovering task {task_id}: status={status}")
|
||||||
|
|
||||||
|
if status == "succeeded" and video_url:
|
||||||
|
downloaded_path = self._download_video(video_url, os.path.basename(output_path))
|
||||||
|
if downloaded_path:
|
||||||
|
# 如果下载的文件名和目标路径不一致 (download_video 使用 filename 参数拼接到 TEMP_DIR),
|
||||||
|
# 需要移动或确认。 _download_video 返回完整路径。
|
||||||
|
# 如果 output_path 是绝对路径且不同,则移动。
|
||||||
|
if os.path.abspath(downloaded_path) != os.path.abspath(output_path):
|
||||||
|
import shutil
|
||||||
|
shutil.move(downloaded_path, output_path)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to recover video task {task_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_task_status(self, task_id: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
查询任务状态
|
||||||
|
Returns: (status, video_url)
|
||||||
|
"""
|
||||||
|
return self._check_task(task_id)
|
||||||
|
|
||||||
|
def generate_scene_videos(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
script: Dict[str, Any],
|
||||||
|
scene_images: Dict[int, str]
|
||||||
|
) -> Dict[int, str]:
|
||||||
|
"""
|
||||||
|
批量生成分镜视频 (Legacy: 阻塞式轮询)
|
||||||
|
"""
|
||||||
|
generated_videos = {}
|
||||||
|
tasks = {} # scene_id -> task_id
|
||||||
|
|
||||||
|
scenes = script.get("scenes", [])
|
||||||
|
|
||||||
|
# 1. 提交所有任务
|
||||||
|
for scene in scenes:
|
||||||
|
scene_id = scene["id"]
|
||||||
|
image_path = scene_images.get(scene_id)
|
||||||
|
prompt = scene.get("video_prompt", "High quality video")
|
||||||
|
|
||||||
|
# Use new method signature with project_id
|
||||||
|
task_id = self.submit_scene_video_task(project_id, scene_id, image_path, prompt)
|
||||||
|
|
||||||
|
if task_id:
|
||||||
|
tasks[scene_id] = task_id
|
||||||
|
logger.info(f"Task submitted: {task_id}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to submit task for Scene {scene_id}")
|
||||||
|
|
||||||
|
# 2. 轮询任务状态
|
||||||
|
pending_tasks = list(tasks.keys())
|
||||||
|
|
||||||
|
# 设置最大轮询时间 (例如 10 分钟)
|
||||||
|
start_time = time.time()
|
||||||
|
timeout = 600
|
||||||
|
|
||||||
|
while pending_tasks and (time.time() - start_time < timeout):
|
||||||
|
logger.info(f"Polling status for {len(pending_tasks)} tasks...")
|
||||||
|
|
||||||
|
still_pending = []
|
||||||
|
for scene_id in pending_tasks:
|
||||||
|
task_id = tasks[scene_id]
|
||||||
|
status, result_url = self._check_task(task_id)
|
||||||
|
|
||||||
|
if status == "succeeded":
|
||||||
|
logger.info(f"Scene {scene_id} video generated successfully")
|
||||||
|
# 下载视频
|
||||||
|
video_path = self._download_video(result_url, f"scene_{scene_id}_video.mp4")
|
||||||
|
if video_path:
|
||||||
|
generated_videos[scene_id] = video_path
|
||||||
|
# Update DB
|
||||||
|
db.save_asset(
|
||||||
|
project_id=project_id,
|
||||||
|
scene_id=scene_id,
|
||||||
|
asset_type="video",
|
||||||
|
status="completed",
|
||||||
|
local_path=video_path,
|
||||||
|
task_id=task_id
|
||||||
|
)
|
||||||
|
elif status == "failed" or status == "cancelled":
|
||||||
|
logger.error(f"Scene {scene_id} task failed/cancelled")
|
||||||
|
db.save_asset(
|
||||||
|
project_id=project_id,
|
||||||
|
scene_id=scene_id,
|
||||||
|
asset_type="video",
|
||||||
|
status="failed",
|
||||||
|
task_id=task_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# running, queued
|
||||||
|
still_pending.append(scene_id)
|
||||||
|
|
||||||
|
pending_tasks = still_pending
|
||||||
|
if pending_tasks:
|
||||||
|
time.sleep(5) # 间隔 5 秒
|
||||||
|
|
||||||
|
return generated_videos
|
||||||
|
|
||||||
|
def _submit_task(self, image_url: str, prompt: str) -> str:
|
||||||
|
"""提交生成任务"""
|
||||||
|
url = f"{self.base_url}/contents/generations/tasks"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model_id,
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"{prompt} --resolution 1080p --duration 3 --camerafixed false --watermark false"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": image_url}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
# ID might be at top level or in data object depending on exact API version response
|
||||||
|
# Document says: { "id": "...", "status": "..." } or similar
|
||||||
|
task_id = data.get("id")
|
||||||
|
if not task_id and "data" in data:
|
||||||
|
task_id = data.get("data", {}).get("id")
|
||||||
|
|
||||||
|
return task_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Task submission failed: {e}")
|
||||||
|
if 'response' in locals():
|
||||||
|
logger.error(f"Response: {response.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _check_task(self, task_id: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
检查任务状态
|
||||||
|
Returns: (status, content_url)
|
||||||
|
Status: queued, running, succeeded, failed, cancelled
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/contents/generations/tasks/{task_id}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, headers=self.headers, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# API Response structure:
|
||||||
|
# { "id": "...", "status": "succeeded", "content": [ { "url": "...", "video_url": "..." } ] }
|
||||||
|
# Or nested in "data" key
|
||||||
|
|
||||||
|
result = data
|
||||||
|
if "data" in data and "status" not in data: # Check if wrapped in data
|
||||||
|
result = data["data"]
|
||||||
|
|
||||||
|
status = result.get("status")
|
||||||
|
content_url = None
|
||||||
|
|
||||||
|
if status == "succeeded":
|
||||||
|
if "content" in result:
|
||||||
|
content = result["content"]
|
||||||
|
if isinstance(content, list) and len(content) > 0:
|
||||||
|
item = content[0]
|
||||||
|
content_url = item.get("video_url") or item.get("url")
|
||||||
|
elif isinstance(content, dict):
|
||||||
|
content_url = content.get("video_url") or content.get("url")
|
||||||
|
|
||||||
|
return status, content_url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Check task failed: {e}")
|
||||||
|
return "unknown", None
|
||||||
|
|
||||||
|
def _download_video(self, url: str, filename: str) -> str:
|
||||||
|
"""下载视频到临时目录"""
|
||||||
|
if not url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url, stream=True, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
output_path = config.TEMP_DIR / filename
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
return str(output_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Download video failed: {e}")
|
||||||
|
return None
|
||||||
31
requirements.txt
Normal file
31
requirements.txt
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# Gloda Video Factory - Dependencies
|
||||||
|
# Python 3.10+
|
||||||
|
|
||||||
|
# Core LLM
|
||||||
|
openai>=1.0.0
|
||||||
|
|
||||||
|
# Image Generation
|
||||||
|
fal-client>=0.4.0
|
||||||
|
|
||||||
|
# Video Generation (Real Mode)
|
||||||
|
PyJWT>=2.8.0
|
||||||
|
requests>=2.31.0
|
||||||
|
|
||||||
|
# Audio Generation
|
||||||
|
elevenlabs>=1.0.0
|
||||||
|
gTTS>=2.4.0
|
||||||
|
|
||||||
|
# Video Processing
|
||||||
|
moviepy==1.0.3
|
||||||
|
imageio[ffmpeg]>=2.33.0
|
||||||
|
Pillow>=10.0.0
|
||||||
|
numpy>=1.24.0
|
||||||
|
|
||||||
|
# Web UI
|
||||||
|
streamlit>=1.29.0
|
||||||
|
|
||||||
|
# Config
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
PyYAML>=6.0.1
|
||||||
|
boto3>=1.34.0
|
||||||
|
|
||||||
116
volcengine_binary_demo/examples/volcengine/binary.py
Normal file
116
volcengine_binary_demo/examples/volcengine/binary.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from protocols import MsgType, full_client_request, receive_message
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cluster(voice: str) -> str:
|
||||||
|
if voice.startswith("S_"):
|
||||||
|
return "volcano_icl"
|
||||||
|
return "volcano_tts"
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--appid", required=True, help="APP ID")
|
||||||
|
parser.add_argument("--access_token", required=True, help="Access Token")
|
||||||
|
parser.add_argument("--voice_type", required=True, help="Voice type")
|
||||||
|
parser.add_argument("--cluster", default="", help="Cluster name")
|
||||||
|
parser.add_argument("--text", required=True, help="Text to convert")
|
||||||
|
parser.add_argument("--encoding", default="wav", help="Output file encoding")
|
||||||
|
parser.add_argument(
|
||||||
|
"--endpoint",
|
||||||
|
default="wss://openspeech.bytedance.com/api/v1/tts/ws_binary",
|
||||||
|
help="WebSocket endpoint URL",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Determine cluster
|
||||||
|
cluster = args.cluster if args.cluster else get_cluster(args.voice_type)
|
||||||
|
|
||||||
|
# Connect to server
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer;{args.access_token}",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Connecting to {args.endpoint} with headers: {headers}")
|
||||||
|
websocket = await websockets.connect(
|
||||||
|
args.endpoint, additional_headers=headers, max_size=10 * 1024 * 1024
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Connected to WebSocket server, Logid: {websocket.response.headers['x-tt-logid']}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare request payload
|
||||||
|
request = {
|
||||||
|
"app": {
|
||||||
|
"appid": args.appid,
|
||||||
|
"token": args.access_token,
|
||||||
|
"cluster": cluster,
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"uid": str(uuid.uuid4()),
|
||||||
|
},
|
||||||
|
"audio": {
|
||||||
|
"voice_type": args.voice_type,
|
||||||
|
"encoding": args.encoding,
|
||||||
|
},
|
||||||
|
"request": {
|
||||||
|
"reqid": str(uuid.uuid4()),
|
||||||
|
"text": args.text,
|
||||||
|
"operation": "submit",
|
||||||
|
"with_timestamp": "1",
|
||||||
|
"extra_param": json.dumps(
|
||||||
|
{
|
||||||
|
"disable_markdown_filter": False,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send request
|
||||||
|
await full_client_request(websocket, json.dumps(request).encode())
|
||||||
|
|
||||||
|
# Receive audio data
|
||||||
|
audio_data = bytearray()
|
||||||
|
while True:
|
||||||
|
msg = await receive_message(websocket)
|
||||||
|
|
||||||
|
if msg.type == MsgType.FrontEndResultServer:
|
||||||
|
continue
|
||||||
|
elif msg.type == MsgType.AudioOnlyServer:
|
||||||
|
audio_data.extend(msg.payload)
|
||||||
|
if msg.sequence < 0: # Last message
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"TTS conversion failed: {msg}")
|
||||||
|
|
||||||
|
# Check if we received any audio data
|
||||||
|
if not audio_data:
|
||||||
|
raise RuntimeError("No audio data received")
|
||||||
|
|
||||||
|
# Save audio file
|
||||||
|
filename = f"{args.voice_type}.{args.encoding}"
|
||||||
|
with open(filename, "wb") as f:
|
||||||
|
f.write(audio_data)
|
||||||
|
logger.info(f"Audio received: {len(audio_data)}, saved to {filename}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await websocket.close()
|
||||||
|
logger.info("Connection closed")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
41
volcengine_binary_demo/protocols/__init__.py
Normal file
41
volcengine_binary_demo/protocols/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from .protocols import (
|
||||||
|
CompressionBits,
|
||||||
|
EventType,
|
||||||
|
HeaderSizeBits,
|
||||||
|
Message,
|
||||||
|
MsgType,
|
||||||
|
MsgTypeFlagBits,
|
||||||
|
SerializationBits,
|
||||||
|
VersionBits,
|
||||||
|
audio_only_client,
|
||||||
|
cancel_session,
|
||||||
|
finish_connection,
|
||||||
|
finish_session,
|
||||||
|
full_client_request,
|
||||||
|
receive_message,
|
||||||
|
start_connection,
|
||||||
|
start_session,
|
||||||
|
task_request,
|
||||||
|
wait_for_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CompressionBits",
|
||||||
|
"EventType",
|
||||||
|
"HeaderSizeBits",
|
||||||
|
"Message",
|
||||||
|
"MsgType",
|
||||||
|
"MsgTypeFlagBits",
|
||||||
|
"SerializationBits",
|
||||||
|
"VersionBits",
|
||||||
|
"audio_only_client",
|
||||||
|
"cancel_session",
|
||||||
|
"finish_connection",
|
||||||
|
"finish_session",
|
||||||
|
"full_client_request",
|
||||||
|
"receive_message",
|
||||||
|
"start_connection",
|
||||||
|
"start_session",
|
||||||
|
"task_request",
|
||||||
|
"wait_for_event",
|
||||||
|
]
|
||||||
543
volcengine_binary_demo/protocols/protocols.py
Normal file
543
volcengine_binary_demo/protocols/protocols.py
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import struct
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MsgType(IntEnum):
|
||||||
|
"""Message type enumeration"""
|
||||||
|
|
||||||
|
Invalid = 0
|
||||||
|
FullClientRequest = 0b1
|
||||||
|
AudioOnlyClient = 0b10
|
||||||
|
FullServerResponse = 0b1001
|
||||||
|
AudioOnlyServer = 0b1011
|
||||||
|
FrontEndResultServer = 0b1100
|
||||||
|
Error = 0b1111
|
||||||
|
|
||||||
|
# Alias
|
||||||
|
ServerACK = AudioOnlyServer
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name if self.name else f"MsgType({self.value})"
|
||||||
|
|
||||||
|
|
||||||
|
class MsgTypeFlagBits(IntEnum):
|
||||||
|
"""Message type flag bits"""
|
||||||
|
|
||||||
|
NoSeq = 0 # Non-terminal packet with no sequence
|
||||||
|
PositiveSeq = 0b1 # Non-terminal packet with sequence > 0
|
||||||
|
LastNoSeq = 0b10 # Last packet with no sequence
|
||||||
|
NegativeSeq = 0b11 # Last packet with sequence < 0
|
||||||
|
WithEvent = 0b100 # Payload contains event number (int32)
|
||||||
|
|
||||||
|
|
||||||
|
class VersionBits(IntEnum):
|
||||||
|
"""Version bits"""
|
||||||
|
|
||||||
|
Version1 = 1
|
||||||
|
Version2 = 2
|
||||||
|
Version3 = 3
|
||||||
|
Version4 = 4
|
||||||
|
|
||||||
|
|
||||||
|
class HeaderSizeBits(IntEnum):
|
||||||
|
"""Header size bits"""
|
||||||
|
|
||||||
|
HeaderSize4 = 1
|
||||||
|
HeaderSize8 = 2
|
||||||
|
HeaderSize12 = 3
|
||||||
|
HeaderSize16 = 4
|
||||||
|
|
||||||
|
|
||||||
|
class SerializationBits(IntEnum):
|
||||||
|
"""Serialization method bits"""
|
||||||
|
|
||||||
|
Raw = 0
|
||||||
|
JSON = 0b1
|
||||||
|
Thrift = 0b11
|
||||||
|
Custom = 0b1111
|
||||||
|
|
||||||
|
|
||||||
|
class CompressionBits(IntEnum):
|
||||||
|
"""Compression method bits"""
|
||||||
|
|
||||||
|
None_ = 0
|
||||||
|
Gzip = 0b1
|
||||||
|
Custom = 0b1111
|
||||||
|
|
||||||
|
|
||||||
|
class EventType(IntEnum):
|
||||||
|
"""Event type enumeration"""
|
||||||
|
|
||||||
|
None_ = 0 # Default event
|
||||||
|
|
||||||
|
# 1 ~ 49 Upstream Connection events
|
||||||
|
StartConnection = 1
|
||||||
|
StartTask = 1 # Alias of StartConnection
|
||||||
|
FinishConnection = 2
|
||||||
|
FinishTask = 2 # Alias of FinishConnection
|
||||||
|
|
||||||
|
# 50 ~ 99 Downstream Connection events
|
||||||
|
ConnectionStarted = 50 # Connection established successfully
|
||||||
|
TaskStarted = 50 # Alias of ConnectionStarted
|
||||||
|
ConnectionFailed = 51 # Connection failed (possibly due to authentication failure)
|
||||||
|
TaskFailed = 51 # Alias of ConnectionFailed
|
||||||
|
ConnectionFinished = 52 # Connection ended
|
||||||
|
TaskFinished = 52 # Alias of ConnectionFinished
|
||||||
|
|
||||||
|
# 100 ~ 149 Upstream Session events
|
||||||
|
StartSession = 100
|
||||||
|
CancelSession = 101
|
||||||
|
FinishSession = 102
|
||||||
|
|
||||||
|
# 150 ~ 199 Downstream Session events
|
||||||
|
SessionStarted = 150
|
||||||
|
SessionCanceled = 151
|
||||||
|
SessionFinished = 152
|
||||||
|
SessionFailed = 153
|
||||||
|
UsageResponse = 154 # Usage response
|
||||||
|
ChargeData = 154 # Alias of UsageResponse
|
||||||
|
|
||||||
|
# 200 ~ 249 Upstream general events
|
||||||
|
TaskRequest = 200
|
||||||
|
UpdateConfig = 201
|
||||||
|
|
||||||
|
# 250 ~ 299 Downstream general events
|
||||||
|
AudioMuted = 250
|
||||||
|
|
||||||
|
# 300 ~ 349 Upstream TTS events
|
||||||
|
SayHello = 300
|
||||||
|
|
||||||
|
# 350 ~ 399 Downstream TTS events
|
||||||
|
TTSSentenceStart = 350
|
||||||
|
TTSSentenceEnd = 351
|
||||||
|
TTSResponse = 352
|
||||||
|
TTSEnded = 359
|
||||||
|
PodcastRoundStart = 360
|
||||||
|
PodcastRoundResponse = 361
|
||||||
|
PodcastRoundEnd = 362
|
||||||
|
|
||||||
|
# 450 ~ 499 Downstream ASR events
|
||||||
|
ASRInfo = 450
|
||||||
|
ASRResponse = 451
|
||||||
|
ASREnded = 459
|
||||||
|
|
||||||
|
# 500 ~ 549 Upstream dialogue events
|
||||||
|
ChatTTSText = 500 # (Ground-Truth-Alignment) text for speech synthesis
|
||||||
|
|
||||||
|
# 550 ~ 599 Downstream dialogue events
|
||||||
|
ChatResponse = 550
|
||||||
|
ChatEnded = 559
|
||||||
|
|
||||||
|
# 650 ~ 699 Downstream dialogue events
|
||||||
|
# Events for source (original) language subtitle
|
||||||
|
SourceSubtitleStart = 650
|
||||||
|
SourceSubtitleResponse = 651
|
||||||
|
SourceSubtitleEnd = 652
|
||||||
|
# Events for target (translation) language subtitle
|
||||||
|
TranslationSubtitleStart = 653
|
||||||
|
TranslationSubtitleResponse = 654
|
||||||
|
TranslationSubtitleEnd = 655
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name if self.name else f"EventType({self.value})"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message:
|
||||||
|
"""Message object
|
||||||
|
|
||||||
|
Message format:
|
||||||
|
0 1 2 3
|
||||||
|
| 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 |
|
||||||
|
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||||
|
| Version | Header Size | Msg Type | Flags |
|
||||||
|
| (4 bits) | (4 bits) | (4 bits) | (4 bits) |
|
||||||
|
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||||
|
| Serialization | Compression | Reserved |
|
||||||
|
| (4 bits) | (4 bits) | (8 bits) |
|
||||||
|
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||||
|
| |
|
||||||
|
| Optional Header Extensions |
|
||||||
|
| (if Header Size > 1) |
|
||||||
|
| |
|
||||||
|
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||||
|
| |
|
||||||
|
| Payload |
|
||||||
|
| (variable length) |
|
||||||
|
| |
|
||||||
|
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||||
|
"""
|
||||||
|
|
||||||
|
version: VersionBits = VersionBits.Version1
|
||||||
|
header_size: HeaderSizeBits = HeaderSizeBits.HeaderSize4
|
||||||
|
type: MsgType = MsgType.Invalid
|
||||||
|
flag: MsgTypeFlagBits = MsgTypeFlagBits.NoSeq
|
||||||
|
serialization: SerializationBits = SerializationBits.JSON
|
||||||
|
compression: CompressionBits = CompressionBits.None_
|
||||||
|
|
||||||
|
event: EventType = EventType.None_
|
||||||
|
session_id: str = ""
|
||||||
|
connect_id: str = ""
|
||||||
|
sequence: int = 0
|
||||||
|
error_code: int = 0
|
||||||
|
|
||||||
|
payload: bytes = b""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, data: bytes) -> "Message":
|
||||||
|
"""Create message object from bytes"""
|
||||||
|
if len(data) < 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"Data too short: expected at least 3 bytes, got {len(data)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
type_and_flag = data[1]
|
||||||
|
msg_type = MsgType(type_and_flag >> 4)
|
||||||
|
flag = MsgTypeFlagBits(type_and_flag & 0b00001111)
|
||||||
|
|
||||||
|
msg = cls(type=msg_type, flag=flag)
|
||||||
|
msg.unmarshal(data)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def marshal(self) -> bytes:
|
||||||
|
"""Serialize message to bytes"""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
|
||||||
|
# Write header
|
||||||
|
header = [
|
||||||
|
(self.version << 4) | self.header_size,
|
||||||
|
(self.type << 4) | self.flag,
|
||||||
|
(self.serialization << 4) | self.compression,
|
||||||
|
]
|
||||||
|
|
||||||
|
header_size = 4 * self.header_size
|
||||||
|
if padding := header_size - len(header):
|
||||||
|
header.extend([0] * padding)
|
||||||
|
|
||||||
|
buffer.write(bytes(header))
|
||||||
|
|
||||||
|
# Write other fields
|
||||||
|
writers = self._get_writers()
|
||||||
|
for writer in writers:
|
||||||
|
writer(buffer)
|
||||||
|
|
||||||
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
def unmarshal(self, data: bytes) -> None:
|
||||||
|
"""Deserialize message from bytes"""
|
||||||
|
buffer = io.BytesIO(data)
|
||||||
|
|
||||||
|
# Read version and header size
|
||||||
|
version_and_header_size = buffer.read(1)[0]
|
||||||
|
self.version = VersionBits(version_and_header_size >> 4)
|
||||||
|
self.header_size = HeaderSizeBits(version_and_header_size & 0b00001111)
|
||||||
|
|
||||||
|
# Skip second byte
|
||||||
|
buffer.read(1)
|
||||||
|
|
||||||
|
# Read serialization and compression methods
|
||||||
|
serialization_compression = buffer.read(1)[0]
|
||||||
|
self.serialization = SerializationBits(serialization_compression >> 4)
|
||||||
|
self.compression = CompressionBits(serialization_compression & 0b00001111)
|
||||||
|
|
||||||
|
# Skip header padding
|
||||||
|
header_size = 4 * self.header_size
|
||||||
|
read_size = 3
|
||||||
|
if padding_size := header_size - read_size:
|
||||||
|
buffer.read(padding_size)
|
||||||
|
|
||||||
|
# Read other fields
|
||||||
|
readers = self._get_readers()
|
||||||
|
for reader in readers:
|
||||||
|
reader(buffer)
|
||||||
|
|
||||||
|
# Check for remaining data
|
||||||
|
remaining = buffer.read()
|
||||||
|
if remaining:
|
||||||
|
raise ValueError(f"Unexpected data after message: {remaining}")
|
||||||
|
|
||||||
|
def _get_writers(self) -> List[Callable[[io.BytesIO], None]]:
|
||||||
|
"""Get list of writer functions"""
|
||||||
|
writers = []
|
||||||
|
|
||||||
|
if self.flag == MsgTypeFlagBits.WithEvent:
|
||||||
|
writers.extend([self._write_event, self._write_session_id])
|
||||||
|
|
||||||
|
if self.type in [
|
||||||
|
MsgType.FullClientRequest,
|
||||||
|
MsgType.FullServerResponse,
|
||||||
|
MsgType.FrontEndResultServer,
|
||||||
|
MsgType.AudioOnlyClient,
|
||||||
|
MsgType.AudioOnlyServer,
|
||||||
|
]:
|
||||||
|
if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
|
||||||
|
writers.append(self._write_sequence)
|
||||||
|
elif self.type == MsgType.Error:
|
||||||
|
writers.append(self._write_error_code)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported message type: {self.type}")
|
||||||
|
|
||||||
|
writers.append(self._write_payload)
|
||||||
|
return writers
|
||||||
|
|
||||||
|
def _get_readers(self) -> List[Callable[[io.BytesIO], None]]:
|
||||||
|
"""Get list of reader functions"""
|
||||||
|
readers = []
|
||||||
|
|
||||||
|
if self.type in [
|
||||||
|
MsgType.FullClientRequest,
|
||||||
|
MsgType.FullServerResponse,
|
||||||
|
MsgType.FrontEndResultServer,
|
||||||
|
MsgType.AudioOnlyClient,
|
||||||
|
MsgType.AudioOnlyServer,
|
||||||
|
]:
|
||||||
|
if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
|
||||||
|
readers.append(self._read_sequence)
|
||||||
|
elif self.type == MsgType.Error:
|
||||||
|
readers.append(self._read_error_code)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported message type: {self.type}")
|
||||||
|
|
||||||
|
if self.flag == MsgTypeFlagBits.WithEvent:
|
||||||
|
readers.extend(
|
||||||
|
[self._read_event, self._read_session_id, self._read_connect_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
readers.append(self._read_payload)
|
||||||
|
return readers
|
||||||
|
|
||||||
|
def _write_event(self, buffer: io.BytesIO) -> None:
|
||||||
|
"""Write event"""
|
||||||
|
buffer.write(struct.pack(">i", self.event))
|
||||||
|
|
||||||
|
def _write_session_id(self, buffer: io.BytesIO) -> None:
|
||||||
|
"""Write session ID"""
|
||||||
|
if self.event in [
|
||||||
|
EventType.StartConnection,
|
||||||
|
EventType.FinishConnection,
|
||||||
|
EventType.ConnectionStarted,
|
||||||
|
EventType.ConnectionFailed,
|
||||||
|
]:
|
||||||
|
return
|
||||||
|
|
||||||
|
session_id_bytes = self.session_id.encode("utf-8")
|
||||||
|
size = len(session_id_bytes)
|
||||||
|
if size > 0xFFFFFFFF:
|
||||||
|
raise ValueError(f"Session ID size ({size}) exceeds max(uint32)")
|
||||||
|
|
||||||
|
buffer.write(struct.pack(">I", size))
|
||||||
|
if size > 0:
|
||||||
|
buffer.write(session_id_bytes)
|
||||||
|
|
||||||
|
def _write_sequence(self, buffer: io.BytesIO) -> None:
|
||||||
|
"""Write sequence number"""
|
||||||
|
buffer.write(struct.pack(">i", self.sequence))
|
||||||
|
|
||||||
|
def _write_error_code(self, buffer: io.BytesIO) -> None:
|
||||||
|
"""Write error code"""
|
||||||
|
buffer.write(struct.pack(">I", self.error_code))
|
||||||
|
|
||||||
|
def _write_payload(self, buffer: io.BytesIO) -> None:
|
||||||
|
"""Write payload"""
|
||||||
|
size = len(self.payload)
|
||||||
|
if size > 0xFFFFFFFF:
|
||||||
|
raise ValueError(f"Payload size ({size}) exceeds max(uint32)")
|
||||||
|
|
||||||
|
buffer.write(struct.pack(">I", size))
|
||||||
|
buffer.write(self.payload)
|
||||||
|
|
||||||
|
def _read_event(self, buffer: io.BytesIO) -> None:
|
||||||
|
"""Read event"""
|
||||||
|
event_bytes = buffer.read(4)
|
||||||
|
if event_bytes:
|
||||||
|
self.event = EventType(struct.unpack(">i", event_bytes)[0])
|
||||||
|
|
||||||
|
def _read_session_id(self, buffer: io.BytesIO) -> None:
|
||||||
|
"""Read session ID"""
|
||||||
|
if self.event in [
|
||||||
|
EventType.StartConnection,
|
||||||
|
EventType.FinishConnection,
|
||||||
|
EventType.ConnectionStarted,
|
||||||
|
EventType.ConnectionFailed,
|
||||||
|
EventType.ConnectionFinished,
|
||||||
|
]:
|
||||||
|
return
|
||||||
|
|
||||||
|
size_bytes = buffer.read(4)
|
||||||
|
if size_bytes:
|
||||||
|
size = struct.unpack(">I", size_bytes)[0]
|
||||||
|
if size > 0:
|
||||||
|
session_id_bytes = buffer.read(size)
|
||||||
|
if len(session_id_bytes) == size:
|
||||||
|
self.session_id = session_id_bytes.decode("utf-8")
|
||||||
|
|
||||||
|
def _read_connect_id(self, buffer: io.BytesIO) -> None:
|
||||||
|
"""Read connection ID"""
|
||||||
|
if self.event in [
|
||||||
|
EventType.ConnectionStarted,
|
||||||
|
EventType.ConnectionFailed,
|
||||||
|
EventType.ConnectionFinished,
|
||||||
|
]:
|
||||||
|
size_bytes = buffer.read(4)
|
||||||
|
if size_bytes:
|
||||||
|
size = struct.unpack(">I", size_bytes)[0]
|
||||||
|
if size > 0:
|
||||||
|
self.connect_id = buffer.read(size).decode("utf-8")
|
||||||
|
|
||||||
|
def _read_sequence(self, buffer: io.BytesIO) -> None:
|
||||||
|
"""Read sequence number"""
|
||||||
|
sequence_bytes = buffer.read(4)
|
||||||
|
if sequence_bytes:
|
||||||
|
self.sequence = struct.unpack(">i", sequence_bytes)[0]
|
||||||
|
|
||||||
|
def _read_error_code(self, buffer: io.BytesIO) -> None:
|
||||||
|
"""Read error code"""
|
||||||
|
error_code_bytes = buffer.read(4)
|
||||||
|
if error_code_bytes:
|
||||||
|
self.error_code = struct.unpack(">I", error_code_bytes)[0]
|
||||||
|
|
||||||
|
def _read_payload(self, buffer: io.BytesIO) -> None:
|
||||||
|
"""Read payload"""
|
||||||
|
size_bytes = buffer.read(4)
|
||||||
|
if size_bytes:
|
||||||
|
size = struct.unpack(">I", size_bytes)[0]
|
||||||
|
if size > 0:
|
||||||
|
self.payload = buffer.read(size)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation"""
|
||||||
|
if self.type in [MsgType.AudioOnlyServer, MsgType.AudioOnlyClient]:
|
||||||
|
if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
|
||||||
|
return f"MsgType: {self.type}, EventType:{self.event}, Sequence: {self.sequence}, PayloadSize: {len(self.payload)}"
|
||||||
|
return f"MsgType: {self.type}, EventType:{self.event}, PayloadSize: {len(self.payload)}"
|
||||||
|
elif self.type == MsgType.Error:
|
||||||
|
return f"MsgType: {self.type}, EventType:{self.event}, ErrorCode: {self.error_code}, Payload: {self.payload.decode('utf-8', 'ignore')}"
|
||||||
|
else:
|
||||||
|
if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
|
||||||
|
return f"MsgType: {self.type}, EventType:{self.event}, Sequence: {self.sequence}, Payload: {self.payload.decode('utf-8', 'ignore')}"
|
||||||
|
return f"MsgType: {self.type}, EventType:{self.event}, Payload: {self.payload.decode('utf-8', 'ignore')}"
|
||||||
|
|
||||||
|
|
||||||
|
async def receive_message(websocket: websockets.WebSocketClientProtocol) -> Message:
|
||||||
|
"""Receive message from websocket"""
|
||||||
|
try:
|
||||||
|
data = await websocket.recv()
|
||||||
|
if isinstance(data, str):
|
||||||
|
raise ValueError(f"Unexpected text message: {data}")
|
||||||
|
elif isinstance(data, bytes):
|
||||||
|
msg = Message.from_bytes(data)
|
||||||
|
logger.info(f"Received: {msg}")
|
||||||
|
return msg
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected message type: {type(data)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to receive message: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_event(
|
||||||
|
websocket: websockets.WebSocketClientProtocol,
|
||||||
|
msg_type: MsgType,
|
||||||
|
event_type: EventType,
|
||||||
|
) -> Message:
|
||||||
|
"""Wait for specific event"""
|
||||||
|
while True:
|
||||||
|
msg = await receive_message(websocket)
|
||||||
|
if msg.type != msg_type or msg.event != event_type:
|
||||||
|
raise ValueError(f"Unexpected message: {msg}")
|
||||||
|
if msg.type == msg_type and msg.event == event_type:
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
async def full_client_request(
|
||||||
|
websocket: websockets.WebSocketClientProtocol, payload: bytes
|
||||||
|
) -> None:
|
||||||
|
"""Send full client message"""
|
||||||
|
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.NoSeq)
|
||||||
|
msg.payload = payload
|
||||||
|
logger.info(f"Sending: {msg}")
|
||||||
|
await websocket.send(msg.marshal())
|
||||||
|
|
||||||
|
|
||||||
|
async def audio_only_client(
|
||||||
|
websocket: websockets.WebSocketClientProtocol, payload: bytes, flag: MsgTypeFlagBits
|
||||||
|
) -> None:
|
||||||
|
"""Send audio-only client message"""
|
||||||
|
msg = Message(type=MsgType.AudioOnlyClient, flag=flag)
|
||||||
|
msg.payload = payload
|
||||||
|
logger.info(f"Sending: {msg}")
|
||||||
|
await websocket.send(msg.marshal())
|
||||||
|
|
||||||
|
|
||||||
|
async def start_connection(websocket: websockets.WebSocketClientProtocol) -> None:
|
||||||
|
"""Start connection"""
|
||||||
|
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
|
||||||
|
msg.event = EventType.StartConnection
|
||||||
|
msg.payload = b"{}"
|
||||||
|
logger.info(f"Sending: {msg}")
|
||||||
|
await websocket.send(msg.marshal())
|
||||||
|
|
||||||
|
|
||||||
|
async def finish_connection(websocket: websockets.WebSocketClientProtocol) -> None:
|
||||||
|
"""Finish connection"""
|
||||||
|
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
|
||||||
|
msg.event = EventType.FinishConnection
|
||||||
|
msg.payload = b"{}"
|
||||||
|
logger.info(f"Sending: {msg}")
|
||||||
|
await websocket.send(msg.marshal())
|
||||||
|
|
||||||
|
|
||||||
|
async def start_session(
|
||||||
|
websocket: websockets.WebSocketClientProtocol, payload: bytes, session_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Start session"""
|
||||||
|
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
|
||||||
|
msg.event = EventType.StartSession
|
||||||
|
msg.session_id = session_id
|
||||||
|
msg.payload = payload
|
||||||
|
logger.info(f"Sending: {msg}")
|
||||||
|
await websocket.send(msg.marshal())
|
||||||
|
|
||||||
|
|
||||||
|
async def finish_session(
|
||||||
|
websocket: websockets.WebSocketClientProtocol, session_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Finish session"""
|
||||||
|
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
|
||||||
|
msg.event = EventType.FinishSession
|
||||||
|
msg.session_id = session_id
|
||||||
|
msg.payload = b"{}"
|
||||||
|
logger.info(f"Sending: {msg}")
|
||||||
|
await websocket.send(msg.marshal())
|
||||||
|
|
||||||
|
|
||||||
|
async def cancel_session(
|
||||||
|
websocket: websockets.WebSocketClientProtocol, session_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Cancel session"""
|
||||||
|
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
|
||||||
|
msg.event = EventType.CancelSession
|
||||||
|
msg.session_id = session_id
|
||||||
|
msg.payload = b"{}"
|
||||||
|
logger.info(f"Sending: {msg}")
|
||||||
|
await websocket.send(msg.marshal())
|
||||||
|
|
||||||
|
|
||||||
|
async def task_request(
|
||||||
|
websocket: websockets.WebSocketClientProtocol, payload: bytes, session_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Send task request"""
|
||||||
|
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
|
||||||
|
msg.event = EventType.TaskRequest
|
||||||
|
msg.session_id = session_id
|
||||||
|
msg.payload = payload
|
||||||
|
logger.info(f"Sending: {msg}")
|
||||||
|
await websocket.send(msg.marshal())
|
||||||
11
volcengine_binary_demo/pyproject.toml
Normal file
11
volcengine_binary_demo/pyproject.toml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=42", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "volc-speech-python-sdk"
|
||||||
|
version = "0.1.0"
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
dependencies = [
|
||||||
|
"websockets>=14.0",
|
||||||
|
]
|
||||||
11
volcengine_binary_demo/setup.py
Normal file
11
volcengine_binary_demo/setup.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="volc-speech-python-sdk",
|
||||||
|
version="0.1.0",
|
||||||
|
packages=find_packages(include=["protocols"]),
|
||||||
|
install_requires=[
|
||||||
|
"websockets>=14.0",
|
||||||
|
],
|
||||||
|
python_requires=">=3.9",
|
||||||
|
)
|
||||||
593
web_app.py
Normal file
593
web_app.py
Normal file
@@ -0,0 +1,593 @@
|
|||||||
|
"""
|
||||||
|
MatchMe Studio - 6-Step Video Creation Wizard (v2)
|
||||||
|
"""
|
||||||
|
import streamlit as st
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import config
|
||||||
|
from modules import brain, factory, editor, storage, ingest, asr, project
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
st.set_page_config(
|
||||||
|
page_title="MatchMe 视频工场",
|
||||||
|
page_icon="🎬",
|
||||||
|
layout="wide"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Custom CSS
|
||||||
|
st.markdown("""
|
||||||
|
<style>
|
||||||
|
/* Fix for file uploader */
|
||||||
|
section[data-testid="stFileUploader"] {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.step-header {
|
||||||
|
background: linear-gradient(90deg, #FF4B4B, #FF914D);
|
||||||
|
padding: 10px 20px;
|
||||||
|
border-radius: 10px;
|
||||||
|
color: white;
|
||||||
|
font-weight: bold;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
.stButton>button {
|
||||||
|
border-radius: 20px;
|
||||||
|
background: linear-gradient(45deg, #FF4B4B, #FF914D);
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 10px 30px;
|
||||||
|
}
|
||||||
|
.scene-card {
|
||||||
|
background: #262730;
|
||||||
|
padding: 15px;
|
||||||
|
border-radius: 10px;
|
||||||
|
margin: 10px 0;
|
||||||
|
}
|
||||||
|
.question-card {
|
||||||
|
background: #1e1e2e;
|
||||||
|
padding: 15px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin: 10px 0;
|
||||||
|
border-left: 3px solid #FF4B4B;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
""", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
|
||||||
|
def init_session():
|
||||||
|
"""Initialize session state."""
|
||||||
|
if "proj" not in st.session_state:
|
||||||
|
st.session_state.proj = project.create_project()
|
||||||
|
if "step" not in st.session_state:
|
||||||
|
st.session_state.step = 0
|
||||||
|
if "brief" not in st.session_state:
|
||||||
|
st.session_state.brief = {}
|
||||||
|
|
||||||
|
|
||||||
|
def render_sidebar():
|
||||||
|
"""Render sidebar with project info."""
|
||||||
|
with st.sidebar:
|
||||||
|
st.header("项目控制台")
|
||||||
|
|
||||||
|
proj = st.session_state.proj
|
||||||
|
st.text(f"项目 ID: {proj.id}")
|
||||||
|
st.text(f"状态: {proj.status}")
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
load_id = st.text_input("恢复项目 (输入ID)")
|
||||||
|
if st.button("加载"):
|
||||||
|
loaded = project.load_project(load_id)
|
||||||
|
if loaded:
|
||||||
|
st.session_state.proj = loaded
|
||||||
|
st.success(f"已加载项目 {load_id}")
|
||||||
|
st.rerun()
|
||||||
|
else:
|
||||||
|
st.error("项目不存在")
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
if st.button("重置项目"):
|
||||||
|
st.session_state.proj = project.create_project()
|
||||||
|
st.session_state.step = 0
|
||||||
|
st.session_state.brief = {}
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
steps = ["素材提交", "AI分析", "脚本生成", "画面生成", "视频生成", "最终合成"]
|
||||||
|
for i, name in enumerate(steps):
|
||||||
|
if i == st.session_state.step:
|
||||||
|
st.markdown(f"**→ {i}. {name}**")
|
||||||
|
elif i < st.session_state.step:
|
||||||
|
st.markdown(f"✅ {i}. {name}")
|
||||||
|
else:
|
||||||
|
st.markdown(f"○ {i}. {name}")
|
||||||
|
|
||||||
|
|
||||||
|
def step0_ingest():
|
||||||
|
"""Step 0: Material Submission."""
|
||||||
|
st.markdown('<div class="step-header">Step 0: 素材提交</div>', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
proj = st.session_state.proj
|
||||||
|
|
||||||
|
mode = st.radio(
|
||||||
|
"选择输入方式",
|
||||||
|
["纯文本创意", "图片 + 描述", "视频 + 描述"],
|
||||||
|
horizontal=True
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = st.text_area("创意描述 / 产品卖点", height=100, placeholder="描述你想要的视频内容...")
|
||||||
|
|
||||||
|
if mode == "纯文本创意":
|
||||||
|
proj.input_mode = "text"
|
||||||
|
|
||||||
|
elif mode == "图片 + 描述":
|
||||||
|
proj.input_mode = "images"
|
||||||
|
uploaded = st.file_uploader("上传图片 (支持多张)", type=["jpg", "png", "jpeg"], accept_multiple_files=True)
|
||||||
|
|
||||||
|
if uploaded:
|
||||||
|
urls = []
|
||||||
|
with st.spinner("上传图片中..."):
|
||||||
|
for f in uploaded:
|
||||||
|
temp_path = config.TEMP_DIR / f.name
|
||||||
|
with open(temp_path, "wb") as fp:
|
||||||
|
fp.write(f.getbuffer())
|
||||||
|
url = storage.upload_file(str(temp_path))
|
||||||
|
if url:
|
||||||
|
urls.append(url)
|
||||||
|
else:
|
||||||
|
st.error(f"上传失败: {f.name}")
|
||||||
|
|
||||||
|
if urls:
|
||||||
|
proj.image_urls = urls
|
||||||
|
st.image(urls, width=150)
|
||||||
|
st.success(f"成功上传 {len(urls)} 张图片")
|
||||||
|
|
||||||
|
elif mode == "视频 + 描述":
|
||||||
|
proj.input_mode = "video"
|
||||||
|
uploaded = st.file_uploader("上传视频", type=["mp4"])
|
||||||
|
|
||||||
|
if uploaded:
|
||||||
|
with st.spinner("处理视频中..."):
|
||||||
|
temp_path = config.TEMP_DIR / uploaded.name
|
||||||
|
with open(temp_path, "wb") as f:
|
||||||
|
f.write(uploaded.getbuffer())
|
||||||
|
|
||||||
|
try:
|
||||||
|
frame_urls, video_url = ingest.process_uploaded_video(str(temp_path))
|
||||||
|
proj.image_urls = frame_urls
|
||||||
|
proj.video_url = video_url
|
||||||
|
st.image(frame_urls, width=150, caption=["帧1", "帧2", "帧3"])
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"视频处理失败: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
asr_text = asr.transcribe_video(str(temp_path))
|
||||||
|
proj.asr_text = asr_text
|
||||||
|
st.info(f"语音识别: {asr_text[:100]}...")
|
||||||
|
except Exception as e:
|
||||||
|
st.warning(f"语音识别失败: {e}")
|
||||||
|
|
||||||
|
proj.prompt = prompt
|
||||||
|
|
||||||
|
if st.button("下一步: AI 分析", disabled=not prompt):
|
||||||
|
proj.status = "analyzing"
|
||||||
|
project.save_project(proj)
|
||||||
|
st.session_state.step = 1
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
|
def step1_analyze():
|
||||||
|
"""Step 1: AI Analysis & Questions with multi-select and custom input."""
|
||||||
|
st.markdown('<div class="step-header">Step 1: AI 深度分析</div>', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
proj = st.session_state.proj
|
||||||
|
|
||||||
|
# Run analysis if not done
|
||||||
|
if not proj.analysis:
|
||||||
|
with st.spinner("AI 正在分析素材..."):
|
||||||
|
result = brain.analyze_materials(
|
||||||
|
prompt=proj.prompt,
|
||||||
|
image_urls=proj.image_urls if proj.image_urls else None,
|
||||||
|
asr_text=proj.asr_text
|
||||||
|
)
|
||||||
|
proj.analysis = result.get("analysis", "")
|
||||||
|
proj.questions = result.get("questions", [])
|
||||||
|
project.save_project(proj)
|
||||||
|
|
||||||
|
st.subheader("分析结果")
|
||||||
|
st.write(proj.analysis)
|
||||||
|
|
||||||
|
# Show questions with multi-select and custom input
|
||||||
|
if proj.questions:
|
||||||
|
st.subheader("补充信息")
|
||||||
|
st.caption("请回答以下问题,帮助 AI 更好地理解你的需求")
|
||||||
|
|
||||||
|
answers = {}
|
||||||
|
for q in proj.questions:
|
||||||
|
qid = q["id"]
|
||||||
|
st.markdown(f'<div class="question-card">', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# Check if multi-select is allowed
|
||||||
|
allow_multiple = q.get("allow_multiple", False)
|
||||||
|
allow_custom = q.get("allow_custom", True)
|
||||||
|
|
||||||
|
if allow_multiple:
|
||||||
|
selected = st.multiselect(
|
||||||
|
q["text"],
|
||||||
|
q["options"],
|
||||||
|
key=f"q_{qid}"
|
||||||
|
)
|
||||||
|
answers[qid] = {"selected": selected}
|
||||||
|
else:
|
||||||
|
selected = st.radio(
|
||||||
|
q["text"],
|
||||||
|
q["options"],
|
||||||
|
key=f"q_{qid}"
|
||||||
|
)
|
||||||
|
answers[qid] = {"selected": [selected] if selected else []}
|
||||||
|
|
||||||
|
# Custom input for additional context
|
||||||
|
if allow_custom:
|
||||||
|
custom = st.text_input(
|
||||||
|
"补充说明 (选填)",
|
||||||
|
key=f"custom_{qid}",
|
||||||
|
placeholder="如有其他想法,请在此补充..."
|
||||||
|
)
|
||||||
|
answers[qid]["custom"] = custom
|
||||||
|
|
||||||
|
st.markdown('</div>', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
if st.button("确认回答,生成创意简报"):
|
||||||
|
proj.answers = answers
|
||||||
|
|
||||||
|
# Refine brief with answers
|
||||||
|
with st.spinner("整合创意简报中..."):
|
||||||
|
brief_result = brain.refine_brief(
|
||||||
|
proj.prompt,
|
||||||
|
{"analysis": proj.analysis},
|
||||||
|
answers,
|
||||||
|
proj.image_urls
|
||||||
|
)
|
||||||
|
st.session_state.brief = brief_result.get("brief", {})
|
||||||
|
|
||||||
|
# Store creative summary
|
||||||
|
if "creative_summary" in brief_result:
|
||||||
|
st.session_state.brief["creative_summary"] = brief_result["creative_summary"]
|
||||||
|
|
||||||
|
project.save_project(proj)
|
||||||
|
st.session_state.step = 2
|
||||||
|
st.rerun()
|
||||||
|
else:
|
||||||
|
# No questions needed, build basic brief
|
||||||
|
if st.button("下一步: 生成脚本"):
|
||||||
|
st.session_state.brief = {
|
||||||
|
"product": proj.prompt,
|
||||||
|
"selling_points": [],
|
||||||
|
"style": "现代广告"
|
||||||
|
}
|
||||||
|
st.session_state.step = 2
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
|
def step2_script():
|
||||||
|
"""Step 2: Script Generation."""
|
||||||
|
st.markdown('<div class="step-header">Step 2: 脚本生成</div>', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
proj = st.session_state.proj
|
||||||
|
brief = st.session_state.brief
|
||||||
|
|
||||||
|
# Show creative summary
|
||||||
|
if brief.get("creative_summary"):
|
||||||
|
st.info(f"🎯 创意方向: {brief['creative_summary']}")
|
||||||
|
|
||||||
|
if brief.get("style"):
|
||||||
|
st.caption(f"视频风格: {brief['style']}")
|
||||||
|
|
||||||
|
# Generate script if not done
|
||||||
|
if not proj.scenes:
|
||||||
|
with st.spinner("AI 正在创作脚本..."):
|
||||||
|
script = brain.generate_script(brief, proj.image_urls)
|
||||||
|
proj.hook = script.get("hook", "")
|
||||||
|
proj.scenes = script.get("scenes", [])
|
||||||
|
proj.cta = script.get("cta", "")
|
||||||
|
|
||||||
|
# Store creative summary from script if available
|
||||||
|
if script.get("creative_summary"):
|
||||||
|
brief["creative_summary"] = script["creative_summary"]
|
||||||
|
st.session_state.brief = brief
|
||||||
|
|
||||||
|
proj.status = "scripting"
|
||||||
|
project.save_project(proj)
|
||||||
|
|
||||||
|
# Display script
|
||||||
|
st.subheader(f"🎣 Hook: {proj.hook}")
|
||||||
|
|
||||||
|
# Creative summary
|
||||||
|
if brief.get("creative_summary"):
|
||||||
|
st.markdown(f"**整体创意**: {brief['creative_summary']}")
|
||||||
|
|
||||||
|
for i, scene in enumerate(proj.scenes):
|
||||||
|
with st.expander(f"分镜 {scene.get('id', i+1)}: {scene.get('timeline', '')}"):
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.write(f"**时长**: {scene.get('duration', 5)}秒")
|
||||||
|
st.write(f"**运镜**: {scene.get('camera_movement', '')}")
|
||||||
|
st.write(f"**故事节拍**: {scene.get('story_beat', '')}")
|
||||||
|
st.write(f"**音效设计**: {scene.get('sound_design', '')}")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
kf = scene.get("keyframe", {})
|
||||||
|
st.write(f"**色调**: {kf.get('color_tone', '')}")
|
||||||
|
st.write(f"**环境**: {kf.get('environment', '')}")
|
||||||
|
st.write(f"**焦点**: {kf.get('focus', '')}")
|
||||||
|
st.write(f"**构图**: {kf.get('composition', '')}")
|
||||||
|
|
||||||
|
# Image prompt (key for generation)
|
||||||
|
st.write("**生图Prompt**:")
|
||||||
|
st.code(scene.get("image_prompt", "(未生成)"), language=None)
|
||||||
|
|
||||||
|
st.write(f"**旁白**: {scene.get('voiceover', '(无)')}")
|
||||||
|
|
||||||
|
feedback = st.text_input(f"修改意见", key=f"fb_{i}")
|
||||||
|
if st.button(f"重新生成此分镜", key=f"regen_{i}"):
|
||||||
|
with st.spinner("重新生成中..."):
|
||||||
|
new_scene = brain.regenerate_scene(
|
||||||
|
{"hook": proj.hook, "scenes": proj.scenes, "cta": proj.cta},
|
||||||
|
scene.get("id", i+1),
|
||||||
|
feedback,
|
||||||
|
brief
|
||||||
|
)
|
||||||
|
proj.scenes[i] = new_scene
|
||||||
|
project.save_project(proj)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# CTA - ensure it's a string
|
||||||
|
cta_text = proj.cta
|
||||||
|
if isinstance(cta_text, dict):
|
||||||
|
cta_text = cta_text.get("text", str(cta_text))
|
||||||
|
st.subheader(f"📢 CTA: {cta_text}")
|
||||||
|
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
regen_feedback = st.text_input("整体修改意见")
|
||||||
|
if st.button("重新生成整个脚本"):
|
||||||
|
with st.spinner("重新生成中..."):
|
||||||
|
script = brain.generate_script(brief, proj.image_urls, regen_feedback)
|
||||||
|
proj.hook = script.get("hook", "")
|
||||||
|
proj.scenes = script.get("scenes", [])
|
||||||
|
proj.cta = script.get("cta", "")
|
||||||
|
project.save_project(proj)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
if st.button("确认脚本,下一步"):
|
||||||
|
st.session_state.step = 3
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
|
def step3_images():
|
||||||
|
"""Step 3: Image Generation (Concurrent) using Gemini Image."""
|
||||||
|
st.markdown('<div class="step-header">Step 3: 画面生成 (Gemini Image)</div>', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
proj = st.session_state.proj
|
||||||
|
brief = st.session_state.brief
|
||||||
|
|
||||||
|
# Show reference images if available
|
||||||
|
if proj.image_urls:
|
||||||
|
st.caption("参考素材(用于保持产品一致性):")
|
||||||
|
st.image(proj.image_urls[:3], width=100)
|
||||||
|
|
||||||
|
has_images = all(s.get("image_url") for s in proj.scenes)
|
||||||
|
|
||||||
|
if not has_images:
|
||||||
|
if st.button("开始生成所有画面 (并发)"):
|
||||||
|
progress = st.progress(0)
|
||||||
|
status = st.empty()
|
||||||
|
|
||||||
|
try:
|
||||||
|
status.text("正在并发生成所有分镜画面...")
|
||||||
|
# Pass user's reference images for product consistency
|
||||||
|
image_urls = factory.generate_all_scene_images_concurrent(
|
||||||
|
proj.scenes,
|
||||||
|
brief,
|
||||||
|
reference_images=proj.image_urls, # 传递用户素材
|
||||||
|
max_workers=3
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, url in enumerate(image_urls):
|
||||||
|
if url:
|
||||||
|
proj.scenes[i]["image_url"] = url
|
||||||
|
progress.progress((i + 1) / len(proj.scenes))
|
||||||
|
|
||||||
|
proj.status = "imaging"
|
||||||
|
project.save_project(proj)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"生成失败: {e}")
|
||||||
|
import traceback
|
||||||
|
st.code(traceback.format_exc())
|
||||||
|
|
||||||
|
# Display images in grid
|
||||||
|
cols = st.columns(min(4, len(proj.scenes)))
|
||||||
|
for i, scene in enumerate(proj.scenes):
|
||||||
|
with cols[i % 4]:
|
||||||
|
img_url = scene.get("image_url", "")
|
||||||
|
if img_url:
|
||||||
|
st.image(img_url, caption=f"分镜 {scene.get('id', i+1)}")
|
||||||
|
|
||||||
|
if st.button(f"重新生成", key=f"img_regen_{i}"):
|
||||||
|
with st.spinner("生成中..."):
|
||||||
|
url = factory.generate_scene_image(scene, brief, proj.image_urls)
|
||||||
|
proj.scenes[i]["image_url"] = url
|
||||||
|
project.save_project(proj)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
custom = st.file_uploader(f"替换", key=f"img_up_{i}", type=["jpg", "png"])
|
||||||
|
if custom:
|
||||||
|
temp_path = config.TEMP_DIR / custom.name
|
||||||
|
with open(temp_path, "wb") as f:
|
||||||
|
f.write(custom.getbuffer())
|
||||||
|
url = storage.upload_file(str(temp_path))
|
||||||
|
if url:
|
||||||
|
proj.scenes[i]["image_url"] = url
|
||||||
|
project.save_project(proj)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
vo = st.text_area(f"旁白", scene.get("voiceover", ""), key=f"vo_{i}", height=80)
|
||||||
|
if vo != scene.get("voiceover", ""):
|
||||||
|
proj.scenes[i]["voiceover"] = vo
|
||||||
|
project.save_project(proj)
|
||||||
|
|
||||||
|
if has_images and st.button("下一步: 生成视频"):
|
||||||
|
st.session_state.step = 4
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
|
def step4_videos():
|
||||||
|
"""Step 4: Video Generation (Concurrent) using Sora 2."""
|
||||||
|
st.markdown('<div class="step-header">Step 4: 分镜视频生成 (Sora 2)</div>', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
proj = st.session_state.proj
|
||||||
|
|
||||||
|
has_videos = all(s.get("video_url") for s in proj.scenes)
|
||||||
|
|
||||||
|
if not has_videos:
|
||||||
|
if st.button("开始生成所有视频 (并发)"):
|
||||||
|
progress = st.progress(0)
|
||||||
|
status = st.empty()
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_urls = [s.get("image_url") for s in proj.scenes]
|
||||||
|
|
||||||
|
status.text("正在并发生成所有分镜视频 (Sora 2)...")
|
||||||
|
video_urls = factory.generate_all_scene_videos_concurrent(
|
||||||
|
proj.scenes,
|
||||||
|
image_urls,
|
||||||
|
max_workers=2
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, url in enumerate(video_urls):
|
||||||
|
if url:
|
||||||
|
proj.scenes[i]["video_url"] = url
|
||||||
|
progress.progress((i + 1) / len(proj.scenes))
|
||||||
|
|
||||||
|
proj.status = "video"
|
||||||
|
project.save_project(proj)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"视频生成失败: {e}")
|
||||||
|
import traceback
|
||||||
|
st.code(traceback.format_exc())
|
||||||
|
|
||||||
|
# Display videos
|
||||||
|
for i, scene in enumerate(proj.scenes):
|
||||||
|
vid_url = scene.get("video_url", "")
|
||||||
|
if vid_url:
|
||||||
|
col1, col2 = st.columns([3, 1])
|
||||||
|
with col1:
|
||||||
|
st.video(vid_url)
|
||||||
|
with col2:
|
||||||
|
st.write(f"分镜 {scene.get('id', i+1)}")
|
||||||
|
st.write(f"{scene.get('duration', 5)}秒")
|
||||||
|
|
||||||
|
if st.button(f"重新生成", key=f"vid_regen_{i}"):
|
||||||
|
with st.spinner("生成中..."):
|
||||||
|
image_url = scene.get("image_url", "")
|
||||||
|
url = factory.generate_scene_video(
|
||||||
|
image_url,
|
||||||
|
scene.get("camera_movement", "slow zoom"),
|
||||||
|
scene.get("duration", 5)
|
||||||
|
)
|
||||||
|
proj.scenes[i]["video_url"] = url
|
||||||
|
project.save_project(proj)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
if has_videos and st.button("下一步: 合成"):
|
||||||
|
st.session_state.step = 5
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
|
def step5_render():
|
||||||
|
"""Step 5: Final Rendering."""
|
||||||
|
st.markdown('<div class="step-header">Step 5: 最终合成</div>', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
proj = st.session_state.proj
|
||||||
|
brief = st.session_state.brief
|
||||||
|
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
add_subtitles = st.checkbox("烧录字幕", value=True)
|
||||||
|
add_voiceover = st.checkbox("添加旁白配音", value=True)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
add_bgm = st.checkbox("添加背景音乐", value=False)
|
||||||
|
bgm_file = None
|
||||||
|
if add_bgm:
|
||||||
|
bgm_file = st.file_uploader("上传 BGM", type=["mp3", "wav"])
|
||||||
|
|
||||||
|
if st.button("开始合成"):
|
||||||
|
with st.spinner("合成中,请稍候..."):
|
||||||
|
video_urls = [s.get("video_url") for s in proj.scenes]
|
||||||
|
|
||||||
|
vo_url = ""
|
||||||
|
if add_voiceover:
|
||||||
|
style = brief.get("style", "")
|
||||||
|
vo_url = factory.generate_full_voiceover(proj.scenes, style)
|
||||||
|
|
||||||
|
bgm_url = ""
|
||||||
|
if bgm_file:
|
||||||
|
temp_path = config.TEMP_DIR / bgm_file.name
|
||||||
|
with open(temp_path, "wb") as f:
|
||||||
|
f.write(bgm_file.getbuffer())
|
||||||
|
bgm_url = storage.upload_file(str(temp_path))
|
||||||
|
|
||||||
|
final_url = editor.assemble_final_video(
|
||||||
|
video_urls=video_urls,
|
||||||
|
scenes=proj.scenes if add_subtitles else [],
|
||||||
|
voiceover_url=vo_url,
|
||||||
|
bgm_url=bgm_url
|
||||||
|
)
|
||||||
|
|
||||||
|
proj.final_video_url = final_url
|
||||||
|
proj.status = "done"
|
||||||
|
project.save_project(proj)
|
||||||
|
|
||||||
|
st.success("🎉 视频合成完成!")
|
||||||
|
st.video(final_url)
|
||||||
|
st.markdown(f"### [📥 下载高清视频]({final_url})")
|
||||||
|
|
||||||
|
storage.cleanup_temp()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
init_session()
|
||||||
|
render_sidebar()
|
||||||
|
|
||||||
|
st.title("MatchMe 视频工场 🎬")
|
||||||
|
st.caption("AI 驱动的短视频创作平台")
|
||||||
|
|
||||||
|
step = st.session_state.step
|
||||||
|
|
||||||
|
if step == 0:
|
||||||
|
step0_ingest()
|
||||||
|
elif step == 1:
|
||||||
|
step1_analyze()
|
||||||
|
elif step == 2:
|
||||||
|
step2_script()
|
||||||
|
elif step == 3:
|
||||||
|
step3_images()
|
||||||
|
elif step == 4:
|
||||||
|
step4_videos()
|
||||||
|
elif step == 5:
|
||||||
|
step5_render()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user