fix(stability): R2上传改为懒加载,缺boto3时降级不阻塞启动

This commit is contained in:
Tony Zhang
2025-12-17 10:51:19 +08:00
parent e365a94dd1
commit ebcf165c3f

View File

@@ -9,10 +9,10 @@ import os
import requests
from typing import Dict, Any, List, Optional
from pathlib import Path
from openai import OpenAI
import config
from modules.db_manager import db
logger = logging.getLogger(__name__)
class ScriptGenerator:
@@ -24,6 +24,11 @@ class ScriptGenerator:
# 根据 demo: https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent
# 这里我们先假设 base_url 是 v1beta/models/
self.endpoint = "https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent"
# OpenAI-compatible client for ShuBiaoBiao (supports multiple models incl. GPT)
self.shubiaobiao_client = OpenAI(
api_key=config.SHUBIAOBIAO_KEY,
base_url=config.SHUBIAOBIAO_BASE_URL
)
# Default System Prompt
self.default_system_prompt = """
@@ -145,11 +150,28 @@ class ScriptGenerator:
system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt)
user_prompt = self._build_user_prompt(product_name, product_info)
# Branch for Doubao
# Branch for Doubao (Volcengine)
if model_provider == "doubao":
return self._generate_script_doubao(system_prompt, user_prompt, image_paths)
script = self._generate_script_doubao(system_prompt, user_prompt, image_paths)
if script:
script["selling_points"] = self._postprocess_selling_points(product_info, script.get("selling_points"))
return script
# ... Existing Shubiaobiao Logic ...
# Branch for ShuBiaoBiao GPT (OpenAI-compatible multimodal)
if model_provider == "shubiaobiao_gpt":
script = self._generate_script_shubiaobiao_openai(system_prompt, user_prompt, image_paths, model_name="gpt-5.2")
if script:
script["selling_points"] = self._postprocess_selling_points(product_info, script.get("selling_points"))
return script
# Branch for ShuBiaoBiao Gemini (OpenAI-compatible; use image URLs instead of base64)
if model_provider == "shubiaobiao":
script = self._generate_script_shubiaobiao_openai(system_prompt, user_prompt, image_paths, model_name=config.SHUBIAOBIAO_MODEL_TEXT)
if script:
script["selling_points"] = self._postprocess_selling_points(product_info, script.get("selling_points"))
return script
# Fallback (should not normally reach here)
# 调试: 检查是否使用了自定义 Prompt
if system_prompt != self.default_system_prompt:
@@ -218,6 +240,8 @@ class ScriptGenerator:
return None
final_script = self._validate_and_fix_script(script_json)
# 不改 prompt 的前提下:对卖点做轻量规则化(更具体、更可执行)
final_script["selling_points"] = self._postprocess_selling_points(product_info, final_script.get("selling_points"))
# Add Debug Info (包含原始输出)
final_script["_debug"] = {
@@ -237,6 +261,165 @@ class ScriptGenerator:
logger.error(f"Response content: {response.text}")
return None
def _upload_images_to_r2(self, image_paths: List[str], limit: int = 10) -> List[str]:
urls: List[str] = []
if not image_paths:
return urls
# NOTE: avoid hard import dependency at app startup.
# If boto3 / storage is not installed on the runtime, we should not crash Streamlit.
try:
from modules import storage # lazy import
except Exception as e:
logger.warning(f"R2 upload disabled (storage/boto3 unavailable): {e}")
return urls
for p in image_paths[:limit]:
try:
if p and Path(p).exists():
url = storage.upload_file(str(p))
if url:
urls.append(url)
except Exception as e:
logger.warning(f"Failed to upload script image to R2: {p} ({e})")
return urls
def _generate_script_shubiaobiao_openai(
self,
system_prompt: str,
user_prompt: str,
image_paths: List[str],
model_name: str,
) -> Optional[Dict[str, Any]]:
"""
ShuBiaoBiao OpenAI-compatible multimodal chat.
IMPORTANT: For ShuBiaoBiao models, we pass image URLs (R2 public URLs), not base64.
"""
messages = [{"role": "system", "content": system_prompt}]
user_content: List[Dict[str, Any]] = []
# Images first (URL), then text
urls = self._upload_images_to_r2(image_paths or [], limit=10)
for url in urls:
user_content.append({"type": "image_url", "image_url": {"url": url}})
user_content.append({"type": "text", "text": user_prompt})
messages.append({"role": "user", "content": user_content})
try:
resp = self.shubiaobiao_client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.7,
)
content_text = (resp.choices[0].message.content or "").strip()
script_json = self._extract_json_from_response(content_text)
if script_json is None:
logger.error(f"Failed to extract JSON from shubiaobiao response({model_name}): {content_text[:500]}...")
return None
final_script = self._validate_and_fix_script(script_json)
final_script["_debug"] = {
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"raw_output": content_text,
"provider": f"shubiaobiao:{model_name}",
"image_urls": urls,
}
return final_script
except Exception as e:
logger.error(f"shubiaobiao script generation failed ({model_name}): {e}")
return None
def _postprocess_selling_points(self, product_info: Dict[str, Any], selling_points: Any) -> List[str]:
"""
Engineering-only postprocess (NO prompt change):
- De-duplicate
- Prefer specific points derived from tags/params when LLM points are too generic
"""
def _norm_list(v: Any) -> List[str]:
if isinstance(v, list):
return [str(x).strip() for x in v if str(x).strip()]
if isinstance(v, str) and v.strip():
return [v.strip()]
return []
tags = str((product_info or {}).get("tags", "") or "")
params = str((product_info or {}).get("params", "") or "")
category = str((product_info or {}).get("category", "") or "")
# Candidates from tags/params/category
raw = " ".join([tags, params, category])
parts = []
for sep in ["|", "", ";", "", ",", "", "\n"]:
raw = raw.replace(sep, "|")
for p in raw.split("|"):
p = p.strip()
if not p:
continue
# params like "key:value"
if ":" in p:
kv = [x.strip() for x in p.split(":", 1)]
if len(kv) == 2 and kv[1]:
parts.append(kv[1])
else:
parts.append(p)
else:
parts.append(p)
# Heuristic: keep more concrete phrases
generic_words = [
"好看", "百搭", "高级", "气质", "显白", "好用", "耐看", "时尚", "精致", "很棒", "不错", "喜欢", "推荐",
"性价比", "划算", "超值", "必入", "绝了",
]
concrete_hints = [
"不掉", "牢固", "防滑", "加厚", "大号", "小号", "强力", "耐用", "稳固", "不勒", "不伤",
"树脂", "金属", "水钻", "", "材质", "尺寸", "弹簧", "夹力", "发量", "马尾",
]
candidates: List[str] = []
seen = set()
for p in parts:
if not p or p in seen:
continue
seen.add(p)
if any(h in p for h in concrete_hints) and not any(g in p for g in ["喜欢", "推荐"]):
candidates.append(p)
# fallback: keep non-empty tags
if not candidates:
candidates = [p for p in parts if p][:8]
points = _norm_list(selling_points)
# de-dup preserving order
out: List[str] = []
used = set()
for p in points:
if p in used:
continue
used.add(p)
out.append(p)
def _is_generic(p: str) -> bool:
# treat as generic if only contains generic words and lacks any concrete hints
if any(h in p for h in concrete_hints):
return False
return any(g in p for g in generic_words)
# Replace overly generic points with better candidates (keep length <= 3-5 ideally)
cand_iter = (c for c in candidates if c not in used)
refined: List[str] = []
for p in out:
if _is_generic(p):
c = next(cand_iter, None)
if c:
refined.append(c)
used.add(c)
continue
refined.append(p)
# Ensure at least 3 selling points if possible
while len(refined) < 3:
c = next(cand_iter, None)
if not c:
break
refined.append(c)
used.add(c)
return refined[:5]
def _generate_script_doubao(
self,
system_prompt: str,