diff --git a/modules/script_gen.py b/modules/script_gen.py index abdfa75..1a3f018 100644 --- a/modules/script_gen.py +++ b/modules/script_gen.py @@ -9,10 +9,10 @@ import os import requests from typing import Dict, Any, List, Optional from pathlib import Path +from openai import OpenAI import config from modules.db_manager import db - logger = logging.getLogger(__name__) class ScriptGenerator: @@ -24,6 +24,11 @@ class ScriptGenerator: # 根据 demo: https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent # 这里我们先假设 base_url 是 v1beta/models/ self.endpoint = "https://api.shubiaobiao.cn/v1beta/models/gemini-3-pro-preview:generateContent" + # OpenAI-compatible client for ShuBiaoBiao (supports multiple models incl. GPT) + self.shubiaobiao_client = OpenAI( + api_key=config.SHUBIAOBIAO_KEY, + base_url=config.SHUBIAOBIAO_BASE_URL + ) # Default System Prompt self.default_system_prompt = """ @@ -145,11 +150,28 @@ class ScriptGenerator: system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt) user_prompt = self._build_user_prompt(product_name, product_info) - # Branch for Doubao + # Branch for Doubao (Volcengine) if model_provider == "doubao": - return self._generate_script_doubao(system_prompt, user_prompt, image_paths) + script = self._generate_script_doubao(system_prompt, user_prompt, image_paths) + if script: + script["selling_points"] = self._postprocess_selling_points(product_info, script.get("selling_points")) + return script - # ... Existing Shubiaobiao Logic ... + # Branch for ShuBiaoBiao GPT (OpenAI-compatible multimodal) + if model_provider == "shubiaobiao_gpt": + script = self._generate_script_shubiaobiao_openai(system_prompt, user_prompt, image_paths, model_name="gpt-5.2") + if script: + script["selling_points"] = self._postprocess_selling_points(product_info, script.get("selling_points")) + return script + + # Branch for ShuBiaoBiao Gemini (OpenAI-compatible; use image URLs instead of base64) + if model_provider == "shubiaobiao": + script = self._generate_script_shubiaobiao_openai(system_prompt, user_prompt, image_paths, model_name=config.SHUBIAOBIAO_MODEL_TEXT) + if script: + script["selling_points"] = self._postprocess_selling_points(product_info, script.get("selling_points")) + return script + + # Fallback (should not normally reach here) # 调试: 检查是否使用了自定义 Prompt if system_prompt != self.default_system_prompt: @@ -218,6 +240,8 @@ class ScriptGenerator: return None final_script = self._validate_and_fix_script(script_json) + # 不改 prompt 的前提下:对卖点做轻量规则化(更具体、更可执行) + final_script["selling_points"] = self._postprocess_selling_points(product_info, final_script.get("selling_points")) # Add Debug Info (包含原始输出) final_script["_debug"] = { @@ -237,6 +261,165 @@ class ScriptGenerator: logger.error(f"Response content: {response.text}") return None + def _upload_images_to_r2(self, image_paths: List[str], limit: int = 10) -> List[str]: + urls: List[str] = [] + if not image_paths: + return urls + # NOTE: avoid hard import dependency at app startup. + # If boto3 / storage is not installed on the runtime, we should not crash Streamlit. + try: + from modules import storage # lazy import + except Exception as e: + logger.warning(f"R2 upload disabled (storage/boto3 unavailable): {e}") + return urls + for p in image_paths[:limit]: + try: + if p and Path(p).exists(): + url = storage.upload_file(str(p)) + if url: + urls.append(url) + except Exception as e: + logger.warning(f"Failed to upload script image to R2: {p} ({e})") + return urls + + def _generate_script_shubiaobiao_openai( + self, + system_prompt: str, + user_prompt: str, + image_paths: List[str], + model_name: str, + ) -> Optional[Dict[str, Any]]: + """ + ShuBiaoBiao OpenAI-compatible multimodal chat. + IMPORTANT: For ShuBiaoBiao models, we pass image URLs (R2 public URLs), not base64. + """ + messages = [{"role": "system", "content": system_prompt}] + user_content: List[Dict[str, Any]] = [] + # Images first (URL), then text + urls = self._upload_images_to_r2(image_paths or [], limit=10) + for url in urls: + user_content.append({"type": "image_url", "image_url": {"url": url}}) + user_content.append({"type": "text", "text": user_prompt}) + messages.append({"role": "user", "content": user_content}) + + try: + resp = self.shubiaobiao_client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0.7, + ) + content_text = (resp.choices[0].message.content or "").strip() + script_json = self._extract_json_from_response(content_text) + if script_json is None: + logger.error(f"Failed to extract JSON from shubiaobiao response({model_name}): {content_text[:500]}...") + return None + final_script = self._validate_and_fix_script(script_json) + final_script["_debug"] = { + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "raw_output": content_text, + "provider": f"shubiaobiao:{model_name}", + "image_urls": urls, + } + return final_script + except Exception as e: + logger.error(f"shubiaobiao script generation failed ({model_name}): {e}") + return None + + def _postprocess_selling_points(self, product_info: Dict[str, Any], selling_points: Any) -> List[str]: + """ + Engineering-only postprocess (NO prompt change): + - De-duplicate + - Prefer specific points derived from tags/params when LLM points are too generic + """ + def _norm_list(v: Any) -> List[str]: + if isinstance(v, list): + return [str(x).strip() for x in v if str(x).strip()] + if isinstance(v, str) and v.strip(): + return [v.strip()] + return [] + + tags = str((product_info or {}).get("tags", "") or "") + params = str((product_info or {}).get("params", "") or "") + category = str((product_info or {}).get("category", "") or "") + # Candidates from tags/params/category + raw = " ".join([tags, params, category]) + parts = [] + for sep in ["|", ";", ";", "、", ",", ",", "\n"]: + raw = raw.replace(sep, "|") + for p in raw.split("|"): + p = p.strip() + if not p: + continue + # params like "key:value" + if ":" in p: + kv = [x.strip() for x in p.split(":", 1)] + if len(kv) == 2 and kv[1]: + parts.append(kv[1]) + else: + parts.append(p) + else: + parts.append(p) + + # Heuristic: keep more concrete phrases + generic_words = [ + "好看", "百搭", "高级", "气质", "显白", "好用", "耐看", "时尚", "精致", "很棒", "不错", "喜欢", "推荐", + "性价比", "划算", "超值", "必入", "绝了", + ] + concrete_hints = [ + "不掉", "牢固", "防滑", "加厚", "大号", "小号", "强力", "耐用", "稳固", "不勒", "不伤", + "树脂", "金属", "水钻", "夹", "材质", "尺寸", "弹簧", "夹力", "发量", "马尾", + ] + candidates: List[str] = [] + seen = set() + for p in parts: + if not p or p in seen: + continue + seen.add(p) + if any(h in p for h in concrete_hints) and not any(g in p for g in ["喜欢", "推荐"]): + candidates.append(p) + # fallback: keep non-empty tags + if not candidates: + candidates = [p for p in parts if p][:8] + + points = _norm_list(selling_points) + # de-dup preserving order + out: List[str] = [] + used = set() + for p in points: + if p in used: + continue + used.add(p) + out.append(p) + + def _is_generic(p: str) -> bool: + # treat as generic if only contains generic words and lacks any concrete hints + if any(h in p for h in concrete_hints): + return False + return any(g in p for g in generic_words) + + # Replace overly generic points with better candidates (keep length <= 3-5 ideally) + cand_iter = (c for c in candidates if c not in used) + refined: List[str] = [] + for p in out: + if _is_generic(p): + c = next(cand_iter, None) + if c: + refined.append(c) + used.add(c) + continue + refined.append(p) + + # Ensure at least 3 selling points if possible + while len(refined) < 3: + c = next(cand_iter, None) + if not c: + break + refined.append(c) + used.add(c) + + return refined[:5] + def _generate_script_doubao( self, system_prompt: str,