fix(stability): R2上传改为懒加载,缺boto3时降级不阻塞启动
This commit is contained in:
@@ -9,10 +9,10 @@ import os
|
||||
import requests
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pathlib import Path
|
||||
from openai import OpenAI
|
||||
|
||||
import config
|
||||
from modules.db_manager import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ScriptGenerator:
|
||||
@@ -24,6 +24,11 @@ class ScriptGenerator:
|
||||
# 根据 demo: https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent
|
||||
# 这里我们先假设 base_url 是 v1beta/models/
|
||||
self.endpoint = "https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent"
|
||||
# OpenAI-compatible client for ShuBiaoBiao (supports multiple models incl. GPT)
|
||||
self.shubiaobiao_client = OpenAI(
|
||||
api_key=config.SHUBIAOBIAO_KEY,
|
||||
base_url=config.SHUBIAOBIAO_BASE_URL
|
||||
)
|
||||
|
||||
# Default System Prompt
|
||||
self.default_system_prompt = """
|
||||
@@ -145,11 +150,28 @@ class ScriptGenerator:
|
||||
system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt)
|
||||
user_prompt = self._build_user_prompt(product_name, product_info)
|
||||
|
||||
# Branch for Doubao
|
||||
# Branch for Doubao (Volcengine)
|
||||
if model_provider == "doubao":
|
||||
return self._generate_script_doubao(system_prompt, user_prompt, image_paths)
|
||||
script = self._generate_script_doubao(system_prompt, user_prompt, image_paths)
|
||||
if script:
|
||||
script["selling_points"] = self._postprocess_selling_points(product_info, script.get("selling_points"))
|
||||
return script
|
||||
|
||||
# ... Existing Shubiaobiao Logic ...
|
||||
# Branch for ShuBiaoBiao GPT (OpenAI-compatible multimodal)
|
||||
if model_provider == "shubiaobiao_gpt":
|
||||
script = self._generate_script_shubiaobiao_openai(system_prompt, user_prompt, image_paths, model_name="gpt-5.2")
|
||||
if script:
|
||||
script["selling_points"] = self._postprocess_selling_points(product_info, script.get("selling_points"))
|
||||
return script
|
||||
|
||||
# Branch for ShuBiaoBiao Gemini (OpenAI-compatible; use image URLs instead of base64)
|
||||
if model_provider == "shubiaobiao":
|
||||
script = self._generate_script_shubiaobiao_openai(system_prompt, user_prompt, image_paths, model_name=config.SHUBIAOBIAO_MODEL_TEXT)
|
||||
if script:
|
||||
script["selling_points"] = self._postprocess_selling_points(product_info, script.get("selling_points"))
|
||||
return script
|
||||
|
||||
# Fallback (should not normally reach here)
|
||||
|
||||
# 调试: 检查是否使用了自定义 Prompt
|
||||
if system_prompt != self.default_system_prompt:
|
||||
@@ -218,6 +240,8 @@ class ScriptGenerator:
|
||||
return None
|
||||
|
||||
final_script = self._validate_and_fix_script(script_json)
|
||||
# 不改 prompt 的前提下:对卖点做轻量规则化(更具体、更可执行)
|
||||
final_script["selling_points"] = self._postprocess_selling_points(product_info, final_script.get("selling_points"))
|
||||
|
||||
# Add Debug Info (包含原始输出)
|
||||
final_script["_debug"] = {
|
||||
@@ -237,6 +261,165 @@ class ScriptGenerator:
|
||||
logger.error(f"Response content: {response.text}")
|
||||
return None
|
||||
|
||||
def _upload_images_to_r2(self, image_paths: List[str], limit: int = 10) -> List[str]:
|
||||
urls: List[str] = []
|
||||
if not image_paths:
|
||||
return urls
|
||||
# NOTE: avoid hard import dependency at app startup.
|
||||
# If boto3 / storage is not installed on the runtime, we should not crash Streamlit.
|
||||
try:
|
||||
from modules import storage # lazy import
|
||||
except Exception as e:
|
||||
logger.warning(f"R2 upload disabled (storage/boto3 unavailable): {e}")
|
||||
return urls
|
||||
for p in image_paths[:limit]:
|
||||
try:
|
||||
if p and Path(p).exists():
|
||||
url = storage.upload_file(str(p))
|
||||
if url:
|
||||
urls.append(url)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to upload script image to R2: {p} ({e})")
|
||||
return urls
|
||||
|
||||
def _generate_script_shubiaobiao_openai(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
image_paths: List[str],
|
||||
model_name: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
ShuBiaoBiao OpenAI-compatible multimodal chat.
|
||||
IMPORTANT: For ShuBiaoBiao models, we pass image URLs (R2 public URLs), not base64.
|
||||
"""
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
user_content: List[Dict[str, Any]] = []
|
||||
# Images first (URL), then text
|
||||
urls = self._upload_images_to_r2(image_paths or [], limit=10)
|
||||
for url in urls:
|
||||
user_content.append({"type": "image_url", "image_url": {"url": url}})
|
||||
user_content.append({"type": "text", "text": user_prompt})
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
try:
|
||||
resp = self.shubiaobiao_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
)
|
||||
content_text = (resp.choices[0].message.content or "").strip()
|
||||
script_json = self._extract_json_from_response(content_text)
|
||||
if script_json is None:
|
||||
logger.error(f"Failed to extract JSON from shubiaobiao response({model_name}): {content_text[:500]}...")
|
||||
return None
|
||||
final_script = self._validate_and_fix_script(script_json)
|
||||
final_script["_debug"] = {
|
||||
"system_prompt": system_prompt,
|
||||
"user_prompt": user_prompt,
|
||||
"raw_output": content_text,
|
||||
"provider": f"shubiaobiao:{model_name}",
|
||||
"image_urls": urls,
|
||||
}
|
||||
return final_script
|
||||
except Exception as e:
|
||||
logger.error(f"shubiaobiao script generation failed ({model_name}): {e}")
|
||||
return None
|
||||
|
||||
def _postprocess_selling_points(self, product_info: Dict[str, Any], selling_points: Any) -> List[str]:
|
||||
"""
|
||||
Engineering-only postprocess (NO prompt change):
|
||||
- De-duplicate
|
||||
- Prefer specific points derived from tags/params when LLM points are too generic
|
||||
"""
|
||||
def _norm_list(v: Any) -> List[str]:
|
||||
if isinstance(v, list):
|
||||
return [str(x).strip() for x in v if str(x).strip()]
|
||||
if isinstance(v, str) and v.strip():
|
||||
return [v.strip()]
|
||||
return []
|
||||
|
||||
tags = str((product_info or {}).get("tags", "") or "")
|
||||
params = str((product_info or {}).get("params", "") or "")
|
||||
category = str((product_info or {}).get("category", "") or "")
|
||||
# Candidates from tags/params/category
|
||||
raw = " ".join([tags, params, category])
|
||||
parts = []
|
||||
for sep in ["|", ";", ";", "、", ",", ",", "\n"]:
|
||||
raw = raw.replace(sep, "|")
|
||||
for p in raw.split("|"):
|
||||
p = p.strip()
|
||||
if not p:
|
||||
continue
|
||||
# params like "key:value"
|
||||
if ":" in p:
|
||||
kv = [x.strip() for x in p.split(":", 1)]
|
||||
if len(kv) == 2 and kv[1]:
|
||||
parts.append(kv[1])
|
||||
else:
|
||||
parts.append(p)
|
||||
else:
|
||||
parts.append(p)
|
||||
|
||||
# Heuristic: keep more concrete phrases
|
||||
generic_words = [
|
||||
"好看", "百搭", "高级", "气质", "显白", "好用", "耐看", "时尚", "精致", "很棒", "不错", "喜欢", "推荐",
|
||||
"性价比", "划算", "超值", "必入", "绝了",
|
||||
]
|
||||
concrete_hints = [
|
||||
"不掉", "牢固", "防滑", "加厚", "大号", "小号", "强力", "耐用", "稳固", "不勒", "不伤",
|
||||
"树脂", "金属", "水钻", "夹", "材质", "尺寸", "弹簧", "夹力", "发量", "马尾",
|
||||
]
|
||||
candidates: List[str] = []
|
||||
seen = set()
|
||||
for p in parts:
|
||||
if not p or p in seen:
|
||||
continue
|
||||
seen.add(p)
|
||||
if any(h in p for h in concrete_hints) and not any(g in p for g in ["喜欢", "推荐"]):
|
||||
candidates.append(p)
|
||||
# fallback: keep non-empty tags
|
||||
if not candidates:
|
||||
candidates = [p for p in parts if p][:8]
|
||||
|
||||
points = _norm_list(selling_points)
|
||||
# de-dup preserving order
|
||||
out: List[str] = []
|
||||
used = set()
|
||||
for p in points:
|
||||
if p in used:
|
||||
continue
|
||||
used.add(p)
|
||||
out.append(p)
|
||||
|
||||
def _is_generic(p: str) -> bool:
|
||||
# treat as generic if only contains generic words and lacks any concrete hints
|
||||
if any(h in p for h in concrete_hints):
|
||||
return False
|
||||
return any(g in p for g in generic_words)
|
||||
|
||||
# Replace overly generic points with better candidates (keep length <= 3-5 ideally)
|
||||
cand_iter = (c for c in candidates if c not in used)
|
||||
refined: List[str] = []
|
||||
for p in out:
|
||||
if _is_generic(p):
|
||||
c = next(cand_iter, None)
|
||||
if c:
|
||||
refined.append(c)
|
||||
used.add(c)
|
||||
continue
|
||||
refined.append(p)
|
||||
|
||||
# Ensure at least 3 selling points if possible
|
||||
while len(refined) < 3:
|
||||
c = next(cand_iter, None)
|
||||
if not c:
|
||||
break
|
||||
refined.append(c)
|
||||
used.add(c)
|
||||
|
||||
return refined[:5]
|
||||
|
||||
def _generate_script_doubao(
|
||||
self,
|
||||
system_prompt: str,
|
||||
|
||||
Reference in New Issue
Block a user