804 lines
26 KiB
Python
804 lines
26 KiB
Python
"""
|
||
MatchMe Studio - Factory Module (Concurrent Scene Generation)
|
||
Using Volcengine (Doubao) API for Image and Video
|
||
"""
|
||
import os
|
||
import time
|
||
import logging
|
||
import requests
|
||
import json
|
||
import re
|
||
import base64
|
||
import subprocess
|
||
from pathlib import Path
|
||
from typing import Dict, Any, List, Optional
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from elevenlabs import ElevenLabs, VoiceSettings
|
||
from openai import OpenAI
|
||
|
||
import config
|
||
from modules import storage
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Initialize OpenAI Client for Volcengine Image Generation
|
||
client = OpenAI(
|
||
api_key=config.VOLC_API_KEY,
|
||
base_url=config.VOLC_BASE_URL
|
||
)
|
||
|
||
# ============================================================
|
||
# Helper Functions
|
||
# ============================================================
|
||
|
||
def _download_as_base64(url: str) -> str:
|
||
"""Download image from URL and convert to Base64."""
|
||
try:
|
||
response = requests.get(url)
|
||
response.raise_for_status()
|
||
return base64.b64encode(response.content).decode('utf-8')
|
||
except Exception as e:
|
||
logger.error(f"Failed to download/encode image: {e}")
|
||
return ""
|
||
|
||
# ============================================================
|
||
# Image Generation (Doubao / Volcengine)
|
||
# ============================================================
|
||
|
||
def generate_scene_image(
|
||
scene: Dict[str, Any],
|
||
brief: Dict[str, Any] = None,
|
||
reference_images: List[str] = None
|
||
) -> str:
|
||
"""
|
||
Generate image using Volcengine API (Doubao Image).
|
||
Using raw requests to match user's curl example exactly.
|
||
"""
|
||
# Build prompt
|
||
image_prompt = scene.get("image_prompt", "")
|
||
if not image_prompt:
|
||
# Fallback prompt construction
|
||
keyframe = scene.get("keyframe", {})
|
||
# Stronger style consistency intro
|
||
parts = ["Cinematic shot, 8k, photorealistic"]
|
||
if brief:
|
||
if brief.get("product_visual_description"):
|
||
parts.append(f"Product: {brief['product_visual_description']}")
|
||
parts.extend([
|
||
f"Subject: {keyframe.get('subject', 'product')}",
|
||
f"Environment: {keyframe.get('environment', 'studio')}",
|
||
f"Action: {keyframe.get('focus', '')}"
|
||
])
|
||
image_prompt = ", ".join(parts)
|
||
|
||
# Append explicit consistency enforcement to prompt
|
||
if brief and brief.get("product_visual_description"):
|
||
if brief['product_visual_description'] not in image_prompt:
|
||
image_prompt = f"{brief['product_visual_description']}, {image_prompt}"
|
||
|
||
logger.info(f"Generating image (Volcengine): {image_prompt[:50]}...")
|
||
|
||
url = f"{config.VOLC_BASE_URL}/images/generations"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {config.VOLC_API_KEY}"
|
||
}
|
||
|
||
# Payload matching user's curl example
|
||
payload = {
|
||
"model": config.IMAGE_MODEL_ID,
|
||
"prompt": image_prompt,
|
||
"sequential_image_generation": "disabled",
|
||
"response_format": "b64_json", # Use base64 to avoid temp url expiration issues
|
||
"size": "2K", # User specified 2K
|
||
"stream": False,
|
||
"watermark": True
|
||
}
|
||
|
||
try:
|
||
response = requests.post(url, headers=headers, json=payload, timeout=60)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"Image API Error: {response.text}")
|
||
raise ValueError(f"Image API failed: {response.status_code} - {response.text}")
|
||
|
||
data = response.json()
|
||
|
||
# Extract Image Data
|
||
image_data = None
|
||
if "data" in data and len(data["data"]) > 0:
|
||
image_data = data["data"][0].get("b64_json")
|
||
if not image_data:
|
||
# Fallback to URL download if b64 not present
|
||
img_url = data["data"][0].get("url")
|
||
if img_url:
|
||
# Download the image to ensure we have it locally
|
||
image_data = _download_as_base64(img_url)
|
||
|
||
if not image_data:
|
||
raise ValueError("No image data returned")
|
||
|
||
# Decode and Save
|
||
filename = f"scene_{scene.get('id', 0)}_{int(time.time())}.jpg"
|
||
local_path = config.TEMP_DIR / filename
|
||
|
||
with open(local_path, "wb") as f:
|
||
f.write(base64.b64decode(image_data))
|
||
|
||
# Upload to R2
|
||
r2_url = storage.upload_file(str(local_path))
|
||
logger.info(f"Scene {scene.get('id', '?')} image uploaded: {r2_url}")
|
||
return r2_url
|
||
|
||
except Exception as e:
|
||
logger.error(f"Image Generation Failed: {e}")
|
||
raise
|
||
|
||
|
||
def generate_all_scene_images_concurrent(
|
||
scenes: List[Dict[str, Any]],
|
||
brief: Dict[str, Any] = None,
|
||
reference_images: List[str] = None,
|
||
max_workers: int = 3
|
||
) -> List[str]:
|
||
"""Generate images for all scenes concurrently."""
|
||
logger.info(f"Generating {len(scenes)} images concurrently...")
|
||
image_urls = [None] * len(scenes)
|
||
|
||
def generate_single(index: int, scene: Dict[str, Any]) -> tuple:
|
||
url = generate_scene_image(scene, brief, reference_images)
|
||
return index, url
|
||
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
futures = {
|
||
executor.submit(generate_single, i, scene): i
|
||
for i, scene in enumerate(scenes)
|
||
}
|
||
|
||
for future in as_completed(futures):
|
||
index = futures[future]
|
||
try:
|
||
_, url = future.result()
|
||
image_urls[index] = url
|
||
except Exception as e:
|
||
logger.error(f"Scene {index+1} failed: {e}")
|
||
|
||
return image_urls
|
||
|
||
|
||
# ============================================================
|
||
# Video Generation (Doubao Video / PixelDance)
|
||
# ============================================================
|
||
|
||
def generate_scene_video(
|
||
start_frame_url: str,
|
||
motion_prompt: str,
|
||
duration: int = 5
|
||
) -> str:
|
||
"""
|
||
Generate video using Volcengine API (Async Task Flow).
|
||
"""
|
||
logger.info(f"Generating video (Volcengine): {motion_prompt[:50]}...")
|
||
|
||
# 1. Create Task
|
||
create_url = f"{config.VOLC_BASE_URL}/contents/generations/tasks"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {config.VOLC_API_KEY}"
|
||
}
|
||
|
||
# Construct Content List (Text + Optional Image)
|
||
content_list = [
|
||
{
|
||
"type": "text",
|
||
"text": f"{motion_prompt} --resolution 1080p --duration {duration} --camerafixed false --watermark true"
|
||
}
|
||
]
|
||
|
||
if start_frame_url:
|
||
content_list.append({
|
||
"type": "image_url",
|
||
"image_url": {"url": start_frame_url}
|
||
})
|
||
|
||
payload = {
|
||
"model": config.VIDEO_MODEL_ID,
|
||
"content": content_list
|
||
}
|
||
|
||
try:
|
||
response = requests.post(create_url, headers=headers, json=payload, timeout=30)
|
||
if response.status_code != 200:
|
||
# 202 Accepted is also possible for async tasks
|
||
if response.status_code != 202:
|
||
logger.error(f"Video Task Creation Error: {response.text}")
|
||
raise ValueError(f"Video Task failed: {response.status_code} - {response.text}")
|
||
|
||
data = response.json()
|
||
task_id = data.get("id")
|
||
if not task_id:
|
||
# Sometimes ID is in data.id or similar
|
||
task_id = data.get("data", {}).get("id")
|
||
|
||
if not task_id:
|
||
raise ValueError(f"No Task ID returned: {data}")
|
||
|
||
logger.info(f"Video Task Created: {task_id}. Polling for result...")
|
||
|
||
# 2. Poll for Result
|
||
# GET /contents/generations/tasks/{id}
|
||
max_retries = 60 # 5 mins max (5s interval)
|
||
video_url = None
|
||
|
||
for _ in range(max_retries):
|
||
time.sleep(5)
|
||
status_url = f"{config.VOLC_BASE_URL}/contents/generations/tasks/{task_id}"
|
||
resp = requests.get(status_url, headers=headers, timeout=30)
|
||
|
||
if resp.status_code == 200:
|
||
res_data = resp.json()
|
||
# Check status
|
||
# Structure usually: data.status = "succeeded" / "running" / "failed"
|
||
# Or top level status
|
||
|
||
status = res_data.get("status")
|
||
if not status and "data" in res_data:
|
||
status = res_data["data"].get("status")
|
||
|
||
if status == "succeeded" or status == "SUCCEEDED":
|
||
# Extract URL
|
||
content = res_data.get("data", {}).get("content", [])
|
||
if not content and "content" in res_data:
|
||
content = res_data["content"]
|
||
|
||
# Find video url in content
|
||
# Content is usually list of dicts with type='video' or 'video_url'
|
||
for item in content:
|
||
if item.get("video_url"):
|
||
video_url = item["video_url"]
|
||
break
|
||
if item.get("url"): # sometimes just url
|
||
video_url = item["url"]
|
||
break
|
||
|
||
if video_url:
|
||
break
|
||
elif status == "failed" or status == "FAILED":
|
||
reason = res_data.get("data", {}).get("error", "Unknown error")
|
||
raise ValueError(f"Video Generation Failed: {reason}")
|
||
|
||
# If running/queued, continue waiting
|
||
|
||
if not video_url:
|
||
raise TimeoutError("Video generation timed out or failed to return URL.")
|
||
|
||
# 3. Download and Upload to R2
|
||
logger.info(f"Video Generated. Downloading: {video_url}")
|
||
filename = f"vid_doubao_{int(time.time())}.mp4"
|
||
local_path = config.TEMP_DIR / filename
|
||
|
||
resp = requests.get(video_url, stream=True)
|
||
if resp.status_code != 200:
|
||
raise ValueError(f"Failed to download generated video: {resp.status_code}")
|
||
|
||
with open(local_path, "wb") as f:
|
||
for chunk in resp.iter_content(chunk_size=8192):
|
||
f.write(chunk)
|
||
|
||
r2_url = storage.upload_file(str(local_path))
|
||
return r2_url
|
||
|
||
except Exception as e:
|
||
logger.error(f"Video Generation Error: {e}")
|
||
raise
|
||
|
||
|
||
def generate_all_scene_videos_concurrent(
|
||
scenes: List[Dict[str, Any]],
|
||
image_urls: List[str],
|
||
max_workers: int = 2
|
||
) -> List[str]:
|
||
"""Generate videos concurrently."""
|
||
logger.info(f"Generating {len(scenes)} videos concurrently...")
|
||
video_urls = [None] * len(scenes)
|
||
|
||
def generate_single(index: int, scene: Dict[str, Any], img_url: str) -> tuple:
|
||
motion = scene.get("camera_movement", "slow zoom")
|
||
if scene.get("image_prompt"):
|
||
motion = f"{scene['image_prompt']}. {motion}"
|
||
|
||
duration = scene.get("duration", 5)
|
||
url = generate_scene_video(img_url, motion, duration)
|
||
return index, url
|
||
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
futures = {
|
||
executor.submit(generate_single, i, scene, image_urls[i]): i
|
||
for i, scene in enumerate(scenes)
|
||
}
|
||
|
||
for future in as_completed(futures):
|
||
index = futures[future]
|
||
try:
|
||
_, url = future.result()
|
||
video_urls[index] = url
|
||
except Exception as e:
|
||
logger.error(f"Scene {index+1} video failed: {e}")
|
||
|
||
return video_urls
|
||
|
||
|
||
# ============================================================
|
||
# Audio Generation (ElevenLabs)
|
||
# ============================================================
|
||
|
||
def generate_voiceover(text: str, style: str = "") -> str:
|
||
"""Generate voiceover audio. Returns R2 URL."""
|
||
if not text or not text.strip():
|
||
return ""
|
||
|
||
stability = 0.3 if "ASMR" in style else 0.5
|
||
similarity = 0.9 if "ASMR" in style else 0.8
|
||
|
||
logger.info(f"Generating voiceover ({len(text)} chars, style={style})...")
|
||
|
||
try:
|
||
el_client = ElevenLabs(api_key=config.XI_KEY)
|
||
|
||
audio_stream = el_client.text_to_speech.convert(
|
||
voice_id=config.ELEVENLABS_VOICE_ID,
|
||
text=text,
|
||
model_id=config.ELEVENLABS_MODEL,
|
||
voice_settings=VoiceSettings(stability=stability, similarity_boost=similarity)
|
||
)
|
||
|
||
filename = f"vo_{int(time.time())}.mp3"
|
||
local_path = config.TEMP_DIR / filename
|
||
|
||
with open(local_path, "wb") as f:
|
||
for chunk in audio_stream:
|
||
f.write(chunk)
|
||
|
||
r2_url = storage.upload_file(str(local_path))
|
||
return r2_url
|
||
except Exception as e:
|
||
logger.error(f"Voiceover failed: {e}")
|
||
return ""
|
||
|
||
|
||
def generate_full_voiceover(scenes: List[Dict[str, Any]], style: str = "") -> str:
|
||
"""Generate combined voiceover for all scenes."""
|
||
voiceovers = []
|
||
for s in scenes:
|
||
vo = s.get("voiceover", "")
|
||
if vo and vo.strip() and not vo.startswith("("):
|
||
voiceovers.append(vo.strip())
|
||
|
||
if not voiceovers:
|
||
return ""
|
||
|
||
full_text = " ".join(voiceovers)
|
||
return generate_voiceover(full_text, style)
|
||
|
||
|
||
# ============================================================
|
||
# Audio Generation (Edge TTS - 免费中文语音合成)
|
||
# ============================================================
|
||
|
||
# Edge TTS 中文音色预设 (免费,效果好)
|
||
EDGE_TTS_VOICES = {
|
||
# 女声
|
||
"sweet_female": "zh-CN-XiaoxiaoNeural", # 晓晓 - 甜美活泼(推荐)
|
||
"gentle_female": "zh-CN-XiaoyiNeural", # 晓伊 - 温柔知性
|
||
"lively_female": "zh-CN-XiaochenNeural", # 晓辰 - 活泼可爱
|
||
"broadcast_female": "zh-CN-XiaoqiuNeural", # 晓秋 - 新闻播报
|
||
# 男声
|
||
"general_male": "zh-CN-YunxiNeural", # 云希 - 温暖男声
|
||
"broadcast_male": "zh-CN-YunjianNeural", # 云健 - 专业播报
|
||
}
|
||
|
||
# 火山引擎 TTS 音色预设 (需开通服务) - 选择抖音带货友好的音色
|
||
VOLC_TTS_VOICES = {
|
||
# 抖音带货友好女声
|
||
"sweet_female": "zh_female_vv_uranus_bigtts", # viv 2.0 通用女声(甜美)
|
||
"lively_female": "zh_female_jitangnv_saturn_bigtts", # 鸡汤女(元气)
|
||
"broadcast_female": "zh_male_ruyaichen_saturn_bigtts", # 入雅尘(新闻播报)- 若需女声播报可换 zh_female_meilinyou_saturn_bigtts
|
||
"meilinvyou": "zh_female_meilinvyou_saturn_bigtts",
|
||
# 男声
|
||
"general_male": "zh_male_dayi_saturn_bigtts", # 大义(沉稳男声)
|
||
}
|
||
|
||
|
||
def generate_voiceover_edge(
|
||
text: str,
|
||
voice_type: str = "sweet_female",
|
||
rate: str = "+0%",
|
||
volume: str = "+0%",
|
||
output_path: str = None
|
||
) -> str:
|
||
"""
|
||
使用 Edge TTS 生成中文旁白(免费,效果好)
|
||
|
||
Args:
|
||
text: 旁白文本
|
||
voice_type: 音色类型(见 EDGE_TTS_VOICES)或直接使用音色名
|
||
rate: 语速调整,如 "+10%", "-20%"
|
||
volume: 音量调整,如 "+10%", "-20%"
|
||
output_path: 输出路径
|
||
|
||
Returns:
|
||
音频文件路径
|
||
"""
|
||
import asyncio
|
||
import edge_tts
|
||
|
||
if not text or not text.strip():
|
||
logger.warning("Empty text provided for TTS")
|
||
return ""
|
||
|
||
# 获取音色
|
||
voice = EDGE_TTS_VOICES.get(voice_type, voice_type)
|
||
|
||
logger.info(f"Generating voiceover (Edge TTS): {len(text)} chars, voice={voice}")
|
||
|
||
if not output_path:
|
||
filename = f"vo_edge_{int(time.time())}.mp3"
|
||
output_path = str(config.TEMP_DIR / filename)
|
||
|
||
async def _generate():
|
||
communicate = edge_tts.Communicate(text, voice, rate=rate, volume=volume)
|
||
await communicate.save(output_path)
|
||
|
||
# Simple retry logic for Edge TTS
|
||
max_retries = 3
|
||
for i in range(max_retries):
|
||
try:
|
||
asyncio.run(_generate())
|
||
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
||
logger.info(f"Edge TTS voiceover generated: {output_path}")
|
||
return output_path
|
||
except Exception as e:
|
||
logger.warning(f"Edge TTS attempt {i+1} failed: {e}")
|
||
time.sleep(1.0) # wait before retry
|
||
|
||
logger.error("Edge TTS failed after retries.")
|
||
return ""
|
||
|
||
|
||
def generate_voiceover_volcengine_ws(
|
||
text: str,
|
||
voice_type: str = "sweet_female",
|
||
output_path: str = None,
|
||
timeout: int = 120
|
||
) -> str:
|
||
"""
|
||
使用火山 WebSocket Binary Demo 生成 TTS 音频
|
||
依赖目录:{PROJECT_ROOT}/volcengine_binary_demo/.venv/bin/python
|
||
"""
|
||
if not text or not text.strip():
|
||
logger.warning("Empty text provided for TTS (ws)")
|
||
return ""
|
||
|
||
voice_id = VOLC_TTS_VOICES.get(voice_type, voice_type)
|
||
|
||
# 跨平台路径:使用项目根目录相对路径
|
||
volc_demo_dir = config.BASE_DIR / "volcengine_binary_demo"
|
||
venv_python = volc_demo_dir / ".venv" / "bin" / "python"
|
||
demo_script = volc_demo_dir / "examples" / "volcengine" / "binary.py"
|
||
|
||
if not venv_python.exists() or not demo_script.exists():
|
||
logger.error("Volcengine WS demo or venv not found. Please install under volcengine_binary_demo/.venv")
|
||
return ""
|
||
|
||
if not output_path:
|
||
output_path = str(config.TEMP_DIR / f"vo_volc_ws_{int(time.time())}.mp3")
|
||
|
||
cmd = [
|
||
str(venv_python),
|
||
str(demo_script),
|
||
"--appid", config.VOLC_TTS_APPID,
|
||
"--access_token", config.VOLC_TTS_ACCESS_TOKEN,
|
||
"--voice_type", voice_id,
|
||
"--text", text,
|
||
"--encoding", "mp3",
|
||
]
|
||
|
||
logger.info(f"Calling Volcengine WS TTS: voice={voice_id}, len={len(text)}")
|
||
try:
|
||
result = subprocess.run(
|
||
cmd,
|
||
cwd=str(volc_demo_dir),
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=timeout,
|
||
)
|
||
if result.returncode != 0:
|
||
logger.error(f"Volc WS TTS failed: {result.stderr}")
|
||
return ""
|
||
|
||
# demo 保存在 cwd 下 voice_type.mp3
|
||
demo_out = volc_demo_dir / f"{voice_id}.mp3"
|
||
if not demo_out.exists():
|
||
logger.error("Volc WS TTS output not found")
|
||
return ""
|
||
|
||
Path(output_path).write_bytes(demo_out.read_bytes())
|
||
logger.info(f"Volc WS TTS saved to {output_path}")
|
||
return output_path
|
||
except Exception as e:
|
||
logger.error(f"Volc WS TTS error: {e}")
|
||
return ""
|
||
|
||
|
||
def generate_voiceover_volcengine(
|
||
text: str,
|
||
voice_type: str = "sweet_female",
|
||
speed_ratio: float = 1.0,
|
||
volume_ratio: float = 1.0,
|
||
pitch_ratio: float = 1.0,
|
||
output_path: str = None
|
||
) -> str:
|
||
"""
|
||
使用火山引擎 TTS 生成中文旁白
|
||
|
||
Args:
|
||
text: 旁白文本
|
||
voice_type: 音色类型(见 VOLC_TTS_VOICES)或直接使用音色 ID
|
||
speed_ratio: 语速(0.5-2.0,默认1.0)
|
||
volume_ratio: 音量(0.5-2.0,默认1.0)
|
||
pitch_ratio: 音调(0.5-2.0,默认1.0)
|
||
output_path: 输出路径(可选,默认自动生成)
|
||
|
||
Returns:
|
||
音频文件路径
|
||
"""
|
||
import uuid
|
||
|
||
if not text or not text.strip():
|
||
logger.warning("Empty text provided for TTS")
|
||
return ""
|
||
|
||
# 获取音色 ID(火山音色表 + fallback 自定义)
|
||
voice_id = VOLC_TTS_VOICES.get(voice_type, voice_type)
|
||
|
||
logger.info(f"Generating voiceover (Volcengine TTS): {len(text)} chars, voice={voice_id}")
|
||
|
||
# 先尝试 WebSocket Binary(官方 demo 已验证可用)
|
||
ws_path = generate_voiceover_volcengine_ws(text, voice_type, output_path)
|
||
if ws_path:
|
||
return ws_path
|
||
|
||
# 若 WS 异常,再尝试 HTTP
|
||
url = "https://openspeech.bytedance.com/api/v1/tts"
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer;{config.VOLC_TTS_ACCESS_TOKEN}"
|
||
}
|
||
|
||
payload = {
|
||
"app": {
|
||
"appid": config.VOLC_TTS_APPID,
|
||
"token": config.VOLC_TTS_ACCESS_TOKEN,
|
||
"cluster": "volcano_tts"
|
||
},
|
||
"user": {
|
||
"uid": "video_flow_user"
|
||
},
|
||
"audio": {
|
||
"voice_type": voice_id,
|
||
"encoding": "mp3",
|
||
"speed_ratio": speed_ratio,
|
||
"volume_ratio": volume_ratio,
|
||
"pitch_ratio": pitch_ratio
|
||
},
|
||
"request": {
|
||
"reqid": str(uuid.uuid4()),
|
||
"text": text,
|
||
"text_type": "plain",
|
||
"operation": "query",
|
||
"with_timestamp": "1",
|
||
"extra_param": json.dumps({
|
||
"disable_markdown_filter": False
|
||
})
|
||
}
|
||
}
|
||
|
||
try:
|
||
response = requests.post(url, headers=headers, json=payload, timeout=60)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"Volcengine TTS Error: {response.status_code} - {response.text}")
|
||
# Fallback to Edge TTS with a safe default voice
|
||
fallback_voice = "sweet_female" if voice_type not in EDGE_TTS_VOICES else voice_type
|
||
return generate_voiceover_edge(text, fallback_voice, output_path=output_path)
|
||
|
||
data = response.json()
|
||
|
||
ret_code = data.get("code")
|
||
if ret_code not in (0, 3000, 20000000):
|
||
error_msg = data.get("message", "Unknown error")
|
||
logger.error(f"Volcengine TTS Error: {error_msg}")
|
||
# Fallback to Edge TTS with a safe default voice
|
||
fallback_voice = "sweet_female" if voice_type not in EDGE_TTS_VOICES else voice_type
|
||
return generate_voiceover_edge(text, fallback_voice, output_path=output_path)
|
||
|
||
audio_data = data.get("data", "")
|
||
if not audio_data:
|
||
raise ValueError("No audio data returned")
|
||
|
||
if not output_path:
|
||
filename = f"vo_volc_{int(time.time())}.mp3"
|
||
output_path = str(config.TEMP_DIR / filename)
|
||
|
||
with open(output_path, "wb") as f:
|
||
f.write(base64.b64decode(audio_data))
|
||
|
||
logger.info(f"Voiceover generated (HTTP): {output_path}")
|
||
return output_path
|
||
|
||
except Exception as e:
|
||
logger.error(f"Volcengine TTS HTTP error: {e}")
|
||
# Fallback to Edge TTS with a safe default voice
|
||
fallback_voice = "sweet_female" if voice_type not in EDGE_TTS_VOICES else voice_type
|
||
return generate_voiceover_edge(text, fallback_voice, output_path=output_path)
|
||
|
||
|
||
def generate_voiceover_volcengine_long(
|
||
text: str,
|
||
voice_type: str = "sweet_female",
|
||
speed_ratio: float = 1.0,
|
||
output_path: str = None,
|
||
max_chunk_length: int = 300
|
||
) -> str:
|
||
"""
|
||
火山引擎 TTS 长文本处理(自动分段合成)
|
||
|
||
对于超过 max_chunk_length 的文本,自动分段合成后拼接
|
||
"""
|
||
if len(text) <= max_chunk_length:
|
||
return generate_voiceover_volcengine(
|
||
text=text,
|
||
voice_type=voice_type,
|
||
speed_ratio=speed_ratio,
|
||
output_path=output_path
|
||
)
|
||
|
||
logger.info(f"Long text ({len(text)} chars), splitting into chunks...")
|
||
|
||
# 按句子分段
|
||
import re
|
||
sentences = re.split(r'([。!?;.!?;])', text)
|
||
|
||
chunks = []
|
||
current_chunk = ""
|
||
|
||
for i in range(0, len(sentences) - 1, 2):
|
||
sentence = sentences[i] + (sentences[i + 1] if i + 1 < len(sentences) else "")
|
||
|
||
if len(current_chunk) + len(sentence) <= max_chunk_length:
|
||
current_chunk += sentence
|
||
else:
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
current_chunk = sentence
|
||
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
|
||
# 如果最后一段是奇数句子
|
||
if len(sentences) % 2 == 1 and sentences[-1]:
|
||
if chunks:
|
||
chunks[-1] += sentences[-1]
|
||
else:
|
||
chunks.append(sentences[-1])
|
||
|
||
logger.info(f"Split into {len(chunks)} chunks")
|
||
|
||
# 生成每段音频
|
||
chunk_files = []
|
||
for i, chunk in enumerate(chunks):
|
||
chunk_path = str(config.TEMP_DIR / f"vo_chunk_{i}_{int(time.time())}.mp3")
|
||
try:
|
||
path = generate_voiceover_volcengine(
|
||
text=chunk,
|
||
voice_type=voice_type,
|
||
speed_ratio=speed_ratio,
|
||
output_path=chunk_path
|
||
)
|
||
chunk_files.append(path)
|
||
except Exception as e:
|
||
logger.error(f"Chunk {i} failed: {e}")
|
||
# 继续处理其他段落
|
||
|
||
if not chunk_files:
|
||
raise ValueError("All TTS chunks failed")
|
||
|
||
# 使用 FFmpeg 合并音频
|
||
if len(chunk_files) == 1:
|
||
if output_path:
|
||
import shutil
|
||
shutil.move(chunk_files[0], output_path)
|
||
return output_path
|
||
return chunk_files[0]
|
||
|
||
# 创建合并文件列表
|
||
concat_list = config.TEMP_DIR / f"concat_audio_{os.getpid()}.txt"
|
||
with open(concat_list, "w") as f:
|
||
for cf in chunk_files:
|
||
f.write(f"file '{cf}'\n")
|
||
|
||
if not output_path:
|
||
output_path = str(config.TEMP_DIR / f"vo_volc_merged_{int(time.time())}.mp3")
|
||
|
||
# FFmpeg 合并
|
||
import subprocess
|
||
cmd = [
|
||
"ffmpeg", "-y",
|
||
"-f", "concat",
|
||
"-safe", "0",
|
||
"-i", str(concat_list),
|
||
"-c", "copy",
|
||
output_path
|
||
]
|
||
|
||
subprocess.run(cmd, capture_output=True, check=True)
|
||
|
||
# 清理临时文件
|
||
for cf in chunk_files:
|
||
try:
|
||
os.remove(cf)
|
||
except:
|
||
pass
|
||
concat_list.unlink(missing_ok=True)
|
||
|
||
logger.info(f"Merged voiceover: {output_path}")
|
||
return output_path
|
||
|
||
|
||
def generate_scene_voiceovers_volcengine(
|
||
scenes: List[Dict[str, Any]],
|
||
voice_type: str = "sweet_female",
|
||
output_dir: str = None
|
||
) -> List[str]:
|
||
"""
|
||
为每个场景单独生成旁白音频
|
||
|
||
Args:
|
||
scenes: 场景列表,每个场景包含 voiceover 字段
|
||
voice_type: 音色类型
|
||
output_dir: 输出目录
|
||
|
||
Returns:
|
||
音频文件路径列表
|
||
"""
|
||
if output_dir:
|
||
output_dir = Path(output_dir)
|
||
output_dir.mkdir(exist_ok=True)
|
||
else:
|
||
output_dir = config.TEMP_DIR
|
||
|
||
audio_paths = []
|
||
|
||
for i, scene in enumerate(scenes):
|
||
vo_text = scene.get("voiceover", "")
|
||
|
||
if not vo_text or not vo_text.strip() or vo_text.startswith("("):
|
||
# 无旁白或是注释
|
||
audio_paths.append("")
|
||
continue
|
||
|
||
try:
|
||
output_path = str(output_dir / f"scene_{i+1}_vo.mp3")
|
||
path = generate_voiceover_volcengine(
|
||
text=vo_text.strip(),
|
||
voice_type=voice_type,
|
||
output_path=output_path
|
||
)
|
||
audio_paths.append(path)
|
||
except Exception as e:
|
||
logger.error(f"Scene {i+1} voiceover failed: {e}")
|
||
audio_paths.append("")
|
||
|
||
return audio_paths
|