548 lines
21 KiB
Python
548 lines
21 KiB
Python
"""
|
||
连贯生图模块 (Volcengine Doubao)
|
||
负责根据分镜脚本和原始素材生成一系列连贯的分镜图片
|
||
"""
|
||
import base64
|
||
import logging
|
||
import os
|
||
import time
|
||
import requests
|
||
import json
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any, Optional
|
||
from PIL import Image
|
||
import io
|
||
from modules import storage
|
||
|
||
import config
|
||
from modules import path_utils
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _env_int(name: str, default: int) -> int:
|
||
try:
|
||
return int(os.getenv(name, str(default)))
|
||
except Exception:
|
||
return default
|
||
|
||
|
||
# Tunables: slow channels can be hot; default conservative but adjustable.
|
||
IMG_SUBMIT_TIMEOUT_S = _env_int("IMG_SUBMIT_TIMEOUT_S", 180)
|
||
IMG_POLL_TIMEOUT_S = _env_int("IMG_POLL_TIMEOUT_S", 30)
|
||
IMG_MAX_RETRIES = _env_int("IMG_MAX_RETRIES", 3)
|
||
IMG_POLL_INTERVAL_S = _env_int("IMG_POLL_INTERVAL_S", 2)
|
||
IMG_POLL_MAX_RETRIES = _env_int("IMG_POLL_MAX_RETRIES", 90) # 90*2s ~= 180s
|
||
|
||
|
||
def _is_retryable_exception(e: Exception) -> bool:
|
||
# Network / transient errors
|
||
if isinstance(e, (requests.Timeout, requests.ConnectionError)):
|
||
return True
|
||
msg = str(e).lower()
|
||
# Transient provider errors often contain these keywords
|
||
if any(k in msg for k in ["timeout", "temporarily", "temporarily unavailable", "gateway", "rate", "try again"]):
|
||
return True
|
||
return False
|
||
|
||
|
||
def _with_retries(fn, *, max_retries: int, label: str):
|
||
last = None
|
||
for attempt in range(1, max_retries + 1):
|
||
try:
|
||
return fn()
|
||
except Exception as e:
|
||
last = e
|
||
retryable = _is_retryable_exception(e)
|
||
logger.warning(f"[{label}] attempt {attempt}/{max_retries} failed: {e} (retryable={retryable})")
|
||
if not retryable or attempt >= max_retries:
|
||
raise
|
||
# small backoff
|
||
time.sleep(min(2 ** (attempt - 1), 4))
|
||
raise last # pragma: no cover
|
||
|
||
class ImageGenerator:
|
||
"""连贯图片生成器 (Volcengine Provider)"""
|
||
|
||
def __init__(self):
|
||
self.api_key = config.VOLC_API_KEY
|
||
# Endpoint: https://ark.cn-beijing.volces.com/api/v3/images/generations
|
||
self.endpoint = f"https://ark.cn-beijing.volces.com/api/v3/images/generations"
|
||
self.model = config.IMAGE_MODEL_ID
|
||
|
||
def _encode_image(self, image_path: str) -> str:
|
||
"""读取图片,调整大小并转为 Base64"""
|
||
try:
|
||
with Image.open(image_path) as img:
|
||
if img.mode != 'RGB':
|
||
img = img.convert('RGB')
|
||
|
||
max_size = 1024
|
||
if max(img.size) > max_size:
|
||
img.thumbnail((max_size, max_size), Image.LANCZOS)
|
||
|
||
buffer = io.BytesIO()
|
||
img.save(buffer, format="JPEG", quality=80)
|
||
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||
except Exception as e:
|
||
logger.error(f"Error processing image {image_path}: {e}")
|
||
return ""
|
||
|
||
def generate_single_scene_image(
|
||
self,
|
||
scene: Dict[str, Any],
|
||
original_image_path: Any,
|
||
previous_image_path: Optional[str] = None,
|
||
model_provider: str = "shubiaobiao", # "shubiaobiao", "gemini", "doubao"
|
||
visual_anchor: str = "", # 视觉锚点,强制拼接到 prompt 前
|
||
project_id: Optional[str] = None,
|
||
) -> Optional[str]:
|
||
"""
|
||
生成单张分镜图片 (Public)
|
||
"""
|
||
scene_id = scene["id"]
|
||
visual_prompt = scene.get("visual_prompt", "")
|
||
|
||
# 强制拼接 Visual Anchor (确保生图一致性)
|
||
if visual_anchor and visual_anchor not in visual_prompt:
|
||
visual_prompt = f"[{visual_anchor}] {visual_prompt}"
|
||
logger.info(f"Scene {scene_id}: Prepended visual_anchor to prompt")
|
||
|
||
logger.info(f"Generating image for Scene {scene_id} (Provider: {model_provider})...")
|
||
|
||
input_images = []
|
||
|
||
# Handle original_image_path (can be str or list)
|
||
if isinstance(original_image_path, list):
|
||
input_images.extend(original_image_path)
|
||
elif isinstance(original_image_path, str) and original_image_path:
|
||
input_images.append(original_image_path)
|
||
|
||
if previous_image_path:
|
||
input_images.append(previous_image_path)
|
||
|
||
try:
|
||
out_dir = path_utils.project_images_dir(project_id) if project_id else config.TEMP_DIR
|
||
out_name = path_utils.unique_filename(
|
||
prefix="scene_image",
|
||
ext="png",
|
||
project_id=project_id,
|
||
scene_id=scene_id,
|
||
)
|
||
output_path = self._generate_single_image(
|
||
prompt=visual_prompt,
|
||
reference_images=input_images,
|
||
output_filename=out_name,
|
||
provider=model_provider,
|
||
output_dir=out_dir,
|
||
)
|
||
|
||
if output_path:
|
||
return output_path
|
||
else:
|
||
raise RuntimeError(f"Image generation returned empty for Scene {scene_id}")
|
||
|
||
except PermissionError as e:
|
||
logger.error(f"Critical API Error for Scene {scene_id}: {e}")
|
||
raise e
|
||
except Exception as e:
|
||
logger.error(f"Image generation failed for Scene {scene_id}: {e}")
|
||
raise e
|
||
|
||
def generate_group_images_doubao(
|
||
self,
|
||
scenes: List[Dict[str, Any]],
|
||
reference_images: List[str],
|
||
visual_anchor: str = "", # 视觉锚点
|
||
project_id: Optional[str] = None,
|
||
) -> Dict[int, str]:
|
||
"""
|
||
Doubao 组图生成 (Batch) - 拼接 Prompt 一次生成多张
|
||
"""
|
||
logger.info("Starting Doubao Group Image Generation...")
|
||
|
||
# 1. 拼接 Prompts
|
||
# 格式: "Global: [Visual Anchor] ... | S1: ... | S2: ..."
|
||
|
||
scene_prompts = []
|
||
for scene in scenes:
|
||
# 提取分镜 Visual Prompt
|
||
p = scene.get("visual_prompt", "")
|
||
scene_prompts.append(f"S{scene['id']}:{p}")
|
||
|
||
combined_scenes_text = " | ".join(scene_prompts)
|
||
|
||
# 构造 Combined Prompt - 将 visual_anchor 放入 Global 部分
|
||
global_context = f"[{visual_anchor}] Consistent product appearance & style." if visual_anchor else "Consistent product appearance & style."
|
||
combined_prompt = (
|
||
f"Global: {global_context}\n"
|
||
f"{combined_scenes_text}\n"
|
||
"Req: 1 img per scene. Follow specific angles."
|
||
)
|
||
|
||
logger.info(f"Visual Anchor applied to group prompt: {visual_anchor[:50]}..." if visual_anchor else "No visual_anchor")
|
||
|
||
# 记录 Prompt 长度供参考
|
||
logger.info(f"Doubao Group Prompt Length: {len(combined_prompt)} chars")
|
||
|
||
# 2. 准备 payload
|
||
payload = {
|
||
"model": config.DOUBAO_IMG_MODEL,
|
||
"prompt": combined_prompt,
|
||
"sequential_image_generation": "auto", # 开启组图
|
||
"sequential_image_generation_options": {
|
||
"max_images": len(scenes) # 限制最大张数
|
||
},
|
||
"response_format": "url",
|
||
"size": "1440x2560",
|
||
"stream": False,
|
||
"watermark": False
|
||
}
|
||
|
||
# 3. 处理参考图
|
||
img_urls = []
|
||
if reference_images:
|
||
for ref_path in reference_images:
|
||
if os.path.exists(ref_path):
|
||
try:
|
||
url = storage.upload_file(ref_path)
|
||
if url: img_urls.append(url)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to upload ref image {ref_path}: {e}")
|
||
|
||
if img_urls:
|
||
payload["image_urls"] = img_urls
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {config.VOLC_API_KEY}"
|
||
}
|
||
|
||
try:
|
||
logger.info(f"Submitting Doubao Group Request (Scenes: {len(scenes)})...")
|
||
resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=240)
|
||
resp.raise_for_status()
|
||
|
||
data = resp.json()
|
||
results = {}
|
||
|
||
if "data" in data:
|
||
items = data["data"]
|
||
logger.info(f"Doubao returned {len(items)} images.")
|
||
|
||
# 尝试将返回的图片映射回 Scene
|
||
# 假设顺序一致
|
||
for i, item in enumerate(items):
|
||
if i < len(scenes):
|
||
scene_id = scenes[i]["id"]
|
||
image_url = item.get("url")
|
||
|
||
if image_url:
|
||
# Download
|
||
img_resp = requests.get(image_url, timeout=60)
|
||
out_dir = path_utils.project_images_dir(project_id) if project_id else config.TEMP_DIR
|
||
out_name = path_utils.unique_filename(
|
||
prefix="scene_image",
|
||
ext="png",
|
||
project_id=project_id,
|
||
scene_id=scene_id,
|
||
extra="group",
|
||
)
|
||
output_path = out_dir / out_name
|
||
with open(output_path, "wb") as f:
|
||
f.write(img_resp.content)
|
||
results[scene_id] = str(output_path)
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
logger.error(f"Doubao Group Generation Failed: {e}")
|
||
raise e
|
||
|
||
def _generate_single_image(
|
||
self,
|
||
prompt: str,
|
||
reference_images: List[str],
|
||
output_filename: str,
|
||
provider: str = "shubiaobiao",
|
||
output_dir: Optional[Path] = None,
|
||
) -> Optional[str]:
|
||
"""统一入口"""
|
||
out_dir = output_dir or config.TEMP_DIR
|
||
if provider == "doubao":
|
||
return self._generate_single_image_doubao(prompt, reference_images, output_filename, out_dir)
|
||
elif provider == "gemini":
|
||
return self._generate_single_image_gemini(prompt, reference_images, output_filename, out_dir)
|
||
else:
|
||
return self._generate_single_image_shubiao(prompt, reference_images, output_filename, out_dir)
|
||
|
||
def _generate_single_image_doubao(
|
||
self,
|
||
prompt: str,
|
||
reference_images: List[str],
|
||
output_filename: str,
|
||
output_dir: Path
|
||
) -> Optional[str]:
|
||
"""调用 Volcengine Doubao (Image API)"""
|
||
|
||
# 1. Upload all reference images to R2
|
||
img_urls = []
|
||
if reference_images:
|
||
for ref_path in reference_images:
|
||
if os.path.exists(ref_path):
|
||
try:
|
||
url = storage.upload_file(ref_path)
|
||
if url:
|
||
img_urls.append(url)
|
||
logger.info(f"Uploaded Doubao ref image: {url}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to upload Doubao ref image {ref_path}: {e}")
|
||
|
||
payload = {
|
||
"model": config.DOUBAO_IMG_MODEL,
|
||
"prompt": prompt,
|
||
"sequential_image_generation": "disabled",
|
||
"response_format": "url",
|
||
"size": "1440x2560",
|
||
"stream": False,
|
||
"watermark": False
|
||
}
|
||
|
||
if img_urls:
|
||
payload["image_urls"] = img_urls
|
||
logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', image_urls={len(img_urls)}")
|
||
else:
|
||
logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', no reference images")
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {config.VOLC_API_KEY}"
|
||
}
|
||
|
||
def _call():
|
||
logger.info(f"Submitting to Doubao Image: {self.endpoint}")
|
||
resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=IMG_SUBMIT_TIMEOUT_S)
|
||
|
||
if resp.status_code != 200:
|
||
msg = f"Doubao Image Failed ({resp.status_code}): {resp.text}"
|
||
logger.error(msg)
|
||
raise RuntimeError(msg)
|
||
|
||
data = resp.json()
|
||
|
||
if "data" in data and len(data["data"]) > 0:
|
||
image_url = data["data"][0].get("url")
|
||
if image_url:
|
||
img_resp = requests.get(image_url, timeout=60)
|
||
img_resp.raise_for_status()
|
||
|
||
output_path = output_dir / output_filename
|
||
with open(output_path, "wb") as f:
|
||
f.write(img_resp.content)
|
||
return str(output_path)
|
||
|
||
raise RuntimeError(f"No image URL in Doubao response: {data}")
|
||
return _with_retries(_call, max_retries=IMG_MAX_RETRIES, label="doubao_image")
|
||
|
||
def _generate_single_image_shubiao(
|
||
self,
|
||
prompt: str,
|
||
reference_images: List[str],
|
||
output_filename: str,
|
||
output_dir: Path
|
||
) -> Optional[str]:
|
||
"""调用 api2img.shubiaobiao.com 通道生成图片(同步返回 base64)"""
|
||
# 准备参考图,内联 base64 方式
|
||
parts = [{"text": prompt}]
|
||
|
||
# 严格过滤和排序参考图
|
||
valid_refs = []
|
||
if reference_images:
|
||
for p in reference_images:
|
||
if p and os.path.exists(p) and p not in valid_refs:
|
||
valid_refs.append(p)
|
||
|
||
logger.info(f"[Shubiaobiao] Input reference images ({len(valid_refs)}): {valid_refs}")
|
||
|
||
if valid_refs:
|
||
for ref_path in valid_refs:
|
||
try:
|
||
encoded = self._encode_image(ref_path)
|
||
if encoded:
|
||
parts.append({
|
||
"inlineData": {
|
||
"mimeType": "image/jpeg",
|
||
"data": encoded
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"Failed to encode image {ref_path}: {e}")
|
||
|
||
logger.info(f"[Shubiaobiao] Final payload parts count: {len(parts)} (1 prompt + {len(parts)-1} images)")
|
||
|
||
payload = {
|
||
"contents": [{
|
||
"role": "user",
|
||
"parts": parts
|
||
}],
|
||
"generationConfig": {
|
||
"responseModalities": ["IMAGE"],
|
||
"imageConfig": {
|
||
"aspectRatio": "9:16",
|
||
"imageSize": "2K"
|
||
}
|
||
}
|
||
}
|
||
|
||
endpoint = f"{config.SHUBIAOBIAO_IMG_BASE_URL}/v1beta/models/{config.SHUBIAOBIAO_IMG_MODEL_NAME}:generateContent"
|
||
headers = {
|
||
"x-goog-api-key": config.SHUBIAOBIAO_IMG_KEY,
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
def _call():
|
||
logger.info(f"Submitting to Shubiaobiao Img: {endpoint}")
|
||
resp = requests.post(endpoint, json=payload, headers=headers, timeout=IMG_SUBMIT_TIMEOUT_S)
|
||
|
||
if resp.status_code != 200:
|
||
msg = f"Shubiaobiao 提交失败 ({resp.status_code}): {resp.text}"
|
||
logger.error(msg)
|
||
raise RuntimeError(msg)
|
||
|
||
data = resp.json()
|
||
|
||
# 查找 base64 图像
|
||
img_b64 = None
|
||
candidates = data.get("candidates") or []
|
||
if candidates:
|
||
content_parts = candidates[0].get("content", {}).get("parts", [])
|
||
for part in content_parts:
|
||
inline = part.get("inlineData") if isinstance(part, dict) else None
|
||
if inline and inline.get("data"):
|
||
img_b64 = inline["data"]
|
||
break
|
||
|
||
if not img_b64:
|
||
msg = f"Shubiaobiao 响应缺少图片数据: {data}"
|
||
logger.error(msg)
|
||
raise RuntimeError(msg)
|
||
|
||
output_path = output_dir / output_filename
|
||
with open(output_path, "wb") as f:
|
||
f.write(base64.b64decode(img_b64))
|
||
|
||
logger.info(f"Shubiaobiao Generation Success: {output_path}")
|
||
return str(output_path)
|
||
return _with_retries(_call, max_retries=IMG_MAX_RETRIES, label="shubiaobiao_image")
|
||
|
||
def _generate_single_image_gemini(
|
||
self,
|
||
prompt: str,
|
||
reference_images: List[str],
|
||
output_filename: str,
|
||
output_dir: Path
|
||
) -> Optional[str]:
|
||
"""调用 Gemini (Wuyin Keji / NanoBanana-Pro) 生成单张图片"""
|
||
|
||
# 1. 构造 Payload
|
||
payload = {
|
||
"prompt": prompt,
|
||
"aspectRatio": "9:16",
|
||
"imageSize": "2K"
|
||
}
|
||
|
||
# 处理参考图 (Image-to-Image)
|
||
if reference_images:
|
||
valid_paths = []
|
||
seen = set()
|
||
for p in reference_images:
|
||
if p and os.path.exists(p) and p not in seen:
|
||
valid_paths.append(p)
|
||
seen.add(p)
|
||
|
||
if valid_paths:
|
||
img_urls = []
|
||
for ref_path in valid_paths:
|
||
try:
|
||
url = storage.upload_file(ref_path)
|
||
if url:
|
||
img_urls.append(url)
|
||
logger.info(f"Uploaded ref image: {url}")
|
||
except Exception as e:
|
||
logger.warning(f"Error uploading ref image {ref_path}: {e}")
|
||
|
||
if img_urls:
|
||
payload["img_url"] = img_urls
|
||
logger.info(f"Using {len(img_urls)} reference images for Gemini Img2Img")
|
||
|
||
headers = {
|
||
"Authorization": config.GEMINI_IMG_KEY,
|
||
"Content-Type": "application/json;charset:utf-8"
|
||
}
|
||
|
||
def _call():
|
||
# 2. 提交任务
|
||
logger.info(f"Submitting to Gemini: {config.GEMINI_IMG_API_URL}")
|
||
resp = requests.post(config.GEMINI_IMG_API_URL, json=payload, headers=headers, timeout=IMG_SUBMIT_TIMEOUT_S)
|
||
|
||
if resp.status_code != 200:
|
||
msg = f"Gemini 提交失败 ({resp.status_code}): {resp.text}"
|
||
logger.error(msg)
|
||
raise RuntimeError(msg)
|
||
|
||
data = resp.json()
|
||
if data.get("code") != 200:
|
||
msg = f"Gemini 返回错误: {data}"
|
||
logger.error(msg)
|
||
raise RuntimeError(msg)
|
||
|
||
task_id = data.get("data", {}).get("id")
|
||
if not task_id:
|
||
raise RuntimeError(f"Gemini 响应缺少 task id: {data}")
|
||
|
||
logger.info(f"Gemini Task Submitted, ID: {task_id}")
|
||
|
||
# 3. 轮询状态
|
||
for _ in range(IMG_POLL_MAX_RETRIES):
|
||
time.sleep(IMG_POLL_INTERVAL_S)
|
||
|
||
poll_url = f"{config.GEMINI_IMG_DETAIL_URL}?key={config.GEMINI_IMG_KEY}&id={task_id}"
|
||
try:
|
||
poll_resp = requests.get(poll_url, headers=headers, timeout=IMG_POLL_TIMEOUT_S)
|
||
except requests.Timeout:
|
||
continue
|
||
except Exception as e:
|
||
continue
|
||
|
||
if poll_resp.status_code != 200:
|
||
continue
|
||
|
||
poll_data = poll_resp.json()
|
||
if poll_data.get("code") != 200:
|
||
raise RuntimeError(f"Gemini 轮询返回错误: {poll_data}")
|
||
|
||
result_data = poll_data.get("data", {}) or {}
|
||
status = result_data.get("status") # 0:排队, 1:生成中, 2:成功, 3:失败
|
||
|
||
if status == 2:
|
||
image_url = result_data.get("image_url")
|
||
if not image_url:
|
||
raise RuntimeError("Gemini 成功但缺少 image_url")
|
||
|
||
logger.info(f"Gemini Generation Success: {image_url}")
|
||
img_resp = requests.get(image_url, timeout=60)
|
||
img_resp.raise_for_status()
|
||
|
||
output_path = output_dir / output_filename
|
||
with open(output_path, "wb") as f:
|
||
f.write(img_resp.content)
|
||
|
||
return str(output_path)
|
||
|
||
if status == 3:
|
||
fail_reason = result_data.get("fail_reason", "Unknown")
|
||
raise RuntimeError(f"Gemini 生成失败: {fail_reason}")
|
||
|
||
raise RuntimeError("Gemini 生成超时")
|
||
return _with_retries(_call, max_retries=IMG_MAX_RETRIES, label="gemini_image")
|