""" 连贯生图模块 (Volcengine Doubao) 负责根据分镜脚本和原始素材生成一系列连贯的分镜图片 """ import base64 import logging import os import time import requests import json from pathlib import Path from typing import List, Dict, Any, Optional from PIL import Image import io from modules import storage import config from modules import path_utils logger = logging.getLogger(__name__) def _env_int(name: str, default: int) -> int: try: return int(os.getenv(name, str(default))) except Exception: return default # Tunables: slow channels can be hot; default conservative but adjustable. IMG_SUBMIT_TIMEOUT_S = _env_int("IMG_SUBMIT_TIMEOUT_S", 180) IMG_POLL_TIMEOUT_S = _env_int("IMG_POLL_TIMEOUT_S", 30) IMG_MAX_RETRIES = _env_int("IMG_MAX_RETRIES", 3) IMG_POLL_INTERVAL_S = _env_int("IMG_POLL_INTERVAL_S", 2) IMG_POLL_MAX_RETRIES = _env_int("IMG_POLL_MAX_RETRIES", 90) # 90*2s ~= 180s def _is_retryable_exception(e: Exception) -> bool: # Network / transient errors if isinstance(e, (requests.Timeout, requests.ConnectionError)): return True msg = str(e).lower() # Transient provider errors often contain these keywords if any(k in msg for k in ["timeout", "temporarily", "temporarily unavailable", "gateway", "rate", "try again"]): return True return False def _with_retries(fn, *, max_retries: int, label: str): last = None for attempt in range(1, max_retries + 1): try: return fn() except Exception as e: last = e retryable = _is_retryable_exception(e) logger.warning(f"[{label}] attempt {attempt}/{max_retries} failed: {e} (retryable={retryable})") if not retryable or attempt >= max_retries: raise # small backoff time.sleep(min(2 ** (attempt - 1), 4)) raise last # pragma: no cover class ImageGenerator: """连贯图片生成器 (Volcengine Provider)""" def __init__(self): self.api_key = config.VOLC_API_KEY # Endpoint: https://ark.cn-beijing.volces.com/api/v3/images/generations self.endpoint = f"https://ark.cn-beijing.volces.com/api/v3/images/generations" self.model = config.IMAGE_MODEL_ID def _encode_image(self, image_path: str) -> str: """读取图片,调整大小并转为 Base64""" try: with Image.open(image_path) as img: if img.mode != 'RGB': img = img.convert('RGB') max_size = 1024 if max(img.size) > max_size: img.thumbnail((max_size, max_size), Image.LANCZOS) buffer = io.BytesIO() img.save(buffer, format="JPEG", quality=80) return base64.b64encode(buffer.getvalue()).decode('utf-8') except Exception as e: logger.error(f"Error processing image {image_path}: {e}") return "" def generate_single_scene_image( self, scene: Dict[str, Any], original_image_path: Any, previous_image_path: Optional[str] = None, model_provider: str = "shubiaobiao", # "shubiaobiao", "gemini", "doubao" visual_anchor: str = "", # 视觉锚点,强制拼接到 prompt 前 project_id: Optional[str] = None, ) -> Optional[str]: """ 生成单张分镜图片 (Public) """ scene_id = scene["id"] visual_prompt = scene.get("visual_prompt", "") # 强制拼接 Visual Anchor (确保生图一致性) if visual_anchor and visual_anchor not in visual_prompt: visual_prompt = f"[{visual_anchor}] {visual_prompt}" logger.info(f"Scene {scene_id}: Prepended visual_anchor to prompt") logger.info(f"Generating image for Scene {scene_id} (Provider: {model_provider})...") input_images = [] # Handle original_image_path (can be str or list) if isinstance(original_image_path, list): input_images.extend(original_image_path) elif isinstance(original_image_path, str) and original_image_path: input_images.append(original_image_path) if previous_image_path: input_images.append(previous_image_path) try: out_dir = path_utils.project_images_dir(project_id) if project_id else config.TEMP_DIR out_name = path_utils.unique_filename( prefix="scene_image", ext="png", project_id=project_id, scene_id=scene_id, ) output_path = self._generate_single_image( prompt=visual_prompt, reference_images=input_images, output_filename=out_name, provider=model_provider, output_dir=out_dir, ) if output_path: return output_path else: raise RuntimeError(f"Image generation returned empty for Scene {scene_id}") except PermissionError as e: logger.error(f"Critical API Error for Scene {scene_id}: {e}") raise e except Exception as e: logger.error(f"Image generation failed for Scene {scene_id}: {e}") raise e def generate_group_images_doubao( self, scenes: List[Dict[str, Any]], reference_images: List[str], visual_anchor: str = "", # 视觉锚点 project_id: Optional[str] = None, ) -> Dict[int, str]: """ Doubao 组图生成 (Batch) - 拼接 Prompt 一次生成多张 """ logger.info("Starting Doubao Group Image Generation...") # 1. 拼接 Prompts # 格式: "Global: [Visual Anchor] ... | S1: ... | S2: ..." scene_prompts = [] for scene in scenes: # 提取分镜 Visual Prompt p = scene.get("visual_prompt", "") scene_prompts.append(f"S{scene['id']}:{p}") combined_scenes_text = " | ".join(scene_prompts) # 构造 Combined Prompt - 将 visual_anchor 放入 Global 部分 global_context = f"[{visual_anchor}] Consistent product appearance & style." if visual_anchor else "Consistent product appearance & style." combined_prompt = ( f"Global: {global_context}\n" f"{combined_scenes_text}\n" "Req: 1 img per scene. Follow specific angles." ) logger.info(f"Visual Anchor applied to group prompt: {visual_anchor[:50]}..." if visual_anchor else "No visual_anchor") # 记录 Prompt 长度供参考 logger.info(f"Doubao Group Prompt Length: {len(combined_prompt)} chars") # 2. 准备 payload payload = { "model": config.DOUBAO_IMG_MODEL, "prompt": combined_prompt, "sequential_image_generation": "auto", # 开启组图 "sequential_image_generation_options": { "max_images": len(scenes) # 限制最大张数 }, "response_format": "url", "size": "1440x2560", "stream": False, "watermark": False } # 3. 处理参考图 img_urls = [] if reference_images: for ref_path in reference_images: if os.path.exists(ref_path): try: url = storage.upload_file(ref_path) if url: img_urls.append(url) except Exception as e: logger.warning(f"Failed to upload ref image {ref_path}: {e}") if img_urls: payload["image_urls"] = img_urls headers = { "Content-Type": "application/json", "Authorization": f"Bearer {config.VOLC_API_KEY}" } try: logger.info(f"Submitting Doubao Group Request (Scenes: {len(scenes)})...") resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=240) resp.raise_for_status() data = resp.json() results = {} if "data" in data: items = data["data"] logger.info(f"Doubao returned {len(items)} images.") # 尝试将返回的图片映射回 Scene # 假设顺序一致 for i, item in enumerate(items): if i < len(scenes): scene_id = scenes[i]["id"] image_url = item.get("url") if image_url: # Download img_resp = requests.get(image_url, timeout=60) out_dir = path_utils.project_images_dir(project_id) if project_id else config.TEMP_DIR out_name = path_utils.unique_filename( prefix="scene_image", ext="png", project_id=project_id, scene_id=scene_id, extra="group", ) output_path = out_dir / out_name with open(output_path, "wb") as f: f.write(img_resp.content) results[scene_id] = str(output_path) return results except Exception as e: logger.error(f"Doubao Group Generation Failed: {e}") raise e def _generate_single_image( self, prompt: str, reference_images: List[str], output_filename: str, provider: str = "shubiaobiao", output_dir: Optional[Path] = None, ) -> Optional[str]: """统一入口""" out_dir = output_dir or config.TEMP_DIR if provider == "doubao": return self._generate_single_image_doubao(prompt, reference_images, output_filename, out_dir) elif provider == "gemini": return self._generate_single_image_gemini(prompt, reference_images, output_filename, out_dir) else: return self._generate_single_image_shubiao(prompt, reference_images, output_filename, out_dir) def _generate_single_image_doubao( self, prompt: str, reference_images: List[str], output_filename: str, output_dir: Path ) -> Optional[str]: """调用 Volcengine Doubao (Image API)""" # 1. Upload all reference images to R2 img_urls = [] if reference_images: for ref_path in reference_images: if os.path.exists(ref_path): try: url = storage.upload_file(ref_path) if url: img_urls.append(url) logger.info(f"Uploaded Doubao ref image: {url}") except Exception as e: logger.warning(f"Failed to upload Doubao ref image {ref_path}: {e}") payload = { "model": config.DOUBAO_IMG_MODEL, "prompt": prompt, "sequential_image_generation": "disabled", "response_format": "url", "size": "1440x2560", "stream": False, "watermark": False } if img_urls: payload["image_urls"] = img_urls logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', image_urls={len(img_urls)}") else: logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', no reference images") headers = { "Content-Type": "application/json", "Authorization": f"Bearer {config.VOLC_API_KEY}" } def _call(): logger.info(f"Submitting to Doubao Image: {self.endpoint}") resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=IMG_SUBMIT_TIMEOUT_S) if resp.status_code != 200: msg = f"Doubao Image Failed ({resp.status_code}): {resp.text}" logger.error(msg) raise RuntimeError(msg) data = resp.json() if "data" in data and len(data["data"]) > 0: image_url = data["data"][0].get("url") if image_url: img_resp = requests.get(image_url, timeout=60) img_resp.raise_for_status() output_path = output_dir / output_filename with open(output_path, "wb") as f: f.write(img_resp.content) return str(output_path) raise RuntimeError(f"No image URL in Doubao response: {data}") return _with_retries(_call, max_retries=IMG_MAX_RETRIES, label="doubao_image") def _generate_single_image_shubiao( self, prompt: str, reference_images: List[str], output_filename: str, output_dir: Path ) -> Optional[str]: """调用 api2img.shubiaobiao.com 通道生成图片(同步返回 base64)""" # 准备参考图,内联 base64 方式 parts = [{"text": prompt}] # 严格过滤和排序参考图 valid_refs = [] if reference_images: for p in reference_images: if p and os.path.exists(p) and p not in valid_refs: valid_refs.append(p) logger.info(f"[Shubiaobiao] Input reference images ({len(valid_refs)}): {valid_refs}") if valid_refs: for ref_path in valid_refs: try: encoded = self._encode_image(ref_path) if encoded: parts.append({ "inlineData": { "mimeType": "image/jpeg", "data": encoded } }) except Exception as e: logger.error(f"Failed to encode image {ref_path}: {e}") logger.info(f"[Shubiaobiao] Final payload parts count: {len(parts)} (1 prompt + {len(parts)-1} images)") payload = { "contents": [{ "role": "user", "parts": parts }], "generationConfig": { "responseModalities": ["IMAGE"], "imageConfig": { "aspectRatio": "9:16", "imageSize": "2K" } } } endpoint = f"{config.SHUBIAOBIAO_IMG_BASE_URL}/v1beta/models/{config.SHUBIAOBIAO_IMG_MODEL_NAME}:generateContent" headers = { "x-goog-api-key": config.SHUBIAOBIAO_IMG_KEY, "Content-Type": "application/json" } def _call(): logger.info(f"Submitting to Shubiaobiao Img: {endpoint}") resp = requests.post(endpoint, json=payload, headers=headers, timeout=IMG_SUBMIT_TIMEOUT_S) if resp.status_code != 200: msg = f"Shubiaobiao 提交失败 ({resp.status_code}): {resp.text}" logger.error(msg) raise RuntimeError(msg) data = resp.json() # 查找 base64 图像 img_b64 = None candidates = data.get("candidates") or [] if candidates: content_parts = candidates[0].get("content", {}).get("parts", []) for part in content_parts: inline = part.get("inlineData") if isinstance(part, dict) else None if inline and inline.get("data"): img_b64 = inline["data"] break if not img_b64: msg = f"Shubiaobiao 响应缺少图片数据: {data}" logger.error(msg) raise RuntimeError(msg) output_path = output_dir / output_filename with open(output_path, "wb") as f: f.write(base64.b64decode(img_b64)) logger.info(f"Shubiaobiao Generation Success: {output_path}") return str(output_path) return _with_retries(_call, max_retries=IMG_MAX_RETRIES, label="shubiaobiao_image") def _generate_single_image_gemini( self, prompt: str, reference_images: List[str], output_filename: str, output_dir: Path ) -> Optional[str]: """调用 Gemini (Wuyin Keji / NanoBanana-Pro) 生成单张图片""" # 1. 构造 Payload payload = { "prompt": prompt, "aspectRatio": "9:16", "imageSize": "2K" } # 处理参考图 (Image-to-Image) if reference_images: valid_paths = [] seen = set() for p in reference_images: if p and os.path.exists(p) and p not in seen: valid_paths.append(p) seen.add(p) if valid_paths: img_urls = [] for ref_path in valid_paths: try: url = storage.upload_file(ref_path) if url: img_urls.append(url) logger.info(f"Uploaded ref image: {url}") except Exception as e: logger.warning(f"Error uploading ref image {ref_path}: {e}") if img_urls: payload["img_url"] = img_urls logger.info(f"Using {len(img_urls)} reference images for Gemini Img2Img") headers = { "Authorization": config.GEMINI_IMG_KEY, "Content-Type": "application/json;charset:utf-8" } def _call(): # 2. 提交任务 logger.info(f"Submitting to Gemini: {config.GEMINI_IMG_API_URL}") resp = requests.post(config.GEMINI_IMG_API_URL, json=payload, headers=headers, timeout=IMG_SUBMIT_TIMEOUT_S) if resp.status_code != 200: msg = f"Gemini 提交失败 ({resp.status_code}): {resp.text}" logger.error(msg) raise RuntimeError(msg) data = resp.json() if data.get("code") != 200: msg = f"Gemini 返回错误: {data}" logger.error(msg) raise RuntimeError(msg) task_id = data.get("data", {}).get("id") if not task_id: raise RuntimeError(f"Gemini 响应缺少 task id: {data}") logger.info(f"Gemini Task Submitted, ID: {task_id}") # 3. 轮询状态 for _ in range(IMG_POLL_MAX_RETRIES): time.sleep(IMG_POLL_INTERVAL_S) poll_url = f"{config.GEMINI_IMG_DETAIL_URL}?key={config.GEMINI_IMG_KEY}&id={task_id}" try: poll_resp = requests.get(poll_url, headers=headers, timeout=IMG_POLL_TIMEOUT_S) except requests.Timeout: continue except Exception as e: continue if poll_resp.status_code != 200: continue poll_data = poll_resp.json() if poll_data.get("code") != 200: raise RuntimeError(f"Gemini 轮询返回错误: {poll_data}") result_data = poll_data.get("data", {}) or {} status = result_data.get("status") # 0:排队, 1:生成中, 2:成功, 3:失败 if status == 2: image_url = result_data.get("image_url") if not image_url: raise RuntimeError("Gemini 成功但缺少 image_url") logger.info(f"Gemini Generation Success: {image_url}") img_resp = requests.get(image_url, timeout=60) img_resp.raise_for_status() output_path = output_dir / output_filename with open(output_path, "wb") as f: f.write(img_resp.content) return str(output_path) if status == 3: fail_reason = result_data.get("fail_reason", "Unknown") raise RuntimeError(f"Gemini 生成失败: {fail_reason}") raise RuntimeError("Gemini 生成超时") return _with_retries(_call, max_retries=IMG_MAX_RETRIES, label="gemini_image")