perf(8502): 并行生图(6并发)+超时重试;视频URL直连预览/下载;路径隔离
This commit is contained in:
@@ -15,9 +15,52 @@ import io
|
||||
from modules import storage
|
||||
|
||||
import config
|
||||
from modules import path_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
try:
|
||||
return int(os.getenv(name, str(default)))
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
# Tunables: slow channels can be hot; default conservative but adjustable.
|
||||
IMG_SUBMIT_TIMEOUT_S = _env_int("IMG_SUBMIT_TIMEOUT_S", 180)
|
||||
IMG_POLL_TIMEOUT_S = _env_int("IMG_POLL_TIMEOUT_S", 30)
|
||||
IMG_MAX_RETRIES = _env_int("IMG_MAX_RETRIES", 3)
|
||||
IMG_POLL_INTERVAL_S = _env_int("IMG_POLL_INTERVAL_S", 2)
|
||||
IMG_POLL_MAX_RETRIES = _env_int("IMG_POLL_MAX_RETRIES", 90) # 90*2s ~= 180s
|
||||
|
||||
|
||||
def _is_retryable_exception(e: Exception) -> bool:
|
||||
# Network / transient errors
|
||||
if isinstance(e, (requests.Timeout, requests.ConnectionError)):
|
||||
return True
|
||||
msg = str(e).lower()
|
||||
# Transient provider errors often contain these keywords
|
||||
if any(k in msg for k in ["timeout", "temporarily", "temporarily unavailable", "gateway", "rate", "try again"]):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _with_retries(fn, *, max_retries: int, label: str):
|
||||
last = None
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
return fn()
|
||||
except Exception as e:
|
||||
last = e
|
||||
retryable = _is_retryable_exception(e)
|
||||
logger.warning(f"[{label}] attempt {attempt}/{max_retries} failed: {e} (retryable={retryable})")
|
||||
if not retryable or attempt >= max_retries:
|
||||
raise
|
||||
# small backoff
|
||||
time.sleep(min(2 ** (attempt - 1), 4))
|
||||
raise last # pragma: no cover
|
||||
|
||||
class ImageGenerator:
|
||||
"""连贯图片生成器 (Volcengine Provider)"""
|
||||
|
||||
@@ -51,7 +94,8 @@ class ImageGenerator:
|
||||
original_image_path: Any,
|
||||
previous_image_path: Optional[str] = None,
|
||||
model_provider: str = "shubiaobiao", # "shubiaobiao", "gemini", "doubao"
|
||||
visual_anchor: str = "" # 视觉锚点,强制拼接到 prompt 前
|
||||
visual_anchor: str = "", # 视觉锚点,强制拼接到 prompt 前
|
||||
project_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
生成单张分镜图片 (Public)
|
||||
@@ -78,11 +122,19 @@ class ImageGenerator:
|
||||
input_images.append(previous_image_path)
|
||||
|
||||
try:
|
||||
out_dir = path_utils.project_images_dir(project_id) if project_id else config.TEMP_DIR
|
||||
out_name = path_utils.unique_filename(
|
||||
prefix="scene_image",
|
||||
ext="png",
|
||||
project_id=project_id,
|
||||
scene_id=scene_id,
|
||||
)
|
||||
output_path = self._generate_single_image(
|
||||
prompt=visual_prompt,
|
||||
reference_images=input_images,
|
||||
output_filename=f"scene_{scene_id}_{int(time.time())}.png",
|
||||
provider=model_provider
|
||||
output_filename=out_name,
|
||||
provider=model_provider,
|
||||
output_dir=out_dir,
|
||||
)
|
||||
|
||||
if output_path:
|
||||
@@ -101,7 +153,8 @@ class ImageGenerator:
|
||||
self,
|
||||
scenes: List[Dict[str, Any]],
|
||||
reference_images: List[str],
|
||||
visual_anchor: str = "" # 视觉锚点
|
||||
visual_anchor: str = "", # 视觉锚点
|
||||
project_id: Optional[str] = None,
|
||||
) -> Dict[int, str]:
|
||||
"""
|
||||
Doubao 组图生成 (Batch) - 拼接 Prompt 一次生成多张
|
||||
@@ -187,7 +240,15 @@ class ImageGenerator:
|
||||
if image_url:
|
||||
# Download
|
||||
img_resp = requests.get(image_url, timeout=60)
|
||||
output_path = config.TEMP_DIR / f"scene_{scene_id}_{int(time.time())}.png"
|
||||
out_dir = path_utils.project_images_dir(project_id) if project_id else config.TEMP_DIR
|
||||
out_name = path_utils.unique_filename(
|
||||
prefix="scene_image",
|
||||
ext="png",
|
||||
project_id=project_id,
|
||||
scene_id=scene_id,
|
||||
extra="group",
|
||||
)
|
||||
output_path = out_dir / out_name
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(img_resp.content)
|
||||
results[scene_id] = str(output_path)
|
||||
@@ -203,21 +264,24 @@ class ImageGenerator:
|
||||
prompt: str,
|
||||
reference_images: List[str],
|
||||
output_filename: str,
|
||||
provider: str = "shubiaobiao"
|
||||
provider: str = "shubiaobiao",
|
||||
output_dir: Optional[Path] = None,
|
||||
) -> Optional[str]:
|
||||
"""统一入口"""
|
||||
out_dir = output_dir or config.TEMP_DIR
|
||||
if provider == "doubao":
|
||||
return self._generate_single_image_doubao(prompt, reference_images, output_filename)
|
||||
return self._generate_single_image_doubao(prompt, reference_images, output_filename, out_dir)
|
||||
elif provider == "gemini":
|
||||
return self._generate_single_image_gemini(prompt, reference_images, output_filename)
|
||||
return self._generate_single_image_gemini(prompt, reference_images, output_filename, out_dir)
|
||||
else:
|
||||
return self._generate_single_image_shubiao(prompt, reference_images, output_filename)
|
||||
return self._generate_single_image_shubiao(prompt, reference_images, output_filename, out_dir)
|
||||
|
||||
def _generate_single_image_doubao(
|
||||
self,
|
||||
prompt: str,
|
||||
reference_images: List[str],
|
||||
output_filename: str
|
||||
output_filename: str,
|
||||
output_dir: Path
|
||||
) -> Optional[str]:
|
||||
"""调用 Volcengine Doubao (Image API)"""
|
||||
|
||||
@@ -255,9 +319,9 @@ class ImageGenerator:
|
||||
"Authorization": f"Bearer {config.VOLC_API_KEY}"
|
||||
}
|
||||
|
||||
try:
|
||||
def _call():
|
||||
logger.info(f"Submitting to Doubao Image: {self.endpoint}")
|
||||
resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=180)
|
||||
resp = requests.post(self.endpoint, json=payload, headers=headers, timeout=IMG_SUBMIT_TIMEOUT_S)
|
||||
|
||||
if resp.status_code != 200:
|
||||
msg = f"Doubao Image Failed ({resp.status_code}): {resp.text}"
|
||||
@@ -272,22 +336,20 @@ class ImageGenerator:
|
||||
img_resp = requests.get(image_url, timeout=60)
|
||||
img_resp.raise_for_status()
|
||||
|
||||
output_path = config.TEMP_DIR / output_filename
|
||||
output_path = output_dir / output_filename
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(img_resp.content)
|
||||
return str(output_path)
|
||||
|
||||
raise RuntimeError(f"No image URL in Doubao response: {data}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Doubao Gen Failed: {e}")
|
||||
raise e
|
||||
return _with_retries(_call, max_retries=IMG_MAX_RETRIES, label="doubao_image")
|
||||
|
||||
def _generate_single_image_shubiao(
|
||||
self,
|
||||
prompt: str,
|
||||
reference_images: List[str],
|
||||
output_filename: str
|
||||
output_filename: str,
|
||||
output_dir: Path
|
||||
) -> Optional[str]:
|
||||
"""调用 api2img.shubiaobiao.com 通道生成图片(同步返回 base64)"""
|
||||
# 准备参考图,内联 base64 方式
|
||||
@@ -338,9 +400,9 @@ class ImageGenerator:
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
try:
|
||||
def _call():
|
||||
logger.info(f"Submitting to Shubiaobiao Img: {endpoint}")
|
||||
resp = requests.post(endpoint, json=payload, headers=headers, timeout=120)
|
||||
resp = requests.post(endpoint, json=payload, headers=headers, timeout=IMG_SUBMIT_TIMEOUT_S)
|
||||
|
||||
if resp.status_code != 200:
|
||||
msg = f"Shubiaobiao 提交失败 ({resp.status_code}): {resp.text}"
|
||||
@@ -365,22 +427,20 @@ class ImageGenerator:
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
output_path = config.TEMP_DIR / output_filename
|
||||
output_path = output_dir / output_filename
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(base64.b64decode(img_b64))
|
||||
|
||||
logger.info(f"Shubiaobiao Generation Success: {output_path}")
|
||||
return str(output_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Shubiaobiao Generation Exception: {e}")
|
||||
raise
|
||||
return _with_retries(_call, max_retries=IMG_MAX_RETRIES, label="shubiaobiao_image")
|
||||
|
||||
def _generate_single_image_gemini(
|
||||
self,
|
||||
prompt: str,
|
||||
reference_images: List[str],
|
||||
output_filename: str
|
||||
output_filename: str,
|
||||
output_dir: Path
|
||||
) -> Optional[str]:
|
||||
"""调用 Gemini (Wuyin Keji / NanoBanana-Pro) 生成单张图片"""
|
||||
|
||||
@@ -420,10 +480,10 @@ class ImageGenerator:
|
||||
"Content-Type": "application/json;charset:utf-8"
|
||||
}
|
||||
|
||||
# 2. 提交任务
|
||||
try:
|
||||
def _call():
|
||||
# 2. 提交任务
|
||||
logger.info(f"Submitting to Gemini: {config.GEMINI_IMG_API_URL}")
|
||||
resp = requests.post(config.GEMINI_IMG_API_URL, json=payload, headers=headers, timeout=30)
|
||||
resp = requests.post(config.GEMINI_IMG_API_URL, json=payload, headers=headers, timeout=IMG_SUBMIT_TIMEOUT_S)
|
||||
|
||||
if resp.status_code != 200:
|
||||
msg = f"Gemini 提交失败 ({resp.status_code}): {resp.text}"
|
||||
@@ -443,13 +503,12 @@ class ImageGenerator:
|
||||
logger.info(f"Gemini Task Submitted, ID: {task_id}")
|
||||
|
||||
# 3. 轮询状态
|
||||
max_retries = 60
|
||||
for i in range(max_retries):
|
||||
time.sleep(2)
|
||||
for _ in range(IMG_POLL_MAX_RETRIES):
|
||||
time.sleep(IMG_POLL_INTERVAL_S)
|
||||
|
||||
poll_url = f"{config.GEMINI_IMG_DETAIL_URL}?key={config.GEMINI_IMG_KEY}&id={task_id}"
|
||||
try:
|
||||
poll_resp = requests.get(poll_url, headers=headers, timeout=30)
|
||||
poll_resp = requests.get(poll_url, headers=headers, timeout=IMG_POLL_TIMEOUT_S)
|
||||
except requests.Timeout:
|
||||
continue
|
||||
except Exception as e:
|
||||
@@ -474,7 +533,7 @@ class ImageGenerator:
|
||||
img_resp = requests.get(image_url, timeout=60)
|
||||
img_resp.raise_for_status()
|
||||
|
||||
output_path = config.TEMP_DIR / output_filename
|
||||
output_path = output_dir / output_filename
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(img_resp.content)
|
||||
|
||||
@@ -485,7 +544,4 @@ class ImageGenerator:
|
||||
raise RuntimeError(f"Gemini 生成失败: {fail_reason}")
|
||||
|
||||
raise RuntimeError("Gemini 生成超时")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini Generation Exception: {e}")
|
||||
raise
|
||||
return _with_retries(_call, max_retries=IMG_MAX_RETRIES, label="gemini_image")
|
||||
|
||||
Reference in New Issue
Block a user