chore: sync code and project files
This commit is contained in:
@@ -6,6 +6,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pathlib import Path
|
||||
@@ -27,7 +28,10 @@ class ScriptGenerator:
|
||||
# OpenAI-compatible client for ShuBiaoBiao (supports multiple models incl. GPT)
|
||||
self.shubiaobiao_client = OpenAI(
|
||||
api_key=config.SHUBIAOBIAO_KEY,
|
||||
base_url=config.SHUBIAOBIAO_BASE_URL
|
||||
base_url=config.SHUBIAOBIAO_BASE_URL,
|
||||
# IMPORTANT: OpenAI SDK default timeout is 10 minutes; cap it to keep UX responsive.
|
||||
timeout=config.SHUBIAOBIAO_CHAT_TIMEOUT_S,
|
||||
max_retries=config.SHUBIAOBIAO_CHAT_MAX_RETRIES,
|
||||
)
|
||||
|
||||
# Default System Prompt
|
||||
@@ -139,15 +143,23 @@ class ScriptGenerator:
|
||||
product_name: str,
|
||||
product_info: Dict[str, Any],
|
||||
image_paths: List[str] = None,
|
||||
model_provider: str = "shubiaobiao" # "shubiaobiao" or "doubao"
|
||||
model_provider: str = "shubiaobiao", # "shubiaobiao" or "doubao"
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成分镜脚本
|
||||
"""
|
||||
logger.info(f"Generating script for: {product_name} (Provider: {model_provider})")
|
||||
|
||||
# 1. 构造 Prompt (优先从数据库读取配置)
|
||||
system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt)
|
||||
# 1. 构造 Prompt (优先按 user_id 读取;否则回退到全局配置,再回退默认)
|
||||
system_prompt = None
|
||||
if user_id:
|
||||
try:
|
||||
system_prompt = db.get_user_prompt(user_id, "prompt_script_gen")
|
||||
except Exception:
|
||||
system_prompt = None
|
||||
if not system_prompt:
|
||||
system_prompt = db.get_config("prompt_script_gen", self.default_system_prompt)
|
||||
user_prompt = self._build_user_prompt(product_name, product_info)
|
||||
|
||||
# Branch for Doubao (Volcengine)
|
||||
@@ -293,21 +305,40 @@ class ScriptGenerator:
|
||||
ShuBiaoBiao OpenAI-compatible multimodal chat.
|
||||
IMPORTANT: For ShuBiaoBiao models, we pass image URLs (R2 public URLs), not base64.
|
||||
"""
|
||||
t0 = time.time()
|
||||
# Use WARNING level so it shows up even if Streamlit/root logger is not configured to INFO.
|
||||
logger.warning(
|
||||
f"[script_gen] start shubiaobiao chat model={model_name} images={len(image_paths or [])} "
|
||||
f"timeout_s={getattr(config, 'SHUBIAOBIAO_CHAT_TIMEOUT_S', 'n/a')} "
|
||||
f"max_retries={getattr(config, 'SHUBIAOBIAO_CHAT_MAX_RETRIES', 'n/a')}"
|
||||
)
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
user_content: List[Dict[str, Any]] = []
|
||||
# Images first (URL), then text
|
||||
t_upload0 = time.time()
|
||||
urls = self._upload_images_to_r2(image_paths or [], limit=10)
|
||||
logger.warning(
|
||||
f"[script_gen] r2_upload done urls={len(urls)} elapsed_s={time.time() - t_upload0:.2f}"
|
||||
)
|
||||
for url in urls:
|
||||
user_content.append({"type": "image_url", "image_url": {"url": url}})
|
||||
user_content.append({"type": "text", "text": user_prompt})
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
try:
|
||||
resp = self.shubiaobiao_client.chat.completions.create(
|
||||
client = self.shubiaobiao_client.with_options(
|
||||
timeout=config.SHUBIAOBIAO_CHAT_TIMEOUT_S,
|
||||
max_retries=config.SHUBIAOBIAO_CHAT_MAX_RETRIES,
|
||||
)
|
||||
t_call0 = time.time()
|
||||
resp = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
)
|
||||
logger.warning(
|
||||
f"[script_gen] shubiaobiao chat done elapsed_s={time.time() - t_call0:.2f} total_s={time.time() - t0:.2f}"
|
||||
)
|
||||
content_text = (resp.choices[0].message.content or "").strip()
|
||||
script_json = self._extract_json_from_response(content_text)
|
||||
if script_json is None:
|
||||
@@ -323,7 +354,9 @@ class ScriptGenerator:
|
||||
}
|
||||
return final_script
|
||||
except Exception as e:
|
||||
logger.error(f"shubiaobiao script generation failed ({model_name}): {e}")
|
||||
logger.error(
|
||||
f"shubiaobiao script generation failed ({model_name}) after {time.time() - t0:.2f}s: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _postprocess_selling_points(self, product_info: Dict[str, Any], selling_points: Any) -> List[str]:
|
||||
@@ -582,7 +615,33 @@ class ScriptGenerator:
|
||||
|
||||
def _validate_and_fix_script(self, script: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""校验并修复脚本结构"""
|
||||
# 简单校验,确保必要字段存在
|
||||
if "scenes" not in script:
|
||||
if not isinstance(script, dict):
|
||||
return {"scenes": []}
|
||||
|
||||
# Ensure fields exist
|
||||
if "scenes" not in script or not isinstance(script.get("scenes"), list):
|
||||
script["scenes"] = []
|
||||
|
||||
# Normalize: keep visual_anchor at top-level, but avoid repeating the full anchor in every scene.visual_prompt.
|
||||
# Reason: repeating a long anchor 4-5 times explodes tokens and makes UI look like "only three sections",
|
||||
# while image generation already supports passing visual_anchor separately and prepending it at runtime.
|
||||
visual_anchor = script.get("visual_anchor") or ""
|
||||
if isinstance(visual_anchor, str) and visual_anchor.strip() and script["scenes"]:
|
||||
anchor = visual_anchor.strip()
|
||||
prefix = f"[{anchor}]"
|
||||
for scene in script["scenes"]:
|
||||
if not isinstance(scene, dict):
|
||||
continue
|
||||
vp = scene.get("visual_prompt")
|
||||
if not isinstance(vp, str) or not vp.strip():
|
||||
continue
|
||||
s = vp.strip()
|
||||
# Strip exact "[anchor]" prefix if present
|
||||
if s.startswith(prefix):
|
||||
s = s[len(prefix):].lstrip()
|
||||
# If the model output copied the raw anchor without brackets, strip it too
|
||||
elif s.startswith(anchor):
|
||||
s = s[len(anchor):].lstrip()
|
||||
scene["visual_prompt"] = s
|
||||
|
||||
return script
|
||||
|
||||
Reference in New Issue
Block a user