feat: video-flow initial commit
- app.py: Streamlit UI for video generation workflow - main_flow.py: CLI tool with argparse support - modules/: Business logic modules (script_gen, image_gen, video_gen, composer, etc.) - config.py: Configuration with API keys and paths - requirements.txt: Python dependencies - docs/: System prompt documentation
This commit is contained in:
491
modules/image_gen.py
Normal file
491
modules/image_gen.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""
|
||||
连贯生图模块 (Volcengine Doubao)
|
||||
负责根据分镜脚本和原始素材生成一系列连贯的分镜图片
|
||||
"""
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from PIL import Image
|
||||
import io
|
||||
from modules import storage
|
||||
|
||||
import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ImageGenerator:
|
||||
"""连贯图片生成器 (Volcengine Provider)"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = config.VOLC_API_KEY
|
||||
# Endpoint: https://ark.cn-beijing.volces.com/api/v3/images/generations
|
||||
self.endpoint = f"https://ark.cn-beijing.volces.com/api/v3/images/generations"
|
||||
self.model = config.IMAGE_MODEL_ID
|
||||
|
||||
def _encode_image(self, image_path: str) -> str:
|
||||
"""读取图片,调整大小并转为 Base64"""
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
|
||||
max_size = 1024
|
||||
if max(img.size) > max_size:
|
||||
img.thumbnail((max_size, max_size), Image.LANCZOS)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="JPEG", quality=80)
|
||||
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {image_path}: {e}")
|
||||
return ""
|
||||
|
||||
def generate_single_scene_image(
|
||||
self,
|
||||
scene: Dict[str, Any],
|
||||
original_image_path: Any,
|
||||
previous_image_path: Optional[str] = None,
|
||||
model_provider: str = "shubiaobiao", # "shubiaobiao", "gemini", "doubao"
|
||||
visual_anchor: str = "" # 视觉锚点,强制拼接到 prompt 前
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
生成单张分镜图片 (Public)
|
||||
"""
|
||||
scene_id = scene["id"]
|
||||
visual_prompt = scene.get("visual_prompt", "")
|
||||
|
||||
# 强制拼接 Visual Anchor (确保生图一致性)
|
||||
if visual_anchor and visual_anchor not in visual_prompt:
|
||||
visual_prompt = f"[{visual_anchor}] {visual_prompt}"
|
||||
logger.info(f"Scene {scene_id}: Prepended visual_anchor to prompt")
|
||||
|
||||
logger.info(f"Generating image for Scene {scene_id} (Provider: {model_provider})...")
|
||||
|
||||
input_images = []
|
||||
|
||||
# Handle original_image_path (can be str or list)
|
||||
if isinstance(original_image_path, list):
|
||||
input_images.extend(original_image_path)
|
||||
elif isinstance(original_image_path, str) and original_image_path:
|
||||
input_images.append(original_image_path)
|
||||
|
||||
if previous_image_path:
|
||||
input_images.append(previous_image_path)
|
||||
|
||||
try:
|
||||
output_path = self._generate_single_image(
|
||||
prompt=visual_prompt,
|
||||
reference_images=input_images,
|
||||
output_filename=f"scene_{scene_id}_{int(time.time())}.png",
|
||||
provider=model_provider
|
||||
)
|
||||
|
||||
if output_path:
|
||||
return output_path
|
||||
else:
|
||||
raise RuntimeError(f"Image generation returned empty for Scene {scene_id}")
|
||||
|
||||
except PermissionError as e:
|
||||
logger.error(f"Critical API Error for Scene {scene_id}: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation failed for Scene {scene_id}: {e}")
|
||||
raise e
|
||||
|
||||
def generate_group_images_doubao(
|
||||
self,
|
||||
scenes: List[Dict[str, Any]],
|
||||
reference_images: List[str],
|
||||
visual_anchor: str = "" # 视觉锚点
|
||||
) -> Dict[int, str]:
|
||||
"""
|
||||
Doubao 组图生成 (Batch) - 拼接 Prompt 一次生成多张
|
||||
"""
|
||||
logger.info("Starting Doubao Group Image Generation...")
|
||||
|
||||
# 1. 拼接 Prompts
|
||||
# 格式: "Global: [Visual Anchor] ... | S1: ... | S2: ..."
|
||||
|
||||
scene_prompts = []
|
||||
for scene in scenes:
|
||||
# 提取分镜 Visual Prompt
|
||||
p = scene.get("visual_prompt", "")
|
||||
scene_prompts.append(f"S{scene['id']}:{p}")
|
||||
|
||||
combined_scenes_text = " | ".join(scene_prompts)
|
||||
|
||||
# 构造 Combined Prompt - 将 visual_anchor 放入 Global 部分
|
||||
global_context = f"[{visual_anchor}] Consistent product appearance & style." if visual_anchor else "Consistent product appearance & style."
|
||||
combined_prompt = (
|
||||
f"Global: {global_context}\n"
|
||||
f"{combined_scenes_text}\n"
|
||||
"Req: 1 img per scene. Follow specific angles."
|
||||
)
|
||||
|
||||
logger.info(f"Visual Anchor applied to group prompt: {visual_anchor[:50]}..." if visual_anchor else "No visual_anchor")
|
||||
|
||||
# 记录 Prompt 长度供参考
|
||||
logger.info(f"Doubao Group Prompt Length: {len(combined_prompt)} chars")
|
||||
|
||||
# 2. 准备 payload
|
||||
payload = {
|
||||
"model": config.DOUBAO_IMG_MODEL,
|
||||
"prompt": combined_prompt,
|
||||
"sequential_image_generation": "auto", # 开启组图
|
||||
"sequential_image_generation_options": {
|
||||
"max_images": len(scenes) # 限制最大张数
|
||||
},
|
||||
"response_format": "url",
|
||||
"size": "1440x2560",
|
||||
"stream": False,
|
||||
"watermark": False
|
||||
}
|
||||
|
||||
# 3. 处理参考图
|
||||
img_urls = []
|
||||
if reference_images:
|
||||
for ref_path in reference_images:
|
||||
if os.path.exists(ref_path):
|
||||
try:
|
||||
url = storage.upload_file(ref_path)
|
||||
if url: img_urls.append(url)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to upload ref image {ref_path}: {e}")
|
||||
|
||||
if img_urls:
|
||||
payload["image_urls"] = img_urls
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {config.VOLC_API_KEY}"
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(f"Submitting Doubao Group Request (Scenes: {len(scenes)})...")
|
||||
resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=240)
|
||||
resp.raise_for_status()
|
||||
|
||||
data = resp.json()
|
||||
results = {}
|
||||
|
||||
if "data" in data:
|
||||
items = data["data"]
|
||||
logger.info(f"Doubao returned {len(items)} images.")
|
||||
|
||||
# 尝试将返回的图片映射回 Scene
|
||||
# 假设顺序一致
|
||||
for i, item in enumerate(items):
|
||||
if i < len(scenes):
|
||||
scene_id = scenes[i]["id"]
|
||||
image_url = item.get("url")
|
||||
|
||||
if image_url:
|
||||
# Download
|
||||
img_resp = requests.get(image_url, timeout=60)
|
||||
output_path = config.TEMP_DIR / f"scene_{scene_id}_{int(time.time())}.png"
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(img_resp.content)
|
||||
results[scene_id] = str(output_path)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Doubao Group Generation Failed: {e}")
|
||||
raise e
|
||||
|
||||
def _generate_single_image(
|
||||
self,
|
||||
prompt: str,
|
||||
reference_images: List[str],
|
||||
output_filename: str,
|
||||
provider: str = "shubiaobiao"
|
||||
) -> Optional[str]:
|
||||
"""统一入口"""
|
||||
if provider == "doubao":
|
||||
return self._generate_single_image_doubao(prompt, reference_images, output_filename)
|
||||
elif provider == "gemini":
|
||||
return self._generate_single_image_gemini(prompt, reference_images, output_filename)
|
||||
else:
|
||||
return self._generate_single_image_shubiao(prompt, reference_images, output_filename)
|
||||
|
||||
def _generate_single_image_doubao(
|
||||
self,
|
||||
prompt: str,
|
||||
reference_images: List[str],
|
||||
output_filename: str
|
||||
) -> Optional[str]:
|
||||
"""调用 Volcengine Doubao (Image API)"""
|
||||
|
||||
# 1. Upload all reference images to R2
|
||||
img_urls = []
|
||||
if reference_images:
|
||||
for ref_path in reference_images:
|
||||
if os.path.exists(ref_path):
|
||||
try:
|
||||
url = storage.upload_file(ref_path)
|
||||
if url:
|
||||
img_urls.append(url)
|
||||
logger.info(f"Uploaded Doubao ref image: {url}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to upload Doubao ref image {ref_path}: {e}")
|
||||
|
||||
payload = {
|
||||
"model": config.DOUBAO_IMG_MODEL,
|
||||
"prompt": prompt,
|
||||
"sequential_image_generation": "disabled",
|
||||
"response_format": "url",
|
||||
"size": "1440x2560",
|
||||
"stream": False,
|
||||
"watermark": False
|
||||
}
|
||||
|
||||
if img_urls:
|
||||
payload["image_urls"] = img_urls
|
||||
logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', image_urls={len(img_urls)}")
|
||||
else:
|
||||
logger.info(f"Doubao Image Payload: prompt='{prompt[:20]}...', no reference images")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {config.VOLC_API_KEY}"
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(f"Submitting to Doubao Image: {self.endpoint}")
|
||||
resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=180)
|
||||
|
||||
if resp.status_code != 200:
|
||||
msg = f"Doubao Image Failed ({resp.status_code}): {resp.text}"
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
data = resp.json()
|
||||
|
||||
if "data" in data and len(data["data"]) > 0:
|
||||
image_url = data["data"][0].get("url")
|
||||
if image_url:
|
||||
img_resp = requests.get(image_url, timeout=60)
|
||||
img_resp.raise_for_status()
|
||||
|
||||
output_path = config.TEMP_DIR / output_filename
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(img_resp.content)
|
||||
return str(output_path)
|
||||
|
||||
raise RuntimeError(f"No image URL in Doubao response: {data}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Doubao Gen Failed: {e}")
|
||||
raise e
|
||||
|
||||
def _generate_single_image_shubiao(
|
||||
self,
|
||||
prompt: str,
|
||||
reference_images: List[str],
|
||||
output_filename: str
|
||||
) -> Optional[str]:
|
||||
"""调用 api2img.shubiaobiao.com 通道生成图片(同步返回 base64)"""
|
||||
# 准备参考图,内联 base64 方式
|
||||
parts = [{"text": prompt}]
|
||||
|
||||
# 严格过滤和排序参考图
|
||||
valid_refs = []
|
||||
if reference_images:
|
||||
for p in reference_images:
|
||||
if p and os.path.exists(p) and p not in valid_refs:
|
||||
valid_refs.append(p)
|
||||
|
||||
logger.info(f"[Shubiaobiao] Input reference images ({len(valid_refs)}): {valid_refs}")
|
||||
|
||||
if valid_refs:
|
||||
for ref_path in valid_refs:
|
||||
try:
|
||||
encoded = self._encode_image(ref_path)
|
||||
if encoded:
|
||||
parts.append({
|
||||
"inlineData": {
|
||||
"mimeType": "image/jpeg",
|
||||
"data": encoded
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encode image {ref_path}: {e}")
|
||||
|
||||
logger.info(f"[Shubiaobiao] Final payload parts count: {len(parts)} (1 prompt + {len(parts)-1} images)")
|
||||
|
||||
payload = {
|
||||
"contents": [{
|
||||
"role": "user",
|
||||
"parts": parts
|
||||
}],
|
||||
"generationConfig": {
|
||||
"responseModalities": ["IMAGE"],
|
||||
"imageConfig": {
|
||||
"aspectRatio": "9:16",
|
||||
"imageSize": "2K"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
endpoint = f"{config.SHUBIAOBIAO_IMG_BASE_URL}/v1beta/models/{config.SHUBIAOBIAO_IMG_MODEL_NAME}:generateContent"
|
||||
headers = {
|
||||
"x-goog-api-key": config.SHUBIAOBIAO_IMG_KEY,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(f"Submitting to Shubiaobiao Img: {endpoint}")
|
||||
resp = requests.post(endpoint, json=payload, headers=headers, timeout=120)
|
||||
|
||||
if resp.status_code != 200:
|
||||
msg = f"Shubiaobiao 提交失败 ({resp.status_code}): {resp.text}"
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
data = resp.json()
|
||||
|
||||
# 查找 base64 图像
|
||||
img_b64 = None
|
||||
candidates = data.get("candidates") or []
|
||||
if candidates:
|
||||
content_parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in content_parts:
|
||||
inline = part.get("inlineData") if isinstance(part, dict) else None
|
||||
if inline and inline.get("data"):
|
||||
img_b64 = inline["data"]
|
||||
break
|
||||
|
||||
if not img_b64:
|
||||
msg = f"Shubiaobiao 响应缺少图片数据: {data}"
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
output_path = config.TEMP_DIR / output_filename
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(base64.b64decode(img_b64))
|
||||
|
||||
logger.info(f"Shubiaobiao Generation Success: {output_path}")
|
||||
return str(output_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Shubiaobiao Generation Exception: {e}")
|
||||
raise
|
||||
|
||||
def _generate_single_image_gemini(
|
||||
self,
|
||||
prompt: str,
|
||||
reference_images: List[str],
|
||||
output_filename: str
|
||||
) -> Optional[str]:
|
||||
"""调用 Gemini (Wuyin Keji / NanoBanana-Pro) 生成单张图片"""
|
||||
|
||||
# 1. 构造 Payload
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"aspectRatio": "9:16",
|
||||
"imageSize": "2K"
|
||||
}
|
||||
|
||||
# 处理参考图 (Image-to-Image)
|
||||
if reference_images:
|
||||
valid_paths = []
|
||||
seen = set()
|
||||
for p in reference_images:
|
||||
if p and os.path.exists(p) and p not in seen:
|
||||
valid_paths.append(p)
|
||||
seen.add(p)
|
||||
|
||||
if valid_paths:
|
||||
img_urls = []
|
||||
for ref_path in valid_paths:
|
||||
try:
|
||||
url = storage.upload_file(ref_path)
|
||||
if url:
|
||||
img_urls.append(url)
|
||||
logger.info(f"Uploaded ref image: {url}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error uploading ref image {ref_path}: {e}")
|
||||
|
||||
if img_urls:
|
||||
payload["img_url"] = img_urls
|
||||
logger.info(f"Using {len(img_urls)} reference images for Gemini Img2Img")
|
||||
|
||||
headers = {
|
||||
"Authorization": config.GEMINI_IMG_KEY,
|
||||
"Content-Type": "application/json;charset:utf-8"
|
||||
}
|
||||
|
||||
# 2. 提交任务
|
||||
try:
|
||||
logger.info(f"Submitting to Gemini: {config.GEMINI_IMG_API_URL}")
|
||||
resp = requests.post(config.GEMINI_IMG_API_URL, json=payload, headers=headers, timeout=30)
|
||||
|
||||
if resp.status_code != 200:
|
||||
msg = f"Gemini 提交失败 ({resp.status_code}): {resp.text}"
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
data = resp.json()
|
||||
if data.get("code") != 200:
|
||||
msg = f"Gemini 返回错误: {data}"
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
task_id = data.get("data", {}).get("id")
|
||||
if not task_id:
|
||||
raise RuntimeError(f"Gemini 响应缺少 task id: {data}")
|
||||
|
||||
logger.info(f"Gemini Task Submitted, ID: {task_id}")
|
||||
|
||||
# 3. 轮询状态
|
||||
max_retries = 60
|
||||
for i in range(max_retries):
|
||||
time.sleep(2)
|
||||
|
||||
poll_url = f"{config.GEMINI_IMG_DETAIL_URL}?key={config.GEMINI_IMG_KEY}&id={task_id}"
|
||||
try:
|
||||
poll_resp = requests.get(poll_url, headers=headers, timeout=30)
|
||||
except requests.Timeout:
|
||||
continue
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if poll_resp.status_code != 200:
|
||||
continue
|
||||
|
||||
poll_data = poll_resp.json()
|
||||
if poll_data.get("code") != 200:
|
||||
raise RuntimeError(f"Gemini 轮询返回错误: {poll_data}")
|
||||
|
||||
result_data = poll_data.get("data", {}) or {}
|
||||
status = result_data.get("status") # 0:排队, 1:生成中, 2:成功, 3:失败
|
||||
|
||||
if status == 2:
|
||||
image_url = result_data.get("image_url")
|
||||
if not image_url:
|
||||
raise RuntimeError("Gemini 成功但缺少 image_url")
|
||||
|
||||
logger.info(f"Gemini Generation Success: {image_url}")
|
||||
img_resp = requests.get(image_url, timeout=60)
|
||||
img_resp.raise_for_status()
|
||||
|
||||
output_path = config.TEMP_DIR / output_filename
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(img_resp.content)
|
||||
|
||||
return str(output_path)
|
||||
|
||||
if status == 3:
|
||||
fail_reason = result_data.get("fail_reason", "Unknown")
|
||||
raise RuntimeError(f"Gemini 生成失败: {fail_reason}")
|
||||
|
||||
raise RuntimeError("Gemini 生成超时")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini Generation Exception: {e}")
|
||||
raise
|
||||
Reference in New Issue
Block a user