Files
video-flow/modules/image_gen.py

565 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
连贯生图模块 (Volcengine Doubao)
负责根据分镜脚本和原始素材生成一系列连贯的分镜图片
"""
import base64
import logging
import os
import time
import requests
import json
from pathlib import Path
from typing import List, Dict, Any, Optional
from PIL import Image
import io
from modules import storage
import config
from modules import path_utils
logger = logging.getLogger(__name__)
def _env_int(name: str, default: int) -> int:
try:
return int(os.getenv(name, str(default)))
except Exception:
return default
# Tunables: slow channels can be hot; default conservative but adjustable.
# IMPORTANT: we enforce minimums to avoid accidental misconfig (e.g. 120s) causing flaky UX.
IMG_SUBMIT_TIMEOUT_S = max(_env_int("IMG_SUBMIT_TIMEOUT_S", 180), 180)
IMG_POLL_TIMEOUT_S = max(_env_int("IMG_POLL_TIMEOUT_S", 30), 10)
IMG_MAX_RETRIES = max(_env_int("IMG_MAX_RETRIES", 3), 3)
IMG_POLL_INTERVAL_S = max(_env_int("IMG_POLL_INTERVAL_S", 2), 1)
IMG_POLL_MAX_RETRIES = max(_env_int("IMG_POLL_MAX_RETRIES", 90), 90) # 90*2s ~= 180s
def _is_retryable_exception(e: Exception) -> bool:
# Network / transient errors
if isinstance(e, (requests.Timeout, requests.ConnectionError)):
return True
msg = str(e).lower()
# Transient provider errors often contain these keywords
if any(k in msg for k in ["timeout", "temporarily", "temporarily unavailable", "gateway", "rate", "try again"]):
return True
# Treat common HTTP transient status codes as retryable when they bubble up as RuntimeError text
# Examples from our code: "Shubiaobiao 提交失败 (429): ..." / "Doubao Image Failed (502): ..."
try:
import re
m = re.search(r"\((\d{3})\)", msg)
if not m:
# sometimes formatted like "429:" without parentheses
m = re.search(r"\b(\d{3})\b", msg)
if m:
code = int(m.group(1))
if code in (408, 409, 425, 429, 500, 502, 503, 504):
return True
if 500 <= code <= 599:
return True
except Exception:
pass
return False
def _with_retries(fn, *, max_retries: int, label: str):
last = None
for attempt in range(1, max_retries + 1):
try:
return fn()
except Exception as e:
last = e
retryable = _is_retryable_exception(e)
logger.warning(f"[{label}] attempt {attempt}/{max_retries} failed: {e} (retryable={retryable})")
if not retryable or attempt >= max_retries:
raise
# small backoff
time.sleep(min(2 ** (attempt - 1), 4))
raise last # pragma: no cover
class ImageGenerator:
"""连贯图片生成器 (Volcengine Provider)"""
def __init__(self):
self.api_key = config.VOLC_API_KEY
# Endpoint: https://ark.cn-beijing.volces.com/api/v3/images/generations
self.endpoint = f"https://ark.cn-beijing.volces.com/api/v3/images/generations"
self.model = config.IMAGE_MODEL_ID
def _encode_image(self, image_path: str) -> str:
"""读取图片,调整大小并转为 Base64"""
try:
with Image.open(image_path) as img:
if img.mode != 'RGB':
img = img.convert('RGB')
max_size = 1024
if max(img.size) > max_size:
img.thumbnail((max_size, max_size), Image.LANCZOS)
buffer = io.BytesIO()
img.save(buffer, format="JPEG", quality=80)
return base64.b64encode(buffer.getvalue()).decode('utf-8')
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
return ""
def generate_single_scene_image(
self,
scene: Dict[str, Any],
original_image_path: Any,
previous_image_path: Optional[str] = None,
model_provider: str = "shubiaobiao", # "shubiaobiao", "gemini", "doubao"
visual_anchor: str = "", # 视觉锚点,强制拼接到 prompt 前
project_id: Optional[str] = None,
) -> Optional[str]:
"""
生成单张分镜图片 (Public)
"""
scene_id = scene["id"]
visual_prompt = scene.get("visual_prompt", "")
# 强制拼接 Visual Anchor (确保生图一致性)
if visual_anchor and visual_anchor not in visual_prompt:
visual_prompt = f"[{visual_anchor}] {visual_prompt}"
logger.info(f"Scene {scene_id}: Prepended visual_anchor to prompt")
logger.info(f"Generating image for Scene {scene_id} (Provider: {model_provider})...")
input_images = []
# Handle original_image_path (can be str or list)
if isinstance(original_image_path, list):
input_images.extend(original_image_path)
elif isinstance(original_image_path, str) and original_image_path:
input_images.append(original_image_path)
if previous_image_path:
input_images.append(previous_image_path)
try:
out_dir = path_utils.project_images_dir(project_id) if project_id else config.TEMP_DIR
out_name = path_utils.unique_filename(
prefix="scene_image",
ext="png",
project_id=project_id,
scene_id=scene_id,
)
output_path = self._generate_single_image(
prompt=visual_prompt,
reference_images=input_images,
output_filename=out_name,
provider=model_provider,
output_dir=out_dir,
)
if output_path:
return output_path
else:
raise RuntimeError(f"Image generation returned empty for Scene {scene_id}")
except PermissionError as e:
logger.error(f"Critical API Error for Scene {scene_id}: {e}")
raise e
except Exception as e:
logger.error(f"Image generation failed for Scene {scene_id}: {e}")
raise e
def generate_group_images_doubao(
self,
scenes: List[Dict[str, Any]],
reference_images: List[str],
visual_anchor: str = "", # 视觉锚点
project_id: Optional[str] = None,
) -> Dict[int, str]:
"""
Doubao 组图生成 (Batch) - 拼接 Prompt 一次生成多张
"""
logger.info("Starting Doubao Group Image Generation...")
# 1. 拼接 Prompts
# 格式: "Global: [Visual Anchor] ... | S1: ... | S2: ..."
scene_prompts = []
for scene in scenes:
# 提取分镜 Visual Prompt
p = scene.get("visual_prompt", "")
scene_prompts.append(f"S{scene['id']}:{p}")
combined_scenes_text = " | ".join(scene_prompts)
# 构造 Combined Prompt - 将 visual_anchor 放入 Global 部分
global_context = f"[{visual_anchor}] Consistent product appearance & style." if visual_anchor else "Consistent product appearance & style."
combined_prompt = (
f"Global: {global_context}\n"
f"{combined_scenes_text}\n"
"Req: 1 img per scene. Follow specific angles."
)
logger.info(f"Visual Anchor applied to group prompt: {visual_anchor[:50]}..." if visual_anchor else "No visual_anchor")
# 记录 Prompt 长度供参考
logger.info(f"Doubao Group Prompt Length: {len(combined_prompt)} chars")
# 2. 准备 payload
payload = {
"model": config.DOUBAO_IMG_MODEL,
"prompt": combined_prompt,
"sequential_image_generation": "auto", # 开启组图
"sequential_image_generation_options": {
"max_images": len(scenes) # 限制最大张数
},
"response_format": "url",
"size": "1440x2560",
"stream": False,
"watermark": False
}
# 3. 处理参考图
img_urls = []
if reference_images:
for ref_path in reference_images:
if os.path.exists(ref_path):
try:
url = storage.upload_file(ref_path)
if url: img_urls.append(url)
except Exception as e:
logger.warning(f"Failed to upload ref image {ref_path}: {e}")
if img_urls:
payload["image_urls"] = img_urls
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {config.VOLC_API_KEY}"
}
try:
logger.info(f"Submitting Doubao Group Request (Scenes: {len(scenes)})...")
resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=240)
resp.raise_for_status()
data = resp.json()
results = {}
if "data" in data:
items = data["data"]
logger.info(f"Doubao returned {len(items)} images.")
# 尝试将返回的图片映射回 Scene
# 假设顺序一致
for i, item in enumerate(items):
if i < len(scenes):
scene_id = scenes[i]["id"]
image_url = item.get("url")
if image_url:
# Download
img_resp = requests.get(image_url, timeout=60)
out_dir = path_utils.project_images_dir(project_id) if project_id else config.TEMP_DIR
out_name = path_utils.unique_filename(
prefix="scene_image",
ext="png",
project_id=project_id,
scene_id=scene_id,
extra="group",
)
output_path = out_dir / out_name
with open(output_path, "wb") as f:
f.write(img_resp.content)
results[scene_id] = str(output_path)
return results
except Exception as e:
logger.error(f"Doubao Group Generation Failed: {e}")
raise e
def _generate_single_image(
self,
prompt: str,
reference_images: List[str],
output_filename: str,
provider: str = "shubiaobiao",
output_dir: Optional[Path] = None,
) -> Optional[str]:
"""统一入口"""
out_dir = output_dir or config.TEMP_DIR
if provider == "doubao":
return self._generate_single_image_doubao(prompt, reference_images, output_filename, out_dir)
elif provider == "gemini":
return self._generate_single_image_gemini(prompt, reference_images, output_filename, out_dir)
else:
return self._generate_single_image_shubiao(prompt, reference_images, output_filename, out_dir)
def _generate_single_image_doubao(
self,
prompt: str,
reference_images: List[str],
output_filename: str,
output_dir: Path
) -> Optional[str]:
"""调用 Volcengine Doubao (Image API)"""
# 1. Upload all reference images to R2
img_urls = []
if reference_images:
for ref_path in reference_images:
if os.path.exists(ref_path):
try:
url = storage.upload_file(ref_path)
if url:
img_urls.append(url)
logger.info(f"Uploaded Doubao ref image: {url}")
except Exception as e:
logger.warning(f"Failed to upload Doubao ref image {ref_path}: {e}")
payload = {
"model": config.DOUBAO_IMG_MODEL,
"prompt": prompt,
"sequential_image_generation": "disabled",
"response_format": "url",
"size": "1440x2560",
"stream": False,
"watermark": False
}
if img_urls:
payload["image_urls"] = img_urls
logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', image_urls={len(img_urls)}")
else:
logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', no reference images")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {config.VOLC_API_KEY}"
}
def _call():
logger.info(f"Submitting to Doubao Image: {self.endpoint}")
resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=IMG_SUBMIT_TIMEOUT_S)
if resp.status_code != 200:
msg = f"Doubao Image Failed ({resp.status_code}): {resp.text}"
logger.error(msg)
raise RuntimeError(msg)
data = resp.json()
if "data" in data and len(data["data"]) > 0:
image_url = data["data"][0].get("url")
if image_url:
img_resp = requests.get(image_url, timeout=60)
img_resp.raise_for_status()
output_path = output_dir / output_filename
with open(output_path, "wb") as f:
f.write(img_resp.content)
return str(output_path)
raise RuntimeError(f"No image URL in Doubao response: {data}")
return _with_retries(_call, max_retries=IMG_MAX_RETRIES, label="doubao_image")
def _generate_single_image_shubiao(
self,
prompt: str,
reference_images: List[str],
output_filename: str,
output_dir: Path
) -> Optional[str]:
"""调用 api2img.shubiaobiao.com 通道生成图片(同步返回 base64"""
# 准备参考图,内联 base64 方式
parts = [{"text": prompt}]
# 严格过滤和排序参考图
valid_refs = []
if reference_images:
for p in reference_images:
if p and os.path.exists(p) and p not in valid_refs:
valid_refs.append(p)
logger.info(f"[Shubiaobiao] Input reference images ({len(valid_refs)}): {valid_refs}")
if valid_refs:
for ref_path in valid_refs:
try:
encoded = self._encode_image(ref_path)
if encoded:
parts.append({
"inlineData": {
"mimeType": "image/jpeg",
"data": encoded
}
})
except Exception as e:
logger.error(f"Failed to encode image {ref_path}: {e}")
logger.info(f"[Shubiaobiao] Final payload parts count: {len(parts)} (1 prompt + {len(parts)-1} images)")
payload = {
"contents": [{
"role": "user",
"parts": parts
}],
"generationConfig": {
"responseModalities": ["IMAGE"],
"imageConfig": {
"aspectRatio": "9:16",
"imageSize": "2K"
}
}
}
endpoint = f"{config.SHUBIAOBIAO_IMG_BASE_URL}/v1beta/models/{config.SHUBIAOBIAO_IMG_MODEL_NAME}:generateContent"
headers = {
"x-goog-api-key": config.SHUBIAOBIAO_IMG_KEY,
"Content-Type": "application/json"
}
def _call():
logger.info(f"Submitting to Shubiaobiao Img: {endpoint}")
resp = requests.post(endpoint, json=payload, headers=headers, timeout=IMG_SUBMIT_TIMEOUT_S)
if resp.status_code != 200:
msg = f"Shubiaobiao 提交失败 ({resp.status_code}): {resp.text}"
logger.error(msg)
raise RuntimeError(msg)
data = resp.json()
# 查找 base64 图像
img_b64 = None
candidates = data.get("candidates") or []
if candidates:
content_parts = candidates[0].get("content", {}).get("parts", [])
for part in content_parts:
inline = part.get("inlineData") if isinstance(part, dict) else None
if inline and inline.get("data"):
img_b64 = inline["data"]
break
if not img_b64:
msg = f"Shubiaobiao 响应缺少图片数据: {data}"
logger.error(msg)
raise RuntimeError(msg)
output_path = output_dir / output_filename
with open(output_path, "wb") as f:
f.write(base64.b64decode(img_b64))
logger.info(f"Shubiaobiao Generation Success: {output_path}")
return str(output_path)
return _with_retries(_call, max_retries=IMG_MAX_RETRIES, label="shubiaobiao_image")
def _generate_single_image_gemini(
self,
prompt: str,
reference_images: List[str],
output_filename: str,
output_dir: Path
) -> Optional[str]:
"""调用 Gemini (Wuyin Keji / NanoBanana-Pro) 生成单张图片"""
# 1. 构造 Payload
payload = {
"prompt": prompt,
"aspectRatio": "9:16",
"imageSize": "2K"
}
# 处理参考图 (Image-to-Image)
if reference_images:
valid_paths = []
seen = set()
for p in reference_images:
if p and os.path.exists(p) and p not in seen:
valid_paths.append(p)
seen.add(p)
if valid_paths:
img_urls = []
for ref_path in valid_paths:
try:
url = storage.upload_file(ref_path)
if url:
img_urls.append(url)
logger.info(f"Uploaded ref image: {url}")
except Exception as e:
logger.warning(f"Error uploading ref image {ref_path}: {e}")
if img_urls:
payload["img_url"] = img_urls
logger.info(f"Using {len(img_urls)} reference images for Gemini Img2Img")
headers = {
"Authorization": config.GEMINI_IMG_KEY,
"Content-Type": "application/json;charset:utf-8"
}
def _call():
# 2. 提交任务
logger.info(f"Submitting to Gemini: {config.GEMINI_IMG_API_URL}")
resp = requests.post(config.GEMINI_IMG_API_URL, json=payload, headers=headers, timeout=IMG_SUBMIT_TIMEOUT_S)
if resp.status_code != 200:
msg = f"Gemini 提交失败 ({resp.status_code}): {resp.text}"
logger.error(msg)
raise RuntimeError(msg)
data = resp.json()
if data.get("code") != 200:
msg = f"Gemini 返回错误: {data}"
logger.error(msg)
raise RuntimeError(msg)
task_id = data.get("data", {}).get("id")
if not task_id:
raise RuntimeError(f"Gemini 响应缺少 task id: {data}")
logger.info(f"Gemini Task Submitted, ID: {task_id}")
# 3. 轮询状态
for _ in range(IMG_POLL_MAX_RETRIES):
time.sleep(IMG_POLL_INTERVAL_S)
poll_url = f"{config.GEMINI_IMG_DETAIL_URL}?key={config.GEMINI_IMG_KEY}&id={task_id}"
try:
poll_resp = requests.get(poll_url, headers=headers, timeout=IMG_POLL_TIMEOUT_S)
except requests.Timeout:
continue
except Exception as e:
continue
if poll_resp.status_code != 200:
continue
poll_data = poll_resp.json()
if poll_data.get("code") != 200:
raise RuntimeError(f"Gemini 轮询返回错误: {poll_data}")
result_data = poll_data.get("data", {}) or {}
status = result_data.get("status") # 0:排队, 1:生成中, 2:成功, 3:失败
if status == 2:
image_url = result_data.get("image_url")
if not image_url:
raise RuntimeError("Gemini 成功但缺少 image_url")
logger.info(f"Gemini Generation Success: {image_url}")
img_resp = requests.get(image_url, timeout=60)
img_resp.raise_for_status()
output_path = output_dir / output_filename
with open(output_path, "wb") as f:
f.write(img_resp.content)
return str(output_path)
if status == 3:
fail_reason = result_data.get("fail_reason", "Unknown")
raise RuntimeError(f"Gemini 生成失败: {fail_reason}")
raise RuntimeError("Gemini 生成超时")
return _with_retries(_call, max_retries=IMG_MAX_RETRIES, label="gemini_image")