Files
video-flow/modules/video_gen.py

322 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图生视频模块 (Volcengine Doubao-SeedDance)
负责将分镜图片转换为视频片段
"""
import logging
import time
import requests
import os
from os import stat
from typing import Dict, Any, List, Optional
from pathlib import Path
import config
from modules import storage
from modules.db_manager import db
from modules import path_utils
logger = logging.getLogger(__name__)
class VideoGenerator:
"""图生视频生成器"""
def __init__(self):
self.api_key = config.VOLC_API_KEY
self.base_url = config.VOLC_BASE_URL
self.model_id = config.VIDEO_MODEL_ID
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
def submit_scene_video_task(
self,
project_id: str,
scene_id: int,
image_path: str,
prompt: str
) -> str:
"""
提交单场景视频生成任务
Returns: task_id or None
"""
if not image_path or not os.path.exists(image_path):
logger.warning(f"Skipping video generation for Scene {scene_id}: Image not found")
return None
# 上传图片到 R2 获取 URL
logger.info(f"Uploading image for Scene {scene_id}...")
image_url = storage.upload_file(image_path)
if not image_url:
logger.error(f"Failed to upload image for Scene {scene_id}")
return None
logger.info(f"Submitting video task for Scene {scene_id}...")
task_id = self._submit_task(image_url, prompt)
if task_id:
try:
st = stat(image_path)
source_sig = {
"source_image_local_path": image_path,
"source_image_size": int(getattr(st, "st_size", 0) or 0),
"source_image_mtime": float(getattr(st, "st_mtime", 0.0) or 0.0),
"source_image_r2_url": image_url,
"submitted_at": time.time(),
}
except Exception:
source_sig = {
"source_image_local_path": image_path,
"source_image_r2_url": image_url,
"submitted_at": time.time(),
}
# 立即保存 task_id 到数据库,状态为 processing
db.save_asset(
project_id=project_id,
scene_id=scene_id,
asset_type="video",
status="processing",
task_id=task_id,
local_path=None,
metadata=source_sig,
)
return task_id
def recover_video_from_task(self, task_id: str, output_path: str) -> bool:
"""
尝试从已有的 task_id 恢复视频 (查询状态并下载)
"""
try:
status, video_url = self._check_task(task_id)
logger.info(f"Recovering task {task_id}: status={status}")
if status == "succeeded" and video_url:
return self._download_video_to(video_url, output_path)
return False
except Exception as e:
logger.error(f"Failed to recover video task {task_id}: {e}")
return False
def check_task_status(self, task_id: str) -> tuple[str, str]:
"""
查询任务状态
Returns: (status, video_url)
"""
return self._check_task(task_id)
def generate_scene_videos(
self,
project_id: str,
script: Dict[str, Any],
scene_images: Dict[int, str]
) -> Dict[int, str]:
"""
批量生成分镜视频 (Legacy: 阻塞式轮询)
"""
generated_videos = {}
tasks = {} # scene_id -> task_id
scenes = script.get("scenes", [])
# 1. 提交所有任务
for scene in scenes:
scene_id = scene["id"]
image_path = scene_images.get(scene_id)
prompt = scene.get("video_prompt", "High quality video")
# Use new method signature with project_id
task_id = self.submit_scene_video_task(project_id, scene_id, image_path, prompt)
if task_id:
tasks[scene_id] = task_id
logger.info(f"Task submitted: {task_id}")
else:
logger.error(f"Failed to submit task for Scene {scene_id}")
# 2. 轮询任务状态
pending_tasks = list(tasks.keys())
# 设置最大轮询时间 (例如 10 分钟)
start_time = time.time()
timeout = 600
while pending_tasks and (time.time() - start_time < timeout):
logger.info(f"Polling status for {len(pending_tasks)} tasks...")
still_pending = []
for scene_id in pending_tasks:
task_id = tasks[scene_id]
status, result_url = self._check_task(task_id)
if status == "succeeded":
logger.info(f"Scene {scene_id} video generated successfully")
# 下载视频
out_dir = path_utils.project_videos_dir(project_id) if project_id else config.TEMP_DIR
fname = path_utils.unique_filename(
prefix="scene_video",
ext="mp4",
project_id=project_id,
scene_id=scene_id,
extra=(task_id[-8:] if isinstance(task_id, str) else None),
)
video_path = self._download_video(result_url, fname, output_dir=out_dir)
if video_path:
generated_videos[scene_id] = video_path
# Update DB
db.save_asset(
project_id=project_id,
scene_id=scene_id,
asset_type="video",
status="completed",
local_path=video_path,
task_id=task_id
)
elif status == "failed" or status == "cancelled":
logger.error(f"Scene {scene_id} task failed/cancelled")
db.save_asset(
project_id=project_id,
scene_id=scene_id,
asset_type="video",
status="failed",
task_id=task_id
)
else:
# running, queued
still_pending.append(scene_id)
pending_tasks = still_pending
if pending_tasks:
time.sleep(5) # 间隔 5 秒
return generated_videos
def _submit_task(self, image_url: str, prompt: str) -> str:
"""提交生成任务"""
url = f"{self.base_url}/contents/generations/tasks"
payload = {
"model": self.model_id,
"content": [
{
"type": "text",
"text": f"{prompt} --resolution 1080p --duration 3 --camerafixed false --watermark false"
},
{
"type": "image_url",
"image_url": {"url": image_url}
}
]
}
try:
response = requests.post(url, headers=self.headers, json=payload, timeout=30)
response.raise_for_status()
data = response.json()
# ID might be at top level or in data object depending on exact API version response
# Document says: { "id": "...", "status": "..." } or similar
task_id = data.get("id")
if not task_id and "data" in data:
task_id = data.get("data", {}).get("id")
return task_id
except Exception as e:
logger.error(f"Task submission failed: {e}")
if 'response' in locals():
logger.error(f"Response: {response.text}")
return None
def _check_task(self, task_id: str) -> tuple[str, str]:
"""
检查任务状态
Returns: (status, content_url)
Status: queued, running, succeeded, failed, cancelled
"""
url = f"{self.base_url}/contents/generations/tasks/{task_id}"
try:
response = requests.get(url, headers=self.headers, timeout=30)
response.raise_for_status()
data = response.json()
# API Response structure:
# { "id": "...", "status": "succeeded", "content": [ { "url": "...", "video_url": "..." } ] }
# Or nested in "data" key
result = data
if "data" in data and "status" not in data: # Check if wrapped in data
result = data["data"]
status = result.get("status")
content_url = None
if status == "succeeded":
# Try multiple known shapes for volcengine response
content = result.get("content")
# sometimes nested: data.content or data.result.content, etc.
if not content and isinstance(result.get("result"), dict):
content = result["result"].get("content")
def _extract_url(obj):
if isinstance(obj, dict):
return obj.get("video_url") or obj.get("url")
return None
if isinstance(content, list) and content:
# pick the first item that has a usable url
for item in content:
u = _extract_url(item)
if u:
content_url = u
break
elif isinstance(content, dict):
content_url = _extract_url(content)
return status, content_url
except Exception as e:
logger.error(f"Check task failed: {e}")
return "unknown", None
def _download_video_to(self, url: str, output_path: str) -> bool:
"""下载视频到指定路径(避免 TEMP_DIR 固定文件名导致覆盖)"""
if not url or not output_path:
return False
try:
out_p = Path(output_path)
out_p.parent.mkdir(parents=True, exist_ok=True)
response = requests.get(url, stream=True, timeout=60)
response.raise_for_status()
with open(out_p, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
return True
except Exception as e:
logger.error(f"Download video failed: {e}")
return False
def _download_video(self, url: str, filename: str, output_dir: Optional[Path] = None) -> str:
"""下载视频到临时目录(默认使用 config.TEMP_DIR可指定 output_dir 避免覆盖)"""
if not url:
return None
try:
response = requests.get(url, stream=True, timeout=60)
response.raise_for_status()
out_dir = output_dir or config.TEMP_DIR
out_dir.mkdir(parents=True, exist_ok=True)
output_path = out_dir / filename
with open(output_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
return str(output_path)
except Exception as e:
logger.error(f"Download video failed: {e}")
return None