Files
video-flow/modules/video_gen.py
Tony Zhang 33a165a615 feat: video-flow initial commit
- app.py: Streamlit UI for video generation workflow
- main_flow.py: CLI tool with argparse support
- modules/: Business logic modules (script_gen, image_gen, video_gen, composer, etc.)
- config.py: Configuration with API keys and paths
- requirements.txt: Python dependencies
- docs/: System prompt documentation
2025-12-12 19:18:27 +08:00

270 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图生视频模块 (Volcengine Doubao-SeedDance)
负责将分镜图片转换为视频片段
"""
import logging
import time
import requests
import os
from typing import Dict, Any, List, Optional
from pathlib import Path
import config
from modules import storage
from modules.db_manager import db
logger = logging.getLogger(__name__)
class VideoGenerator:
"""图生视频生成器"""
def __init__(self):
self.api_key = config.VOLC_API_KEY
self.base_url = config.VOLC_BASE_URL
self.model_id = config.VIDEO_MODEL_ID
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
def submit_scene_video_task(
self,
project_id: str,
scene_id: int,
image_path: str,
prompt: str
) -> str:
"""
提交单场景视频生成任务
Returns: task_id or None
"""
if not image_path or not os.path.exists(image_path):
logger.warning(f"Skipping video generation for Scene {scene_id}: Image not found")
return None
# 上传图片到 R2 获取 URL
logger.info(f"Uploading image for Scene {scene_id}...")
image_url = storage.upload_file(image_path)
if not image_url:
logger.error(f"Failed to upload image for Scene {scene_id}")
return None
logger.info(f"Submitting video task for Scene {scene_id}...")
task_id = self._submit_task(image_url, prompt)
if task_id:
# 立即保存 task_id 到数据库,状态为 processing
db.save_asset(
project_id=project_id,
scene_id=scene_id,
asset_type="video",
status="processing",
task_id=task_id,
local_path=None
)
return task_id
def recover_video_from_task(self, task_id: str, output_path: str) -> bool:
"""
尝试从已有的 task_id 恢复视频 (查询状态并下载)
"""
try:
status, video_url = self._check_task(task_id)
logger.info(f"Recovering task {task_id}: status={status}")
if status == "succeeded" and video_url:
downloaded_path = self._download_video(video_url, os.path.basename(output_path))
if downloaded_path:
# 如果下载的文件名和目标路径不一致 (download_video 使用 filename 参数拼接到 TEMP_DIR)
# 需要移动或确认。 _download_video 返回完整路径。
# 如果 output_path 是绝对路径且不同,则移动。
if os.path.abspath(downloaded_path) != os.path.abspath(output_path):
import shutil
shutil.move(downloaded_path, output_path)
return True
return False
except Exception as e:
logger.error(f"Failed to recover video task {task_id}: {e}")
return False
def check_task_status(self, task_id: str) -> tuple[str, str]:
"""
查询任务状态
Returns: (status, video_url)
"""
return self._check_task(task_id)
def generate_scene_videos(
self,
project_id: str,
script: Dict[str, Any],
scene_images: Dict[int, str]
) -> Dict[int, str]:
"""
批量生成分镜视频 (Legacy: 阻塞式轮询)
"""
generated_videos = {}
tasks = {} # scene_id -> task_id
scenes = script.get("scenes", [])
# 1. 提交所有任务
for scene in scenes:
scene_id = scene["id"]
image_path = scene_images.get(scene_id)
prompt = scene.get("video_prompt", "High quality video")
# Use new method signature with project_id
task_id = self.submit_scene_video_task(project_id, scene_id, image_path, prompt)
if task_id:
tasks[scene_id] = task_id
logger.info(f"Task submitted: {task_id}")
else:
logger.error(f"Failed to submit task for Scene {scene_id}")
# 2. 轮询任务状态
pending_tasks = list(tasks.keys())
# 设置最大轮询时间 (例如 10 分钟)
start_time = time.time()
timeout = 600
while pending_tasks and (time.time() - start_time < timeout):
logger.info(f"Polling status for {len(pending_tasks)} tasks...")
still_pending = []
for scene_id in pending_tasks:
task_id = tasks[scene_id]
status, result_url = self._check_task(task_id)
if status == "succeeded":
logger.info(f"Scene {scene_id} video generated successfully")
# 下载视频
video_path = self._download_video(result_url, f"scene_{scene_id}_video.mp4")
if video_path:
generated_videos[scene_id] = video_path
# Update DB
db.save_asset(
project_id=project_id,
scene_id=scene_id,
asset_type="video",
status="completed",
local_path=video_path,
task_id=task_id
)
elif status == "failed" or status == "cancelled":
logger.error(f"Scene {scene_id} task failed/cancelled")
db.save_asset(
project_id=project_id,
scene_id=scene_id,
asset_type="video",
status="failed",
task_id=task_id
)
else:
# running, queued
still_pending.append(scene_id)
pending_tasks = still_pending
if pending_tasks:
time.sleep(5) # 间隔 5 秒
return generated_videos
def _submit_task(self, image_url: str, prompt: str) -> str:
"""提交生成任务"""
url = f"{self.base_url}/contents/generations/tasks"
payload = {
"model": self.model_id,
"content": [
{
"type": "text",
"text": f"{prompt} --resolution 1080p --duration 3 --camerafixed false --watermark false"
},
{
"type": "image_url",
"image_url": {"url": image_url}
}
]
}
try:
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
response.raise_for_status()
data = response.json()
# ID might be at top level or in data object depending on exact API version response
# Document says: { "id": "...", "status": "..." } or similar
task_id = data.get("id")
if not task_id and "data" in data:
task_id = data.get("data", {}).get("id")
return task_id
except Exception as e:
logger.error(f"Task submission failed: {e}")
if 'response' in locals():
logger.error(f"Response: {response.text}")
return None
def _check_task(self, task_id: str) -> tuple[str, str]:
"""
检查任务状态
Returns: (status, content_url)
Status: queued, running, succeeded, failed, cancelled
"""
url = f"{self.base_url}/contents/generations/tasks/{task_id}"
try:
response = requests.get(url, headers=self.headers, timeout=30)
response.raise_for_status()
data = response.json()
# API Response structure:
# { "id": "...", "status": "succeeded", "content": [ { "url": "...", "video_url": "..." } ] }
# Or nested in "data" key
result = data
if "data" in data and "status" not in data: # Check if wrapped in data
result = data["data"]
status = result.get("status")
content_url = None
if status == "succeeded":
if "content" in result:
content = result["content"]
if isinstance(content, list) and len(content) > 0:
item = content[0]
content_url = item.get("video_url") or item.get("url")
elif isinstance(content, dict):
content_url = content.get("video_url") or content.get("url")
return status, content_url
except Exception as e:
logger.error(f"Check task failed: {e}")
return "unknown", None
def _download_video(self, url: str, filename: str) -> str:
"""下载视频到临时目录"""
if not url:
return None
try:
response = requests.get(url, stream=True, timeout=60)
response.raise_for_status()
output_path = config.TEMP_DIR / filename
with open(output_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return str(output_path)
except Exception as e:
logger.error(f"Download video failed: {e}")
return None