- app.py: Streamlit UI for video generation workflow - main_flow.py: CLI tool with argparse support - modules/: Business logic modules (script_gen, image_gen, video_gen, composer, etc.) - config.py: Configuration with API keys and paths - requirements.txt: Python dependencies - docs/: System prompt documentation
270 lines
9.7 KiB
Python
270 lines
9.7 KiB
Python
"""
|
||
图生视频模块 (Volcengine Doubao-SeedDance)
|
||
负责将分镜图片转换为视频片段
|
||
"""
|
||
import logging
|
||
import time
|
||
import requests
|
||
import os
|
||
from typing import Dict, Any, List, Optional
|
||
from pathlib import Path
|
||
|
||
import config
|
||
from modules import storage
|
||
from modules.db_manager import db
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class VideoGenerator:
|
||
"""图生视频生成器"""
|
||
|
||
def __init__(self):
|
||
self.api_key = config.VOLC_API_KEY
|
||
self.base_url = config.VOLC_BASE_URL
|
||
self.model_id = config.VIDEO_MODEL_ID
|
||
|
||
self.headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {self.api_key}"
|
||
}
|
||
|
||
def submit_scene_video_task(
|
||
self,
|
||
project_id: str,
|
||
scene_id: int,
|
||
image_path: str,
|
||
prompt: str
|
||
) -> str:
|
||
"""
|
||
提交单场景视频生成任务
|
||
Returns: task_id or None
|
||
"""
|
||
if not image_path or not os.path.exists(image_path):
|
||
logger.warning(f"Skipping video generation for Scene {scene_id}: Image not found")
|
||
return None
|
||
|
||
# 上传图片到 R2 获取 URL
|
||
logger.info(f"Uploading image for Scene {scene_id}...")
|
||
image_url = storage.upload_file(image_path)
|
||
|
||
if not image_url:
|
||
logger.error(f"Failed to upload image for Scene {scene_id}")
|
||
return None
|
||
|
||
logger.info(f"Submitting video task for Scene {scene_id}...")
|
||
task_id = self._submit_task(image_url, prompt)
|
||
|
||
if task_id:
|
||
# 立即保存 task_id 到数据库,状态为 processing
|
||
db.save_asset(
|
||
project_id=project_id,
|
||
scene_id=scene_id,
|
||
asset_type="video",
|
||
status="processing",
|
||
task_id=task_id,
|
||
local_path=None
|
||
)
|
||
|
||
return task_id
|
||
|
||
def recover_video_from_task(self, task_id: str, output_path: str) -> bool:
|
||
"""
|
||
尝试从已有的 task_id 恢复视频 (查询状态并下载)
|
||
"""
|
||
try:
|
||
status, video_url = self._check_task(task_id)
|
||
logger.info(f"Recovering task {task_id}: status={status}")
|
||
|
||
if status == "succeeded" and video_url:
|
||
downloaded_path = self._download_video(video_url, os.path.basename(output_path))
|
||
if downloaded_path:
|
||
# 如果下载的文件名和目标路径不一致 (download_video 使用 filename 参数拼接到 TEMP_DIR),
|
||
# 需要移动或确认。 _download_video 返回完整路径。
|
||
# 如果 output_path 是绝对路径且不同,则移动。
|
||
if os.path.abspath(downloaded_path) != os.path.abspath(output_path):
|
||
import shutil
|
||
shutil.move(downloaded_path, output_path)
|
||
return True
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Failed to recover video task {task_id}: {e}")
|
||
return False
|
||
|
||
def check_task_status(self, task_id: str) -> tuple[str, str]:
|
||
"""
|
||
查询任务状态
|
||
Returns: (status, video_url)
|
||
"""
|
||
return self._check_task(task_id)
|
||
|
||
def generate_scene_videos(
|
||
self,
|
||
project_id: str,
|
||
script: Dict[str, Any],
|
||
scene_images: Dict[int, str]
|
||
) -> Dict[int, str]:
|
||
"""
|
||
批量生成分镜视频 (Legacy: 阻塞式轮询)
|
||
"""
|
||
generated_videos = {}
|
||
tasks = {} # scene_id -> task_id
|
||
|
||
scenes = script.get("scenes", [])
|
||
|
||
# 1. 提交所有任务
|
||
for scene in scenes:
|
||
scene_id = scene["id"]
|
||
image_path = scene_images.get(scene_id)
|
||
prompt = scene.get("video_prompt", "High quality video")
|
||
|
||
# Use new method signature with project_id
|
||
task_id = self.submit_scene_video_task(project_id, scene_id, image_path, prompt)
|
||
|
||
if task_id:
|
||
tasks[scene_id] = task_id
|
||
logger.info(f"Task submitted: {task_id}")
|
||
else:
|
||
logger.error(f"Failed to submit task for Scene {scene_id}")
|
||
|
||
# 2. 轮询任务状态
|
||
pending_tasks = list(tasks.keys())
|
||
|
||
# 设置最大轮询时间 (例如 10 分钟)
|
||
start_time = time.time()
|
||
timeout = 600
|
||
|
||
while pending_tasks and (time.time() - start_time < timeout):
|
||
logger.info(f"Polling status for {len(pending_tasks)} tasks...")
|
||
|
||
still_pending = []
|
||
for scene_id in pending_tasks:
|
||
task_id = tasks[scene_id]
|
||
status, result_url = self._check_task(task_id)
|
||
|
||
if status == "succeeded":
|
||
logger.info(f"Scene {scene_id} video generated successfully")
|
||
# 下载视频
|
||
video_path = self._download_video(result_url, f"scene_{scene_id}_video.mp4")
|
||
if video_path:
|
||
generated_videos[scene_id] = video_path
|
||
# Update DB
|
||
db.save_asset(
|
||
project_id=project_id,
|
||
scene_id=scene_id,
|
||
asset_type="video",
|
||
status="completed",
|
||
local_path=video_path,
|
||
task_id=task_id
|
||
)
|
||
elif status == "failed" or status == "cancelled":
|
||
logger.error(f"Scene {scene_id} task failed/cancelled")
|
||
db.save_asset(
|
||
project_id=project_id,
|
||
scene_id=scene_id,
|
||
asset_type="video",
|
||
status="failed",
|
||
task_id=task_id
|
||
)
|
||
else:
|
||
# running, queued
|
||
still_pending.append(scene_id)
|
||
|
||
pending_tasks = still_pending
|
||
if pending_tasks:
|
||
time.sleep(5) # 间隔 5 秒
|
||
|
||
return generated_videos
|
||
|
||
def _submit_task(self, image_url: str, prompt: str) -> str:
|
||
"""提交生成任务"""
|
||
url = f"{self.base_url}/contents/generations/tasks"
|
||
|
||
payload = {
|
||
"model": self.model_id,
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"{prompt} --resolution 1080p --duration 3 --camerafixed false --watermark false"
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": image_url}
|
||
}
|
||
]
|
||
}
|
||
|
||
try:
|
||
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
# ID might be at top level or in data object depending on exact API version response
|
||
# Document says: { "id": "...", "status": "..." } or similar
|
||
task_id = data.get("id")
|
||
if not task_id and "data" in data:
|
||
task_id = data.get("data", {}).get("id")
|
||
|
||
return task_id
|
||
except Exception as e:
|
||
logger.error(f"Task submission failed: {e}")
|
||
if 'response' in locals():
|
||
logger.error(f"Response: {response.text}")
|
||
return None
|
||
|
||
def _check_task(self, task_id: str) -> tuple[str, str]:
|
||
"""
|
||
检查任务状态
|
||
Returns: (status, content_url)
|
||
Status: queued, running, succeeded, failed, cancelled
|
||
"""
|
||
url = f"{self.base_url}/contents/generations/tasks/{task_id}"
|
||
|
||
try:
|
||
response = requests.get(url, headers=self.headers, timeout=30)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
|
||
# API Response structure:
|
||
# { "id": "...", "status": "succeeded", "content": [ { "url": "...", "video_url": "..." } ] }
|
||
# Or nested in "data" key
|
||
|
||
result = data
|
||
if "data" in data and "status" not in data: # Check if wrapped in data
|
||
result = data["data"]
|
||
|
||
status = result.get("status")
|
||
content_url = None
|
||
|
||
if status == "succeeded":
|
||
if "content" in result:
|
||
content = result["content"]
|
||
if isinstance(content, list) and len(content) > 0:
|
||
item = content[0]
|
||
content_url = item.get("video_url") or item.get("url")
|
||
elif isinstance(content, dict):
|
||
content_url = content.get("video_url") or content.get("url")
|
||
|
||
return status, content_url
|
||
|
||
except Exception as e:
|
||
logger.error(f"Check task failed: {e}")
|
||
return "unknown", None
|
||
|
||
def _download_video(self, url: str, filename: str) -> str:
|
||
"""下载视频到临时目录"""
|
||
if not url:
|
||
return None
|
||
|
||
try:
|
||
response = requests.get(url, stream=True, timeout=60)
|
||
response.raise_for_status()
|
||
|
||
output_path = config.TEMP_DIR / filename
|
||
with open(output_path, "wb") as f:
|
||
for chunk in response.iter_content(chunk_size=8192):
|
||
f.write(chunk)
|
||
|
||
return str(output_path)
|
||
except Exception as e:
|
||
logger.error(f"Download video failed: {e}")
|
||
return None
|