fix(stability): R2上传改为懒加载,缺boto3时降级不阻塞启动

This commit is contained in:
Tony Zhang
2025-12-17 10:51:19 +08:00
parent e365a94dd1
commit ebcf165c3f

View File

@@ -9,10 +9,10 @@ import os
import requests import requests
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from pathlib import Path from pathlib import Path
from openai import OpenAI
import config import config
from modules.db_manager import db from modules.db_manager import db
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ScriptGenerator: class ScriptGenerator:
@@ -24,6 +24,11 @@ class ScriptGenerator:
# 根据 demo: https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent # 根据 demo: https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent
# 这里我们先假设 base_url 是 v1beta/models/ # 这里我们先假设 base_url 是 v1beta/models/
self.endpoint = "https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent" self.endpoint = "https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent"
# OpenAI-compatible client for ShuBiaoBiao (supports multiple models incl. GPT)
self.shubiaobiao_client = OpenAI(
api_key=config.SHUBIAOBIAO_KEY,
base_url=config.SHUBIAOBIAO_BASE_URL
)
# Default System Prompt # Default System Prompt
self.default_system_prompt = """ self.default_system_prompt = """
@@ -145,11 +150,28 @@ class ScriptGenerator:
system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt) system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt)
user_prompt = self._build_user_prompt(product_name, product_info) user_prompt = self._build_user_prompt(product_name, product_info)
# Branch for Doubao # Branch for Doubao (Volcengine)
if model_provider == "doubao": if model_provider == "doubao":
return self._generate_script_doubao(system_prompt, user_prompt, image_paths) script = self._generate_script_doubao(system_prompt, user_prompt, image_paths)
if script:
script["selling_points"] = self._postprocess_selling_points(product_info, script.get("selling_points"))
return script
# ... Existing Shubiaobiao Logic ... # Branch for ShuBiaoBiao GPT (OpenAI-compatible multimodal)
if model_provider == "shubiaobiao_gpt":
script = self._generate_script_shubiaobiao_openai(system_prompt, user_prompt, image_paths, model_name="gpt-5.2")
if script:
script["selling_points"] = self._postprocess_selling_points(product_info, script.get("selling_points"))
return script
# Branch for ShuBiaoBiao Gemini (OpenAI-compatible; use image URLs instead of base64)
if model_provider == "shubiaobiao":
script = self._generate_script_shubiaobiao_openai(system_prompt, user_prompt, image_paths, model_name=config.SHUBIAOBIAO_MODEL_TEXT)
if script:
script["selling_points"] = self._postprocess_selling_points(product_info, script.get("selling_points"))
return script
# Fallback (should not normally reach here)
# 调试: 检查是否使用了自定义 Prompt # 调试: 检查是否使用了自定义 Prompt
if system_prompt != self.default_system_prompt: if system_prompt != self.default_system_prompt:
@@ -218,6 +240,8 @@ class ScriptGenerator:
return None return None
final_script = self._validate_and_fix_script(script_json) final_script = self._validate_and_fix_script(script_json)
# 不改 prompt 的前提下:对卖点做轻量规则化(更具体、更可执行)
final_script["selling_points"] = self._postprocess_selling_points(product_info, final_script.get("selling_points"))
# Add Debug Info (包含原始输出) # Add Debug Info (包含原始输出)
final_script["_debug"] = { final_script["_debug"] = {
@@ -237,6 +261,165 @@ class ScriptGenerator:
logger.error(f"Response content: {response.text}") logger.error(f"Response content: {response.text}")
return None return None
def _upload_images_to_r2(self, image_paths: List[str], limit: int = 10) -> List[str]:
urls: List[str] = []
if not image_paths:
return urls
# NOTE: avoid hard import dependency at app startup.
# If boto3 / storage is not installed on the runtime, we should not crash Streamlit.
try:
from modules import storage # lazy import
except Exception as e:
logger.warning(f"R2 upload disabled (storage/boto3 unavailable): {e}")
return urls
for p in image_paths[:limit]:
try:
if p and Path(p).exists():
url = storage.upload_file(str(p))
if url:
urls.append(url)
except Exception as e:
logger.warning(f"Failed to upload script image to R2: {p} ({e})")
return urls
def _generate_script_shubiaobiao_openai(
self,
system_prompt: str,
user_prompt: str,
image_paths: List[str],
model_name: str,
) -> Optional[Dict[str, Any]]:
"""
ShuBiaoBiao OpenAI-compatible multimodal chat.
IMPORTANT: For ShuBiaoBiao models, we pass image URLs (R2 public URLs), not base64.
"""
messages = [{"role": "system", "content": system_prompt}]
user_content: List[Dict[str, Any]] = []
# Images first (URL), then text
urls = self._upload_images_to_r2(image_paths or [], limit=10)
for url in urls:
user_content.append({"type": "image_url", "image_url": {"url": url}})
user_content.append({"type": "text", "text": user_prompt})
messages.append({"role": "user", "content": user_content})
try:
resp = self.shubiaobiao_client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.7,
)
content_text = (resp.choices[0].message.content or "").strip()
script_json = self._extract_json_from_response(content_text)
if script_json is None:
logger.error(f"Failed to extract JSON from shubiaobiao response({model_name}): {content_text[:500]}...")
return None
final_script = self._validate_and_fix_script(script_json)
final_script["_debug"] = {
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"raw_output": content_text,
"provider": f"shubiaobiao:{model_name}",
"image_urls": urls,
}
return final_script
except Exception as e:
logger.error(f"shubiaobiao script generation failed ({model_name}): {e}")
return None
def _postprocess_selling_points(self, product_info: Dict[str, Any], selling_points: Any) -> List[str]:
"""
Engineering-only postprocess (NO prompt change):
- De-duplicate
- Prefer specific points derived from tags/params when LLM points are too generic
"""
def _norm_list(v: Any) -> List[str]:
if isinstance(v, list):
return [str(x).strip() for x in v if str(x).strip()]
if isinstance(v, str) and v.strip():
return [v.strip()]
return []
tags = str((product_info or {}).get("tags", "") or "")
params = str((product_info or {}).get("params", "") or "")
category = str((product_info or {}).get("category", "") or "")
# Candidates from tags/params/category
raw = " ".join([tags, params, category])
parts = []
for sep in ["|", "", ";", "", ",", "", "\n"]:
raw = raw.replace(sep, "|")
for p in raw.split("|"):
p = p.strip()
if not p:
continue
# params like "key:value"
if ":" in p:
kv = [x.strip() for x in p.split(":", 1)]
if len(kv) == 2 and kv[1]:
parts.append(kv[1])
else:
parts.append(p)
else:
parts.append(p)
# Heuristic: keep more concrete phrases
generic_words = [
"好看", "百搭", "高级", "气质", "显白", "好用", "耐看", "时尚", "精致", "很棒", "不错", "喜欢", "推荐",
"性价比", "划算", "超值", "必入", "绝了",
]
concrete_hints = [
"不掉", "牢固", "防滑", "加厚", "大号", "小号", "强力", "耐用", "稳固", "不勒", "不伤",
"树脂", "金属", "水钻", "", "材质", "尺寸", "弹簧", "夹力", "发量", "马尾",
]
candidates: List[str] = []
seen = set()
for p in parts:
if not p or p in seen:
continue
seen.add(p)
if any(h in p for h in concrete_hints) and not any(g in p for g in ["喜欢", "推荐"]):
candidates.append(p)
# fallback: keep non-empty tags
if not candidates:
candidates = [p for p in parts if p][:8]
points = _norm_list(selling_points)
# de-dup preserving order
out: List[str] = []
used = set()
for p in points:
if p in used:
continue
used.add(p)
out.append(p)
def _is_generic(p: str) -> bool:
# treat as generic if only contains generic words and lacks any concrete hints
if any(h in p for h in concrete_hints):
return False
return any(g in p for g in generic_words)
# Replace overly generic points with better candidates (keep length <= 3-5 ideally)
cand_iter = (c for c in candidates if c not in used)
refined: List[str] = []
for p in out:
if _is_generic(p):
c = next(cand_iter, None)
if c:
refined.append(c)
used.add(c)
continue
refined.append(p)
# Ensure at least 3 selling points if possible
while len(refined) < 3:
c = next(cand_iter, None)
if not c:
break
refined.append(c)
used.add(c)
return refined[:5]
def _generate_script_doubao( def _generate_script_doubao(
self, self,
system_prompt: str, system_prompt: str,