Files
video-flow/modules/image_gen.py
Tony Zhang 33a165a615 feat: video-flow initial commit
- app.py: Streamlit UI for video generation workflow
- main_flow.py: CLI tool with argparse support
- modules/: Business logic modules (script_gen, image_gen, video_gen, composer, etc.)
- config.py: Configuration with API keys and paths
- requirements.txt: Python dependencies
- docs/: System prompt documentation
2025-12-12 19:18:27 +08:00

492 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
连贯生图模块 (Volcengine Doubao)
负责根据分镜脚本和原始素材生成一系列连贯的分镜图片
"""
import base64
import logging
import os
import time
import requests
import json
from pathlib import Path
from typing import List, Dict, Any, Optional
from PIL import Image
import io
from modules import storage
import config
logger = logging.getLogger(__name__)
class ImageGenerator:
"""连贯图片生成器 (Volcengine Provider)"""
def __init__(self):
self.api_key = config.VOLC_API_KEY
# Endpoint: https://ark.cn-beijing.volces.com/api/v3/images/generations
self.endpoint = f"https://ark.cn-beijing.volces.com/api/v3/images/generations"
self.model = config.IMAGE_MODEL_ID
def _encode_image(self, image_path: str) -> str:
"""读取图片,调整大小并转为 Base64"""
try:
with Image.open(image_path) as img:
if img.mode != 'RGB':
img = img.convert('RGB')
max_size = 1024
if max(img.size) > max_size:
img.thumbnail((max_size, max_size), Image.LANCZOS)
buffer = io.BytesIO()
img.save(buffer, format="JPEG", quality=80)
return base64.b64encode(buffer.getvalue()).decode('utf-8')
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
return ""
def generate_single_scene_image(
self,
scene: Dict[str, Any],
original_image_path: Any,
previous_image_path: Optional[str] = None,
model_provider: str = "shubiaobiao", # "shubiaobiao", "gemini", "doubao"
visual_anchor: str = "" # 视觉锚点,强制拼接到 prompt 前
) -> Optional[str]:
"""
生成单张分镜图片 (Public)
"""
scene_id = scene["id"]
visual_prompt = scene.get("visual_prompt", "")
# 强制拼接 Visual Anchor (确保生图一致性)
if visual_anchor and visual_anchor not in visual_prompt:
visual_prompt = f"[{visual_anchor}] {visual_prompt}"
logger.info(f"Scene {scene_id}: Prepended visual_anchor to prompt")
logger.info(f"Generating image for Scene {scene_id} (Provider: {model_provider})...")
input_images = []
# Handle original_image_path (can be str or list)
if isinstance(original_image_path, list):
input_images.extend(original_image_path)
elif isinstance(original_image_path, str) and original_image_path:
input_images.append(original_image_path)
if previous_image_path:
input_images.append(previous_image_path)
try:
output_path = self._generate_single_image(
prompt=visual_prompt,
reference_images=input_images,
output_filename=f"scene_{scene_id}_{int(time.time())}.png",
provider=model_provider
)
if output_path:
return output_path
else:
raise RuntimeError(f"Image generation returned empty for Scene {scene_id}")
except PermissionError as e:
logger.error(f"Critical API Error for Scene {scene_id}: {e}")
raise e
except Exception as e:
logger.error(f"Image generation failed for Scene {scene_id}: {e}")
raise e
def generate_group_images_doubao(
self,
scenes: List[Dict[str, Any]],
reference_images: List[str],
visual_anchor: str = "" # 视觉锚点
) -> Dict[int, str]:
"""
Doubao 组图生成 (Batch) - 拼接 Prompt 一次生成多张
"""
logger.info("Starting Doubao Group Image Generation...")
# 1. 拼接 Prompts
# 格式: "Global: [Visual Anchor] ... | S1: ... | S2: ..."
scene_prompts = []
for scene in scenes:
# 提取分镜 Visual Prompt
p = scene.get("visual_prompt", "")
scene_prompts.append(f"S{scene['id']}:{p}")
combined_scenes_text = " | ".join(scene_prompts)
# 构造 Combined Prompt - 将 visual_anchor 放入 Global 部分
global_context = f"[{visual_anchor}] Consistent product appearance & style." if visual_anchor else "Consistent product appearance & style."
combined_prompt = (
f"Global: {global_context}\n"
f"{combined_scenes_text}\n"
"Req: 1 img per scene. Follow specific angles."
)
logger.info(f"Visual Anchor applied to group prompt: {visual_anchor[:50]}..." if visual_anchor else "No visual_anchor")
# 记录 Prompt 长度供参考
logger.info(f"Doubao Group Prompt Length: {len(combined_prompt)} chars")
# 2. 准备 payload
payload = {
"model": config.DOUBAO_IMG_MODEL,
"prompt": combined_prompt,
"sequential_image_generation": "auto", # 开启组图
"sequential_image_generation_options": {
"max_images": len(scenes) # 限制最大张数
},
"response_format": "url",
"size": "1440x2560",
"stream": False,
"watermark": False
}
# 3. 处理参考图
img_urls = []
if reference_images:
for ref_path in reference_images:
if os.path.exists(ref_path):
try:
url = storage.upload_file(ref_path)
if url: img_urls.append(url)
except Exception as e:
logger.warning(f"Failed to upload ref image {ref_path}: {e}")
if img_urls:
payload["image_urls"] = img_urls
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {config.VOLC_API_KEY}"
}
try:
logger.info(f"Submitting Doubao Group Request (Scenes: {len(scenes)})...")
resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=240)
resp.raise_for_status()
data = resp.json()
results = {}
if "data" in data:
items = data["data"]
logger.info(f"Doubao returned {len(items)} images.")
# 尝试将返回的图片映射回 Scene
# 假设顺序一致
for i, item in enumerate(items):
if i < len(scenes):
scene_id = scenes[i]["id"]
image_url = item.get("url")
if image_url:
# Download
img_resp = requests.get(image_url, timeout=60)
output_path = config.TEMP_DIR / f"scene_{scene_id}_{int(time.time())}.png"
with open(output_path, "wb") as f:
f.write(img_resp.content)
results[scene_id] = str(output_path)
return results
except Exception as e:
logger.error(f"Doubao Group Generation Failed: {e}")
raise e
def _generate_single_image(
self,
prompt: str,
reference_images: List[str],
output_filename: str,
provider: str = "shubiaobiao"
) -> Optional[str]:
"""统一入口"""
if provider == "doubao":
return self._generate_single_image_doubao(prompt, reference_images, output_filename)
elif provider == "gemini":
return self._generate_single_image_gemini(prompt, reference_images, output_filename)
else:
return self._generate_single_image_shubiao(prompt, reference_images, output_filename)
def _generate_single_image_doubao(
self,
prompt: str,
reference_images: List[str],
output_filename: str
) -> Optional[str]:
"""调用 Volcengine Doubao (Image API)"""
# 1. Upload all reference images to R2
img_urls = []
if reference_images:
for ref_path in reference_images:
if os.path.exists(ref_path):
try:
url = storage.upload_file(ref_path)
if url:
img_urls.append(url)
logger.info(f"Uploaded Doubao ref image: {url}")
except Exception as e:
logger.warning(f"Failed to upload Doubao ref image {ref_path}: {e}")
payload = {
"model": config.DOUBAO_IMG_MODEL,
"prompt": prompt,
"sequential_image_generation": "disabled",
"response_format": "url",
"size": "1440x2560",
"stream": False,
"watermark": False
}
if img_urls:
payload["image_urls"] = img_urls
logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', image_urls={len(img_urls)}")
else:
logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', no reference images")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {config.VOLC_API_KEY}"
}
try:
logger.info(f"Submitting to Doubao Image: {self.endpoint}")
resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=180)
if resp.status_code != 200:
msg = f"Doubao Image Failed ({resp.status_code}): {resp.text}"
logger.error(msg)
raise RuntimeError(msg)
data = resp.json()
if "data" in data and len(data["data"]) > 0:
image_url = data["data"][0].get("url")
if image_url:
img_resp = requests.get(image_url, timeout=60)
img_resp.raise_for_status()
output_path = config.TEMP_DIR / output_filename
with open(output_path, "wb") as f:
f.write(img_resp.content)
return str(output_path)
raise RuntimeError(f"No image URL in Doubao response: {data}")
except Exception as e:
logger.error(f"Doubao Gen Failed: {e}")
raise e
def _generate_single_image_shubiao(
self,
prompt: str,
reference_images: List[str],
output_filename: str
) -> Optional[str]:
"""调用 api2img.shubiaobiao.com 通道生成图片(同步返回 base64"""
# 准备参考图,内联 base64 方式
parts = [{"text": prompt}]
# 严格过滤和排序参考图
valid_refs = []
if reference_images:
for p in reference_images:
if p and os.path.exists(p) and p not in valid_refs:
valid_refs.append(p)
logger.info(f"[Shubiaobiao] Input reference images ({len(valid_refs)}): {valid_refs}")
if valid_refs:
for ref_path in valid_refs:
try:
encoded = self._encode_image(ref_path)
if encoded:
parts.append({
"inlineData": {
"mimeType": "image/jpeg",
"data": encoded
}
})
except Exception as e:
logger.error(f"Failed to encode image {ref_path}: {e}")
logger.info(f"[Shubiaobiao] Final payload parts count: {len(parts)} (1 prompt + {len(parts)-1} images)")
payload = {
"contents": [{
"role": "user",
"parts": parts
}],
"generationConfig": {
"responseModalities": ["IMAGE"],
"imageConfig": {
"aspectRatio": "9:16",
"imageSize": "2K"
}
}
}
endpoint = f"{config.SHUBIAOBIAO_IMG_BASE_URL}/v1beta/models/{config.SHUBIAOBIAO_IMG_MODEL_NAME}:generateContent"
headers = {
"x-goog-api-key": config.SHUBIAOBIAO_IMG_KEY,
"Content-Type": "application/json"
}
try:
logger.info(f"Submitting to Shubiaobiao Img: {endpoint}")
resp = requests.post(endpoint, json=payload, headers=headers, timeout=120)
if resp.status_code != 200:
msg = f"Shubiaobiao 提交失败 ({resp.status_code}): {resp.text}"
logger.error(msg)
raise RuntimeError(msg)
data = resp.json()
# 查找 base64 图像
img_b64 = None
candidates = data.get("candidates") or []
if candidates:
content_parts = candidates[0].get("content", {}).get("parts", [])
for part in content_parts:
inline = part.get("inlineData") if isinstance(part, dict) else None
if inline and inline.get("data"):
img_b64 = inline["data"]
break
if not img_b64:
msg = f"Shubiaobiao 响应缺少图片数据: {data}"
logger.error(msg)
raise RuntimeError(msg)
output_path = config.TEMP_DIR / output_filename
with open(output_path, "wb") as f:
f.write(base64.b64decode(img_b64))
logger.info(f"Shubiaobiao Generation Success: {output_path}")
return str(output_path)
except Exception as e:
logger.error(f"Shubiaobiao Generation Exception: {e}")
raise
def _generate_single_image_gemini(
self,
prompt: str,
reference_images: List[str],
output_filename: str
) -> Optional[str]:
"""调用 Gemini (Wuyin Keji / NanoBanana-Pro) 生成单张图片"""
# 1. 构造 Payload
payload = {
"prompt": prompt,
"aspectRatio": "9:16",
"imageSize": "2K"
}
# 处理参考图 (Image-to-Image)
if reference_images:
valid_paths = []
seen = set()
for p in reference_images:
if p and os.path.exists(p) and p not in seen:
valid_paths.append(p)
seen.add(p)
if valid_paths:
img_urls = []
for ref_path in valid_paths:
try:
url = storage.upload_file(ref_path)
if url:
img_urls.append(url)
logger.info(f"Uploaded ref image: {url}")
except Exception as e:
logger.warning(f"Error uploading ref image {ref_path}: {e}")
if img_urls:
payload["img_url"] = img_urls
logger.info(f"Using {len(img_urls)} reference images for Gemini Img2Img")
headers = {
"Authorization": config.GEMINI_IMG_KEY,
"Content-Type": "application/json;charset:utf-8"
}
# 2. 提交任务
try:
logger.info(f"Submitting to Gemini: {config.GEMINI_IMG_API_URL}")
resp = requests.post(config.GEMINI_IMG_API_URL, json=payload, headers=headers, timeout=30)
if resp.status_code != 200:
msg = f"Gemini 提交失败 ({resp.status_code}): {resp.text}"
logger.error(msg)
raise RuntimeError(msg)
data = resp.json()
if data.get("code") != 200:
msg = f"Gemini 返回错误: {data}"
logger.error(msg)
raise RuntimeError(msg)
task_id = data.get("data", {}).get("id")
if not task_id:
raise RuntimeError(f"Gemini 响应缺少 task id: {data}")
logger.info(f"Gemini Task Submitted, ID: {task_id}")
# 3. 轮询状态
max_retries = 60
for i in range(max_retries):
time.sleep(2)
poll_url = f"{config.GEMINI_IMG_DETAIL_URL}?key={config.GEMINI_IMG_KEY}&id={task_id}"
try:
poll_resp = requests.get(poll_url, headers=headers, timeout=30)
except requests.Timeout:
continue
except Exception as e:
continue
if poll_resp.status_code != 200:
continue
poll_data = poll_resp.json()
if poll_data.get("code") != 200:
raise RuntimeError(f"Gemini 轮询返回错误: {poll_data}")
result_data = poll_data.get("data", {}) or {}
status = result_data.get("status") # 0:排队, 1:生成中, 2:成功, 3:失败
if status == 2:
image_url = result_data.get("image_url")
if not image_url:
raise RuntimeError("Gemini 成功但缺少 image_url")
logger.info(f"Gemini Generation Success: {image_url}")
img_resp = requests.get(image_url, timeout=60)
img_resp.raise_for_status()
output_path = config.TEMP_DIR / output_filename
with open(output_path, "wb") as f:
f.write(img_resp.content)
return str(output_path)
if status == 3:
fail_reason = result_data.get("fail_reason", "Unknown")
raise RuntimeError(f"Gemini 生成失败: {fail_reason}")
raise RuntimeError("Gemini 生成超时")
except Exception as e:
logger.error(f"Gemini Generation Exception: {e}")
raise