1684 lines
84 KiB
Python
1684 lines
84 KiB
Python
"""
|
||
MatchMe Studio - UI (Streamlit)
|
||
Style: Kaogujia (Clean, Data-heavy, Professional)
|
||
"""
|
||
import streamlit as st
|
||
import json
|
||
import time
|
||
import os
|
||
import random
|
||
from pathlib import Path
|
||
import pandas as pd
|
||
from time import perf_counter
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from os import stat
|
||
|
||
# Import Backend Modules
|
||
import config
|
||
from modules.script_gen import ScriptGenerator
|
||
from modules.image_gen import ImageGenerator
|
||
from modules.video_gen import VideoGenerator
|
||
from modules.composer import VideoComposer, VideoComposer as Composer # alias
|
||
from modules.text_renderer import renderer
|
||
from modules import export_utils
|
||
from modules.db_manager import db
|
||
from modules import path_utils
|
||
from modules import limits
|
||
from modules.legacy_path_mapper import map_legacy_local_path
|
||
from modules.legacy_normalizer import normalize_legacy_project
|
||
import extra_streamlit_components as stx
|
||
|
||
# Page Config
|
||
st.set_page_config(
|
||
page_title="Video Flow Console",
|
||
page_icon="🎬",
|
||
layout="wide",
|
||
initial_sidebar_state="expanded"
|
||
)
|
||
|
||
# ============================================================
|
||
# Auth (login + remember cookie)
|
||
# ============================================================
|
||
COOKIE_NAME = "vf_session"
|
||
cookie_manager = stx.CookieManager(key="vf_cookie_mgr")
|
||
|
||
def _current_user() -> dict:
|
||
u = st.session_state.get("current_user")
|
||
return u if isinstance(u, dict) else None
|
||
|
||
def _set_current_user(u: dict):
|
||
st.session_state.current_user = u
|
||
|
||
def _try_restore_login():
|
||
if _current_user():
|
||
return
|
||
try:
|
||
token = cookie_manager.get(COOKIE_NAME)
|
||
except Exception:
|
||
token = None
|
||
if token:
|
||
u = db.validate_session(token)
|
||
if u:
|
||
_set_current_user(u)
|
||
|
||
def _logout():
|
||
try:
|
||
token = cookie_manager.get(COOKIE_NAME)
|
||
except Exception:
|
||
token = None
|
||
if token:
|
||
db.revoke_session(token)
|
||
try:
|
||
cookie_manager.delete(COOKIE_NAME)
|
||
except Exception:
|
||
pass
|
||
st.session_state.pop("current_user", None)
|
||
st.rerun()
|
||
|
||
def _login_gate():
|
||
_try_restore_login()
|
||
u = _current_user()
|
||
if u:
|
||
return
|
||
st.title("登录")
|
||
# Login page runs before the global CSS block below; inject minimal CSS here.
|
||
st.markdown(
|
||
"""
|
||
<style>
|
||
button[aria-label="Show password text"],
|
||
button[aria-label="Hide password text"] {
|
||
display: none !important;
|
||
}
|
||
</style>
|
||
""",
|
||
unsafe_allow_html=True,
|
||
)
|
||
c1, c2 = st.columns([1, 2])
|
||
with c1:
|
||
username = st.text_input("用户名", value="", key="login_user")
|
||
password = st.text_input("密码", value="", type="password", key="login_pass")
|
||
if st.button("登录", type="primary"):
|
||
au = db.authenticate_user(username.strip(), password)
|
||
if not au:
|
||
st.error("用户名或密码错误,或账号已禁用")
|
||
st.stop()
|
||
token = db.create_session(au["id"])
|
||
try:
|
||
cookie_manager.set(COOKIE_NAME, token, max_age=7 * 24 * 3600)
|
||
except Exception:
|
||
# fallback: session-only
|
||
pass
|
||
_set_current_user(au)
|
||
st.rerun()
|
||
st.stop()
|
||
|
||
_login_gate()
|
||
|
||
# ============================================================
|
||
# BGM 智能匹配函数
|
||
# ============================================================
|
||
def match_bgm_by_style(bgm_style: str, bgm_dir: Path) -> str:
|
||
"""
|
||
根据脚本 bgm_style 智能匹配 BGM 文件
|
||
- 匹配成功:随机选一个匹配的 BGM
|
||
- 匹配失败:随机选任意一个 BGM
|
||
"""
|
||
# 获取所有 BGM 文件 (支持 .mp3 和 .mp4)
|
||
bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3"))
|
||
bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')]
|
||
|
||
if not bgm_files:
|
||
return None
|
||
|
||
# 关键词匹配
|
||
if bgm_style:
|
||
style_lower = bgm_style.lower()
|
||
# 提取关键词 (中文分词简化版:按常见词匹配)
|
||
keywords = ["活泼", "欢快", "轻松", "舒缓", "休闲", "温柔", "随性", "百搭", "bling", "节奏"]
|
||
matched_keywords = [kw for kw in keywords if kw in style_lower]
|
||
|
||
matched_files = []
|
||
for f in bgm_files:
|
||
fname = f.name
|
||
if any(kw in fname for kw in matched_keywords):
|
||
matched_files.append(f)
|
||
|
||
if matched_files:
|
||
return str(random.choice(matched_files))
|
||
|
||
# 无匹配则随机选一个
|
||
return str(random.choice(bgm_files))
|
||
|
||
# ============================================================
|
||
# CSS Styling (Kaogujia Style)
|
||
# ============================================================
|
||
st.markdown("""
|
||
<style>
|
||
/* 只保留颜色和字体样式,完全不干预布局和滚动 */
|
||
.stApp {
|
||
background-color: #F4F5F7;
|
||
font-family: "PingFang SC", "Microsoft YaHei", sans-serif;
|
||
}
|
||
|
||
section[data-testid="stSidebar"] {
|
||
background-color: #FFFFFF;
|
||
border-right: 1px solid #E5E7EB;
|
||
}
|
||
|
||
.stButton button {
|
||
border-radius: 4px;
|
||
font-weight: 500;
|
||
}
|
||
|
||
div[data-testid="stButton"] > button[kind="primary"] {
|
||
background-color: #1677FF;
|
||
color: white;
|
||
border: none;
|
||
}
|
||
div[data-testid="stButton"] > button[kind="primary"]:hover {
|
||
background-color: #4096FF;
|
||
}
|
||
|
||
h1, h2, h3 {
|
||
color: #1F2329;
|
||
font-weight: 600;
|
||
}
|
||
h2 { border-left: 4px solid #1677FF; padding-left: 10px; }
|
||
|
||
.stTextInput input, .stTextArea textarea {
|
||
border-radius: 4px;
|
||
}
|
||
|
||
/* Hide Streamlit password reveal (eye) buttons to avoid accidental exposure */
|
||
button[aria-label="Show password text"],
|
||
button[aria-label="Hide password text"] {
|
||
display: none !important;
|
||
}
|
||
</style>
|
||
""", unsafe_allow_html=True)
|
||
|
||
# ============================================================
|
||
# Session State Management
|
||
# ============================================================
|
||
if "project_id" not in st.session_state:
|
||
st.session_state.project_id = None
|
||
if "current_step" not in st.session_state:
|
||
st.session_state.current_step = 0
|
||
if "script_data" not in st.session_state:
|
||
st.session_state.script_data = None
|
||
if "scene_images" not in st.session_state:
|
||
st.session_state.scene_images = {}
|
||
if "scene_videos" not in st.session_state:
|
||
st.session_state.scene_videos = {}
|
||
if "final_video" not in st.session_state:
|
||
st.session_state.final_video = None
|
||
if "uploaded_images" not in st.session_state:
|
||
st.session_state.uploaded_images = []
|
||
if "view_mode" not in st.session_state:
|
||
st.session_state.view_mode = "workspace" # workspace, history, settings
|
||
if "selected_img_provider" not in st.session_state:
|
||
st.session_state.selected_img_provider = "shubiaobiao"
|
||
if "ui_rev" not in st.session_state:
|
||
# used to scope Streamlit widget keys per project load to avoid cross-project state pollution
|
||
st.session_state.ui_rev = int(time.time() * 1000)
|
||
|
||
|
||
def _ui_key(suffix: str) -> str:
|
||
"""
|
||
Build a project-scoped widget key.
|
||
Streamlit widget state is keyed globally; without project scoping, switching projects can
|
||
reuse old widget states and even trigger frontend DOM errors (e.g. removeChild NotFoundError).
|
||
"""
|
||
pid = st.session_state.get("project_id") or "NEW"
|
||
rev = st.session_state.get("ui_rev") or 0
|
||
return f"p:{pid}|r:{rev}|{suffix}"
|
||
|
||
def load_project(project_id):
|
||
"""Load project state from DB"""
|
||
data = db.get_project_for_user(project_id, _current_user())
|
||
if not data:
|
||
st.error("Project not found / no permission")
|
||
return
|
||
|
||
st.session_state.project_id = project_id
|
||
# bump UI revision to ensure all widget keys are isolated per project load
|
||
st.session_state.ui_rev = int(time.time() * 1000)
|
||
st.session_state.script_data = data.get("script_data")
|
||
st.session_state.view_mode = "workspace"
|
||
|
||
# Fallback: 如果 DB 中的 script_data 是旧结构/缺字段,则从 legacy JSON 重新规范化一次
|
||
try:
|
||
script_data = st.session_state.script_data
|
||
legacy_json = Path(config.TEMP_DIR) / f"project_{project_id}.json"
|
||
|
||
def _needs_normalize(sd: Any) -> bool:
|
||
if not isinstance(sd, dict):
|
||
return True
|
||
if "_legacy_schema" not in sd:
|
||
return True
|
||
scenes = sd.get("scenes") or []
|
||
if scenes and isinstance(scenes, list) and isinstance(scenes[0], dict):
|
||
if "visual_prompt" not in scenes[0] or "video_prompt" not in scenes[0]:
|
||
return True
|
||
return False
|
||
|
||
if legacy_json.exists() and _needs_normalize(script_data):
|
||
raw = json.loads(legacy_json.read_text(encoding="utf-8"))
|
||
normalized = normalize_legacy_project(raw)
|
||
st.session_state.script_data = normalized
|
||
# 写回 DB,避免每次 load 都重新算
|
||
db.update_project_script(project_id, normalized)
|
||
st.info("已从 legacy JSON 重新规范化脚本字段(兼容旧版项目)。")
|
||
except Exception as e:
|
||
st.warning(f"legacy 规范化失败(将继续使用 DB 数据): {e}")
|
||
|
||
# Restore product info for Step 1 display
|
||
product_info = data.get("product_info", {})
|
||
st.session_state.loaded_product_name = data.get("name", "")
|
||
st.session_state.loaded_product_info = product_info
|
||
|
||
# Restore uploaded images from product_info
|
||
if product_info and "uploaded_images" in product_info:
|
||
# 检查文件是否仍存在
|
||
valid_paths = [p for p in product_info["uploaded_images"] if os.path.exists(p)]
|
||
st.session_state.uploaded_images = valid_paths
|
||
if len(valid_paths) < len(product_info["uploaded_images"]):
|
||
st.warning(f"部分原始图片已丢失 ({len(product_info['uploaded_images']) - len(valid_paths)} 张),建议重新上传")
|
||
else:
|
||
st.session_state.uploaded_images = []
|
||
|
||
# Restore assets
|
||
# RBAC: assets are project-scoped, so permission already checked above.
|
||
assets = db.get_assets(project_id)
|
||
images = {}
|
||
videos = {}
|
||
final_vid = None
|
||
|
||
for asset in assets:
|
||
sid = asset["scene_id"]
|
||
source_path, _mapped_url = map_legacy_local_path(asset.get("local_path"))
|
||
# 假设 scene_id 0 或 -1 用于 final video
|
||
if asset["asset_type"] == "image" and asset["status"] == "completed":
|
||
if source_path:
|
||
images[sid] = source_path
|
||
elif asset["asset_type"] == "video" and asset["status"] == "completed":
|
||
if source_path:
|
||
videos[sid] = source_path
|
||
elif asset["asset_type"] == "final_video" and asset["status"] == "completed":
|
||
if source_path:
|
||
final_vid = source_path
|
||
|
||
st.session_state.scene_images = images
|
||
st.session_state.scene_videos = videos
|
||
st.session_state.final_video = final_vid
|
||
|
||
# Determine step
|
||
if final_vid:
|
||
st.session_state.current_step = 4
|
||
elif videos:
|
||
st.session_state.current_step = 4
|
||
elif images:
|
||
st.session_state.current_step = 3
|
||
elif st.session_state.script_data:
|
||
st.session_state.current_step = 1
|
||
else:
|
||
st.session_state.current_step = 0
|
||
|
||
# ============================================================
|
||
# Helper Functions (must be defined before Sidebar uses them)
|
||
# ============================================================
|
||
def _record_metrics(project_id: str, patch: dict):
|
||
"""Persist lightweight timing/diagnostic metrics into project.product_info['_metrics']."""
|
||
if not project_id or not isinstance(patch, dict) or not patch:
|
||
return
|
||
try:
|
||
proj = db.get_project(project_id) or {}
|
||
product_info = proj.get("product_info") or {}
|
||
metrics = product_info.get("_metrics") if isinstance(product_info.get("_metrics"), dict) else {}
|
||
metrics.update(patch)
|
||
metrics["updated_at"] = time.time()
|
||
product_info["_metrics"] = metrics
|
||
db.update_project_product_info(project_id, product_info)
|
||
except Exception:
|
||
# metrics must never break UX
|
||
pass
|
||
|
||
|
||
def _get_metrics(project_id: str) -> dict:
|
||
try:
|
||
proj = db.get_project(project_id) or {}
|
||
product_info = proj.get("product_info") or {}
|
||
m = product_info.get("_metrics")
|
||
return m if isinstance(m, dict) else {}
|
||
except Exception:
|
||
return {}
|
||
|
||
|
||
def _ensure_local_videos(project_id: str, scenes: list):
|
||
"""
|
||
Step 5 合成前的保障逻辑:
|
||
检查分镜视频是否已下载到服务器本地。如果只有 URL 没有本地文件,则后台静默下载。
|
||
"""
|
||
if not project_id or not scenes:
|
||
return
|
||
|
||
vid_gen = VideoGenerator()
|
||
downloaded = 0
|
||
missing_assets = []
|
||
|
||
for scene in scenes:
|
||
scene_id = scene["id"]
|
||
local_path = st.session_state.scene_videos.get(scene_id)
|
||
|
||
# 如果本地文件不存在,尝试补全
|
||
if not local_path or not os.path.exists(local_path):
|
||
asset = db.get_asset(project_id, scene_id, "video")
|
||
if not asset:
|
||
missing_assets.append(f"Scene {scene_id}")
|
||
continue
|
||
|
||
meta = asset.get("metadata") or {}
|
||
video_url = meta.get("video_url")
|
||
task_id = asset.get("task_id")
|
||
|
||
# 如果 DB 里没有 URL,但有 task_id,尝试现场查询一次
|
||
if not video_url and task_id:
|
||
logger.info(f"Checking task {task_id} status for Scene {scene_id} during composition...")
|
||
status, url = vid_gen.check_task_status(task_id)
|
||
if status == "succeeded" and url:
|
||
video_url = url
|
||
db.update_asset_metadata(project_id, scene_id, "video", {"video_url": url, "volc_status": status})
|
||
|
||
if video_url:
|
||
out_name = path_utils.unique_filename(
|
||
prefix="scene_video",
|
||
ext="mp4",
|
||
project_id=project_id,
|
||
scene_id=scene_id,
|
||
extra=f"auto_{int(time.time())}"
|
||
)
|
||
logger.info(f"Auto-downloading missing video for Scene {scene_id} from {video_url}")
|
||
target_dir = path_utils.project_videos_dir(project_id)
|
||
target_path = str(target_dir / out_name)
|
||
if vid_gen._download_video_to(video_url, target_path):
|
||
st.session_state.scene_videos[scene_id] = target_path
|
||
db.save_asset(project_id, scene_id, "video", "completed", local_path=target_path)
|
||
downloaded += 1
|
||
else:
|
||
missing_assets.append(f"Scene {scene_id} (未生成/无URL)")
|
||
|
||
if downloaded > 0:
|
||
logger.info(f"Successfully auto-downloaded {downloaded} missing videos for project {project_id}")
|
||
|
||
if missing_assets:
|
||
msg = f"无法合成:以下分镜缺少视频素材,请先在第 4 步生成:{', '.join(missing_assets)}"
|
||
logger.warning(msg)
|
||
raise ValueError(msg)
|
||
|
||
# ============================================================
|
||
# Sidebar
|
||
# ============================================================
|
||
with st.sidebar:
|
||
st.title("📽️ Video Flow")
|
||
cu = _current_user()
|
||
if cu:
|
||
st.caption(f"登录用户: {cu.get('username')} ({cu.get('role')})")
|
||
if st.button("退出登录", type="secondary"):
|
||
_logout()
|
||
|
||
# Mode Selection - 正确计算 index
|
||
mode_options = ["🛠️ 工作台", "📜 历史任务", "⚙️ 设置"]
|
||
mode_map = {"workspace": 0, "history": 1, "settings": 2}
|
||
current_index = mode_map.get(st.session_state.view_mode, 0)
|
||
|
||
mode = st.radio("模式", mode_options, index=current_index)
|
||
if mode == "🛠️ 工作台":
|
||
st.session_state.view_mode = "workspace"
|
||
elif mode == "📜 历史任务":
|
||
st.session_state.view_mode = "history"
|
||
elif mode == "⚙️ 设置":
|
||
st.session_state.view_mode = "settings"
|
||
|
||
st.markdown("---")
|
||
|
||
if st.session_state.view_mode == "workspace":
|
||
# Project Selection
|
||
st.subheader("Current Project")
|
||
projects = db.list_projects_for_user(_current_user())
|
||
proj_options = {p['id']: f"{p.get('name', 'Untitled')} ({p['id']})" for p in projects}
|
||
|
||
selected_proj_id = st.selectbox(
|
||
"Select Project",
|
||
options=["New Project"] + list(proj_options.keys()),
|
||
format_func=lambda x: "➕ New Project" if x == "New Project" else proj_options[x]
|
||
)
|
||
|
||
if selected_proj_id == "New Project":
|
||
if st.session_state.project_id and st.session_state.project_id in proj_options:
|
||
# If switching from existing to new, reset
|
||
if st.button("Start New"):
|
||
for key in list(st.session_state.keys()):
|
||
if key != "view_mode": del st.session_state[key]
|
||
st.rerun()
|
||
else:
|
||
if st.session_state.project_id != selected_proj_id:
|
||
if st.button(f"Load {selected_proj_id}"):
|
||
load_project(selected_proj_id)
|
||
st.rerun()
|
||
|
||
if st.session_state.project_id:
|
||
st.caption(f"Current ID: {st.session_state.project_id}")
|
||
with st.expander("⏱️ 性能与诊断", expanded=False):
|
||
m = _get_metrics(st.session_state.project_id)
|
||
if not m:
|
||
st.caption("暂无指标(执行一次脚本/生图/生视频/合成后会出现)。")
|
||
else:
|
||
keys = [
|
||
"script_gen_s",
|
||
"image_gen_total_s",
|
||
"video_submit_s",
|
||
"video_recover_s",
|
||
"compose_s",
|
||
"script_model",
|
||
"image_provider",
|
||
"image_generated",
|
||
"video_submitted",
|
||
"video_recovered",
|
||
"bgm_used",
|
||
]
|
||
for k in keys:
|
||
if k in m:
|
||
st.caption(f"{k}: {m.get(k)}")
|
||
st.markdown("---")
|
||
|
||
# Navigation / Progress
|
||
steps = ["1. 输入信息", "2. 生成脚本", "3. 画面生成", "4. 视频生成", "5. 合成输出"]
|
||
current = st.session_state.current_step
|
||
|
||
for i, step in enumerate(steps):
|
||
if i < current:
|
||
st.markdown(f"✅ **{step}**")
|
||
elif i == current:
|
||
st.markdown(f"🔵 **{step}**")
|
||
else:
|
||
st.markdown(f"⚪ {step}")
|
||
|
||
st.markdown("---")
|
||
|
||
if st.button("🔄 重置状态", type="secondary"):
|
||
for key in list(st.session_state.keys()):
|
||
if key != "view_mode": del st.session_state[key]
|
||
st.rerun()
|
||
|
||
def save_uploaded_file(project_id: str, uploaded_file):
|
||
"""Save uploaded file to per-project upload dir (avoid overwrites across projects)."""
|
||
if uploaded_file is None:
|
||
return None
|
||
if not project_id:
|
||
raise ValueError("project_id is required to save uploaded files safely")
|
||
upload_dir = path_utils.project_upload_dir(project_id)
|
||
original = path_utils.sanitize_filename(getattr(uploaded_file, "name", "upload"))
|
||
# keep original stem for readability, but ensure uniqueness
|
||
suffix = Path(original).suffix.lstrip(".") or "bin"
|
||
stem = Path(original).stem or "upload"
|
||
unique_name = path_utils.unique_filename(prefix=f"upload_{stem}", ext=suffix, project_id=project_id)
|
||
file_path = upload_dir / unique_name
|
||
with open(file_path, "wb") as f:
|
||
f.write(uploaded_file.getbuffer())
|
||
return str(file_path)
|
||
|
||
# ============================================================
|
||
# Main Content: Workspace
|
||
# ============================================================
|
||
if st.session_state.view_mode == "workspace":
|
||
st.header("商品短视频自动生成控制台")
|
||
|
||
# --- Step 1: Input ---
|
||
with st.expander("📦 1. 商品信息输入", expanded=(st.session_state.current_step == 0)):
|
||
# 从 session_state 读取已加载项目的信息,否则使用默认值
|
||
loaded_info = st.session_state.get("loaded_product_info", {})
|
||
default_name = st.session_state.get("loaded_product_name", "网红气质大号发量多!高马尾香蕉夹")
|
||
default_category = loaded_info.get("category", "钟表配饰-时尚饰品-发饰")
|
||
default_price = loaded_info.get("price", "3.99元")
|
||
default_tags = loaded_info.get("tags", "回头客|款式好看|材质好|尺寸合适|颜色好看|很好用|做工好|质感不错|很牢固")
|
||
default_params = loaded_info.get("params", "金属材质:非金属; 非金属材质:树脂; 发夹分类:香蕉夹")
|
||
|
||
col1, col2 = st.columns([1, 1])
|
||
|
||
with col1:
|
||
product_name = st.text_input("商品标题", value=default_name, key=_ui_key("product_name"))
|
||
category = st.text_input("商品类目", value=default_category, key=_ui_key("category"))
|
||
price = st.text_input("价格", value=default_price, key=_ui_key("price"))
|
||
|
||
with col2:
|
||
tags = st.text_area("评价标签 (用于提炼卖点)", value=default_tags, height=100, key=_ui_key("tags"))
|
||
params = st.text_area("商品参数", value=default_params, height=100, key=_ui_key("params"))
|
||
|
||
# 商家自定义风格提示
|
||
style_hint = st.text_area(
|
||
"商品视频重点增强提示 (可选)",
|
||
value=loaded_info.get("style_hint", ""),
|
||
placeholder="例如:韩风、高级感、活力青春、简约日系...",
|
||
height=80,
|
||
key=_ui_key("style_hint"),
|
||
)
|
||
|
||
st.markdown("### 上传素材")
|
||
|
||
# 显示已有的上传图片(如果从历史项目加载)
|
||
if st.session_state.uploaded_images:
|
||
st.info(f"已有 {len(st.session_state.uploaded_images)} 张参考图片")
|
||
with st.expander("查看已有图片"):
|
||
img_cols = st.columns(min(len(st.session_state.uploaded_images), 4))
|
||
for i, img_path in enumerate(st.session_state.uploaded_images[:4]):
|
||
if os.path.exists(img_path):
|
||
img_cols[i % 4].image(img_path, width=150)
|
||
|
||
uploaded_files = st.file_uploader("上传商品主图 (建议 3-5 张)", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
|
||
|
||
# 允许在没有上传新图片但有历史图片的情况下继续
|
||
can_submit = uploaded_files or st.session_state.uploaded_images
|
||
|
||
# Model Selection (all support images; user explicitly chooses model)
|
||
model_options = ["GPT-5.2", "Gemini 3 Pro", "Doubao Pro (Vision)"]
|
||
selected_model_label = st.radio("选择脚本生成模型", model_options, horizontal=True, index=0, key=_ui_key("script_model"))
|
||
# Map label to provider key
|
||
if selected_model_label == "GPT-5.2":
|
||
model_provider = "shubiaobiao_gpt"
|
||
elif "Doubao" in selected_model_label:
|
||
model_provider = "doubao"
|
||
else:
|
||
model_provider = "shubiaobiao"
|
||
|
||
if st.button("提交任务 & 生成脚本", type="primary", disabled=(not can_submit)):
|
||
# 处理图片路径
|
||
image_paths = list(st.session_state.uploaded_images) if st.session_state.uploaded_images else []
|
||
|
||
# 如果有新上传的文件,添加它们
|
||
if uploaded_files:
|
||
# Create Project ID
|
||
if not st.session_state.project_id:
|
||
st.session_state.project_id = f"PROJ-{int(time.time())}"
|
||
|
||
for uf in uploaded_files:
|
||
path = save_uploaded_file(st.session_state.project_id, uf)
|
||
if path: image_paths.append(path)
|
||
|
||
st.session_state.uploaded_images = image_paths
|
||
|
||
# 确保有项目 ID
|
||
if not st.session_state.project_id:
|
||
st.session_state.project_id = f"PROJ-{int(time.time())}"
|
||
|
||
# DB: Create Project
|
||
# 将 uploaded_images 保存到 product_info 以便持久化
|
||
product_info = {"category": category, "price": price, "tags": tags, "params": params, "style_hint": style_hint, "uploaded_images": image_paths}
|
||
db.create_project(
|
||
st.session_state.project_id,
|
||
product_name,
|
||
product_info,
|
||
owner_user_id=(_current_user() or {}).get("id"),
|
||
)
|
||
|
||
# Call Script Generator
|
||
with st.spinner(f"正在分析商品信息并生成脚本 ({selected_model_label})..."):
|
||
gen = ScriptGenerator()
|
||
t0 = perf_counter()
|
||
script = gen.generate_script(
|
||
product_name,
|
||
product_info,
|
||
image_paths,
|
||
model_provider=model_provider,
|
||
user_id=(_current_user() or {}).get("id"),
|
||
)
|
||
_record_metrics(st.session_state.project_id, {
|
||
"script_gen_s": round(perf_counter() - t0, 3),
|
||
"script_model": model_provider,
|
||
})
|
||
|
||
if script:
|
||
st.session_state.script_data = script
|
||
# DB: Save Script
|
||
db.update_project_script(st.session_state.project_id, script)
|
||
|
||
st.session_state.current_step = 1
|
||
st.success("脚本生成成功!")
|
||
st.rerun()
|
||
else:
|
||
st.error("脚本生成失败,请检查日志。")
|
||
|
||
# --- Step 2: Script Review ---
|
||
if st.session_state.current_step >= 1:
|
||
with st.expander("📝 2. 脚本分镜确认", expanded=(st.session_state.current_step == 1)):
|
||
if st.session_state.script_data:
|
||
script = st.session_state.script_data
|
||
|
||
# Display Basic Info (兼容 legacy schema)
|
||
selling_points = script.get("selling_points", []) or []
|
||
target_audience = script.get("target_audience", "") or ""
|
||
analysis_text = script.get("analysis", "") or ""
|
||
legacy_schema = script.get("_legacy_schema", "") or ""
|
||
|
||
c1, c2 = st.columns(2)
|
||
if selling_points:
|
||
c1.write(f"**核心卖点**: {', '.join(selling_points)}")
|
||
else:
|
||
c1.write("**核心卖点**: (legacy 项目可能未生成该字段)")
|
||
if analysis_text:
|
||
with st.expander("查看 legacy analysis(用于补齐信息)"):
|
||
st.write(analysis_text)
|
||
|
||
if target_audience:
|
||
c2.write(f"**目标人群**: {target_audience}")
|
||
else:
|
||
c2.write("**目标人群**: (legacy 项目可能未生成该字段)")
|
||
|
||
# Hook / CTA / Schema
|
||
hook = script.get("hook", "") or ""
|
||
if hook:
|
||
st.markdown(f"**Hook**: {hook}")
|
||
cta = script.get("cta", "")
|
||
if cta:
|
||
if isinstance(cta, dict):
|
||
st.markdown("**CTA(legacy object)**")
|
||
st.json(cta)
|
||
else:
|
||
st.markdown(f"**CTA**: {cta}")
|
||
if legacy_schema:
|
||
st.caption(f"Legacy Schema: {legacy_schema}")
|
||
|
||
# Prompt Visualization
|
||
if "_debug" in script:
|
||
with st.expander("🔍 查看 AI Prompt (Debug)"):
|
||
debug_info = script["_debug"]
|
||
st.markdown("**System Prompt:**")
|
||
st.code(debug_info.get("system_prompt", ""), language="markdown")
|
||
st.markdown("**User Prompt:**")
|
||
st.code(debug_info.get("user_prompt", ""), language="markdown")
|
||
|
||
# 显示原始输出
|
||
if "raw_output" in debug_info:
|
||
st.markdown("**Raw AI Output:**")
|
||
st.code(debug_info.get("raw_output", ""), language="markdown")
|
||
|
||
# Editable Scenes Table
|
||
st.markdown("### 分镜列表")
|
||
scenes = script.get("scenes", [])
|
||
|
||
# Global Voiceover Timeline (New)
|
||
st.markdown("### 🎙️ 整体旁白与字幕时间轴")
|
||
with st.expander("编辑旁白时间轴 (Voiceover Timeline)", expanded=True):
|
||
timeline = script.get("voiceover_timeline", []) or []
|
||
if not timeline:
|
||
# 对于历史项目:如果没有 scenes 也没有 timeline,不要强行塞“示例旁白”,避免污染数据
|
||
if not scenes and analysis_text:
|
||
st.info("该历史项目暂无旁白时间轴(可能停留在分析/提问阶段)。")
|
||
timeline = []
|
||
else:
|
||
# Init with default if empty (使用秒)
|
||
timeline = [{"text": "示例旁白", "subtitle": "示例字幕", "start_time": 0.0, "duration": 3.0}]
|
||
|
||
updated_timeline = []
|
||
for i, item in enumerate(timeline):
|
||
c1, c2, c3, c4 = st.columns([3, 3, 1, 1])
|
||
with c1:
|
||
item["text"] = st.text_input(f"旁白 #{i+1}", value=item.get("text", ""), key=_ui_key(f"tl_vo_{i}"))
|
||
with c2:
|
||
item["subtitle"] = st.text_input(f"字幕 #{i+1}", value=item.get("subtitle", item.get("text", "")), key=_ui_key(f"tl_sub_{i}"))
|
||
with c3:
|
||
# 兼容旧格式: 如果有 start_ratio 则转换为 start_time
|
||
default_start = item.get("start_time", item.get("start_ratio", 0) * 12) # 假设总时长12秒
|
||
item["start_time"] = st.number_input(f"开始(秒) #{i+1}", value=float(default_start), min_value=0.0, max_value=30.0, step=0.5, key=_ui_key(f"tl_start_{i}"))
|
||
with c4:
|
||
# 兼容旧格式: 如果有 duration_ratio 则转换为 duration
|
||
default_dur = item.get("duration", item.get("duration_ratio", 0.25) * 12) # 假设总时长12秒
|
||
item["duration"] = st.number_input(f"时长(秒) #{i+1}", value=float(default_dur), min_value=0.5, max_value=15.0, step=0.5, key=_ui_key(f"tl_dur_{i}"))
|
||
# 清理旧字段
|
||
item.pop("start_ratio", None)
|
||
item.pop("duration_ratio", None)
|
||
updated_timeline.append(item)
|
||
|
||
if st.button("➕ 添加旁白段落"):
|
||
updated_timeline.append({"text": "", "subtitle": "", "start_time": 0.0, "duration": 3.0})
|
||
st.rerun()
|
||
|
||
script["voiceover_timeline"] = updated_timeline
|
||
|
||
updated_scenes = []
|
||
for i, scene in enumerate(scenes):
|
||
with st.container():
|
||
st.markdown(f"#### 🎬 Scene {scene['id']}")
|
||
c_vis, c_aud = st.columns([2, 1])
|
||
|
||
with c_vis:
|
||
new_visual = st.text_area(f"Visual Prompt (Scene {scene['id']})", value=scene.get("visual_prompt", ""), height=80, key=_ui_key(f"vp_{i}"))
|
||
new_video = st.text_area(f"Video Prompt (Scene {scene['id']})", value=scene.get("video_prompt", ""), height=80, key=_ui_key(f"vidp_{i}"))
|
||
scene["visual_prompt"] = new_visual
|
||
scene["video_prompt"] = new_video
|
||
|
||
with c_aud:
|
||
# 花字编辑保留
|
||
ft = scene.get("fancy_text", {})
|
||
if isinstance(ft, dict):
|
||
new_ft_text = st.text_input(
|
||
f"Fancy Text (Scene {scene['id']})",
|
||
value=ft.get("text", ""),
|
||
key=_ui_key(f"ft_{i}"),
|
||
)
|
||
# 兼容:旧数据可能没有 fancy_text 字段
|
||
if not isinstance(scene.get("fancy_text"), dict):
|
||
scene["fancy_text"] = {}
|
||
scene["fancy_text"]["text"] = new_ft_text
|
||
|
||
# 旁白/字幕已移至上方整体时间轴,此处仅作展示或删除
|
||
st.caption("注:旁白与字幕已移至上方整体时间轴编辑")
|
||
|
||
# Legacy 信息展示(只读,用于调试/对齐)
|
||
legacy_scene = scene.get("_legacy", {}) if isinstance(scene.get("_legacy", {}), dict) else {}
|
||
if legacy_scene:
|
||
with st.expander(f"Legacy 信息 (Scene {scene['id']})", expanded=False):
|
||
img_url = legacy_scene.get("image_url")
|
||
if img_url:
|
||
st.markdown(f"- image_url: `{img_url}`")
|
||
cam = legacy_scene.get("camera_movement")
|
||
if cam:
|
||
st.markdown(f"- camera_movement: {cam}")
|
||
vo = legacy_scene.get("voiceover")
|
||
if vo:
|
||
st.markdown(f"- voiceover: {vo}")
|
||
keyframe = legacy_scene.get("keyframe")
|
||
if keyframe:
|
||
st.markdown("- keyframe:")
|
||
st.json(keyframe)
|
||
rhythm = legacy_scene.get("rhythm")
|
||
if rhythm:
|
||
st.markdown("- rhythm:")
|
||
st.json(rhythm)
|
||
|
||
updated_scenes.append(scene)
|
||
st.divider()
|
||
|
||
st.session_state.script_data["scenes"] = updated_scenes
|
||
|
||
col_act1, col_act2 = st.columns([1, 4])
|
||
with col_act1:
|
||
if st.button("确认脚本 & 开始生图", type="primary"):
|
||
# DB: Update script in case user edited it
|
||
db.update_project_script(st.session_state.project_id, st.session_state.script_data)
|
||
st.session_state.current_step = 2
|
||
st.rerun()
|
||
|
||
# --- Step 3: Image Generation ---
|
||
if st.session_state.current_step >= 2:
|
||
with st.expander("🎨 3. 画面生成", expanded=(st.session_state.current_step == 2)):
|
||
|
||
# Debug: Show reference image
|
||
if st.session_state.uploaded_images:
|
||
with st.expander("🔍 查看生图参考底图 (Base Image)"):
|
||
for i, p in enumerate(st.session_state.uploaded_images):
|
||
st.info(f"图 {i+1} 文件名: {os.path.basename(p)}")
|
||
if os.path.exists(p):
|
||
st.image(p, width=200)
|
||
else:
|
||
st.warning("⚠️ 未检测到参考底图!将生成随机风格图像。建议在 Step 1 重新上传。")
|
||
|
||
# Trigger Generation Button
|
||
if not st.session_state.scene_images:
|
||
# Model Selection for Image Gen
|
||
img_model_options = ["Shubiaobiao (Gemini)", "Doubao (Volcengine)", "Gemini (Direct)", "Doubao (Group Image)"]
|
||
selected_img_model = st.radio("选择生图模型", img_model_options, horizontal=True, key=_ui_key("img_model"))
|
||
|
||
# Map to provider
|
||
if "Group Image" in selected_img_model:
|
||
img_provider = "doubao-group"
|
||
elif "Doubao" in selected_img_model:
|
||
img_provider = "doubao"
|
||
elif "Direct" in selected_img_model:
|
||
img_provider = "gemini"
|
||
else:
|
||
img_provider = "shubiaobiao"
|
||
|
||
# Store selected provider for regeneration
|
||
st.session_state.selected_img_provider = img_provider
|
||
|
||
if st.button("🚀 执行 AI 生图", type="primary"):
|
||
with limits.acquire_image(blocking=False) as ok:
|
||
if not ok:
|
||
st.warning("系统正在生成其他任务(生图并发已达上限),请稍后再试。")
|
||
st.stop()
|
||
img_gen = ImageGenerator()
|
||
# Pass ALL uploaded images as reference
|
||
base_imgs = st.session_state.uploaded_images if st.session_state.uploaded_images else []
|
||
|
||
if not base_imgs:
|
||
st.error("No base image found (未找到参考底图). Please upload in Step 1.")
|
||
st.stop()
|
||
|
||
scenes = st.session_state.script_data.get("scenes", [])
|
||
|
||
# 读取 visual_anchor 用于保持生图一致性
|
||
visual_anchor = st.session_state.script_data.get("visual_anchor", "")
|
||
|
||
if img_provider == "doubao-group":
|
||
# --- Group Generation Logic ---
|
||
with st.spinner("正在进行 Doubao 组图生成 (Batch Group Generation)..."):
|
||
try:
|
||
t0 = perf_counter()
|
||
results = img_gen.generate_group_images_doubao(
|
||
scenes=scenes,
|
||
reference_images=base_imgs,
|
||
visual_anchor=visual_anchor,
|
||
project_id=st.session_state.project_id
|
||
)
|
||
_record_metrics(st.session_state.project_id, {
|
||
"image_gen_total_s": round(perf_counter() - t0, 3),
|
||
"image_provider": img_provider,
|
||
"image_generated": len(results),
|
||
})
|
||
|
||
for s_id, path in results.items():
|
||
st.session_state.scene_images[s_id] = path
|
||
db.save_asset(st.session_state.project_id, s_id, "image", "completed", local_path=path)
|
||
# Invalidate stale video for this scene (group image regen also changes image)
|
||
db.clear_asset(st.session_state.project_id, s_id, "video", status="pending")
|
||
st.session_state.scene_videos.pop(s_id, None)
|
||
|
||
if len(results) == len(scenes):
|
||
st.success("组图生成完成!")
|
||
db.update_project_status(st.session_state.project_id, "images_generated")
|
||
time.sleep(1)
|
||
st.rerun()
|
||
else:
|
||
st.warning(f"部分图片生成失败: {len(results)}/{len(scenes)}")
|
||
st.rerun()
|
||
|
||
except Exception as e:
|
||
st.error(f"组图生成失败: {e}")
|
||
else:
|
||
# --- Parallel Logic (default): only merchant uploaded images as references ---
|
||
total_scenes = len(scenes)
|
||
progress_bar = st.progress(0)
|
||
status_text = st.empty()
|
||
|
||
try:
|
||
t0 = perf_counter()
|
||
# Parallel workers within a single run; global semaphore already acquired above.
|
||
max_workers = 6
|
||
futures = {}
|
||
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||
for idx, scene in enumerate(scenes):
|
||
scene_id = scene["id"]
|
||
futures[ex.submit(
|
||
img_gen.generate_single_scene_image,
|
||
scene=scene,
|
||
original_image_path=list(base_imgs), # ONLY merchant images
|
||
previous_image_path=None,
|
||
model_provider=img_provider,
|
||
visual_anchor=visual_anchor,
|
||
project_id=st.session_state.project_id,
|
||
)] = (idx, scene_id)
|
||
|
||
done = 0
|
||
for fut in as_completed(futures):
|
||
idx, scene_id = futures[fut]
|
||
done += 1
|
||
status_text.text(f"已完成 {done}/{total_scenes}(Scene {scene_id})")
|
||
try:
|
||
img_path = fut.result()
|
||
if img_path:
|
||
st.session_state.scene_images[scene_id] = img_path
|
||
db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=img_path)
|
||
# Invalidate stale video for this scene (image changed => old video is wrong)
|
||
db.clear_asset(st.session_state.project_id, scene_id, "video", status="pending")
|
||
st.session_state.scene_videos.pop(scene_id, None)
|
||
except Exception as e:
|
||
st.warning(f"Scene {scene_id} 生成失败:{e}")
|
||
|
||
progress_bar.progress(done / total_scenes)
|
||
|
||
status_text.text("生图完成!")
|
||
st.success("生图完成!")
|
||
_record_metrics(st.session_state.project_id, {
|
||
"image_gen_total_s": round(perf_counter() - t0, 3),
|
||
"image_provider": img_provider,
|
||
"image_generated": len(st.session_state.scene_images),
|
||
})
|
||
# Update Status
|
||
db.update_project_status(st.session_state.project_id, "images_generated")
|
||
time.sleep(1)
|
||
st.rerun()
|
||
|
||
except PermissionError as e:
|
||
st.error(f"生图服务拒绝访问 (可能余额不足): {e}")
|
||
except Exception as e:
|
||
st.error(f"生图发生错误: {e}")
|
||
|
||
# Display & Regenerate Images
|
||
if st.session_state.scene_images:
|
||
cols = st.columns(4)
|
||
scenes = st.session_state.script_data.get("scenes", [])
|
||
|
||
for i, scene in enumerate(scenes):
|
||
scene_id = scene["id"]
|
||
img_path = st.session_state.scene_images.get(scene_id)
|
||
|
||
with cols[i % 4]:
|
||
st.markdown(f"**Scene {scene_id}**")
|
||
|
||
# Expand Prompt info
|
||
with st.expander("Prompt", expanded=False):
|
||
st.caption(scene.get("visual_prompt", "No prompt"))
|
||
|
||
if img_path and os.path.exists(img_path):
|
||
st.image(img_path, width=None, use_column_width=True) # Fix deprecated
|
||
else:
|
||
st.error("Image missing")
|
||
|
||
# Regenerate specific image
|
||
if st.button(f"🔄 重生 S{scene_id}", key=f"regen_img_{scene_id}"):
|
||
if not st.session_state.uploaded_images:
|
||
st.error("缺少参考底图,请在 Step 1 重新上传图片")
|
||
else:
|
||
with st.spinner(f"正在重绘 Scene {scene_id}..."):
|
||
img_gen = ImageGenerator()
|
||
# Only merchant uploaded images as references (no chaining)
|
||
current_refs_for_regen = list(st.session_state.uploaded_images)
|
||
|
||
# Fallback to single mode for regen if group was used
|
||
provider = st.session_state.get("selected_img_provider", "shubiaobiao")
|
||
if provider == "doubao-group": provider = "doubao"
|
||
|
||
# 读取 visual_anchor
|
||
regen_visual_anchor = st.session_state.script_data.get("visual_anchor", "")
|
||
|
||
new_path = img_gen.generate_single_scene_image(
|
||
scene=scene,
|
||
original_image_path=current_refs_for_regen,
|
||
previous_image_path=None,
|
||
model_provider=provider,
|
||
visual_anchor=regen_visual_anchor,
|
||
project_id=st.session_state.project_id
|
||
)
|
||
if new_path:
|
||
st.session_state.scene_images[scene_id] = new_path
|
||
db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=new_path)
|
||
# Invalidate stale video for this scene
|
||
db.clear_asset(st.session_state.project_id, scene_id, "video", status="pending")
|
||
st.session_state.scene_videos.pop(scene_id, None)
|
||
st.rerun()
|
||
|
||
if st.button("下一步:生成视频", type="primary"):
|
||
st.session_state.current_step = 3
|
||
st.rerun()
|
||
|
||
# --- Step 4: Video Generation ---
|
||
if st.session_state.current_step >= 3:
|
||
with st.expander("🎥 4. 视频生成 (Volcengine I2V)", expanded=(st.session_state.current_step == 3)):
|
||
|
||
scenes = st.session_state.script_data.get("scenes", [])
|
||
vid_gen = VideoGenerator()
|
||
|
||
st.caption("简化策略:第 4 步完成“生成→轮询→下载到服务器本地”。当至少有一个分镜视频落盘后,才允许进入第 5 步合成。")
|
||
|
||
if st.button("🎬 生成分镜视频并下载到服务器(阻塞)", type="primary"):
|
||
with limits.acquire_video(blocking=False) as ok:
|
||
if not ok:
|
||
st.warning("系统正在处理其他视频任务(并发已达上限),请稍后再试。")
|
||
st.stop()
|
||
|
||
if not st.session_state.project_id:
|
||
st.error("缺少 project_id,无法生成视频。")
|
||
st.stop()
|
||
|
||
if not scenes:
|
||
st.error("脚本中没有分镜,无法生成视频。")
|
||
st.stop()
|
||
|
||
# 先把 DB 中已完成且存在的本地视频恢复到 session
|
||
for scene in scenes:
|
||
scene_id = scene["id"]
|
||
existing = st.session_state.scene_videos.get(scene_id)
|
||
if existing and os.path.exists(existing):
|
||
continue
|
||
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
|
||
local_path = (asset or {}).get("local_path")
|
||
if local_path and os.path.exists(local_path):
|
||
st.session_state.scene_videos[scene_id] = local_path
|
||
|
||
t0 = perf_counter()
|
||
total = len(scenes)
|
||
done = sum(1 for _sid, p in (st.session_state.scene_videos or {}).items() if p and os.path.exists(p))
|
||
progress = st.progress(0.0)
|
||
status_text = st.empty()
|
||
|
||
# 收集/提交任务
|
||
tasks = {} # scene_id -> task_id
|
||
for scene in scenes:
|
||
scene_id = scene["id"]
|
||
existing = st.session_state.scene_videos.get(scene_id)
|
||
if existing and os.path.exists(existing):
|
||
continue
|
||
|
||
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
|
||
task_id = (asset or {}).get("task_id")
|
||
if task_id:
|
||
tasks[scene_id] = task_id
|
||
continue
|
||
|
||
image_path = st.session_state.scene_images.get(scene_id)
|
||
prompt = scene.get("video_prompt", "High quality video")
|
||
new_task_id = vid_gen.submit_scene_video_task(
|
||
st.session_state.project_id, scene_id, image_path, prompt
|
||
)
|
||
if new_task_id:
|
||
tasks[scene_id] = new_task_id
|
||
|
||
if tasks:
|
||
db.update_project_status(st.session_state.project_id, "videos_processing")
|
||
|
||
out_dir = path_utils.project_videos_dir(st.session_state.project_id)
|
||
pending = set(tasks.keys())
|
||
deadline = time.time() + 15 * 60 # 15 min
|
||
|
||
while pending and time.time() < deadline:
|
||
status_text.text(f"视频生成/下载中:已完成 {done}/{total},队列中 {len(pending)} ...")
|
||
to_remove = []
|
||
|
||
for scene_id in list(pending):
|
||
task_id = tasks.get(scene_id)
|
||
if not task_id:
|
||
to_remove.append(scene_id)
|
||
continue
|
||
|
||
status, url = vid_gen.check_task_status(task_id)
|
||
if status == "succeeded" and url:
|
||
out_name = path_utils.unique_filename(
|
||
prefix="scene_video",
|
||
ext="mp4",
|
||
project_id=st.session_state.project_id,
|
||
scene_id=scene_id,
|
||
extra=(task_id[-8:] if isinstance(task_id, str) else None),
|
||
)
|
||
target_path = str(out_dir / out_name)
|
||
ok_dl = vid_gen._download_video_to(url, target_path)
|
||
if ok_dl and os.path.exists(target_path):
|
||
st.session_state.scene_videos[scene_id] = target_path
|
||
db.save_asset(
|
||
st.session_state.project_id,
|
||
scene_id,
|
||
"video",
|
||
"completed",
|
||
local_path=target_path,
|
||
task_id=task_id,
|
||
metadata={"downloaded_at": time.time()},
|
||
)
|
||
done += 1
|
||
else:
|
||
db.save_asset(
|
||
st.session_state.project_id,
|
||
scene_id,
|
||
"video",
|
||
"failed",
|
||
task_id=task_id,
|
||
metadata={"download_error": True, "checked_at": time.time()},
|
||
)
|
||
to_remove.append(scene_id)
|
||
elif status in ["failed", "cancelled"]:
|
||
db.save_asset(
|
||
st.session_state.project_id,
|
||
scene_id,
|
||
"video",
|
||
"failed",
|
||
task_id=task_id,
|
||
metadata={"volc_status": status, "checked_at": time.time()},
|
||
)
|
||
to_remove.append(scene_id)
|
||
else:
|
||
db.update_asset_metadata(
|
||
st.session_state.project_id,
|
||
scene_id,
|
||
"video",
|
||
{"volc_status": status, "checked_at": time.time()},
|
||
)
|
||
|
||
for sid in to_remove:
|
||
pending.discard(sid)
|
||
|
||
progress.progress(min(1.0, done / max(total, 1)))
|
||
if pending:
|
||
time.sleep(5)
|
||
|
||
if pending:
|
||
st.warning(f"仍有 {len(pending)} 个分镜未在本轮完成:{sorted(list(pending))}。可再次点击按钮继续轮询与下载。")
|
||
|
||
_record_metrics(st.session_state.project_id, {
|
||
"video_blocking_total_s": round(perf_counter() - t0, 3),
|
||
"video_done": done,
|
||
"video_total": total,
|
||
})
|
||
|
||
if any(p and os.path.exists(p) for p in (st.session_state.scene_videos or {}).values()):
|
||
db.update_project_status(st.session_state.project_id, "videos_generated")
|
||
st.success("已生成并下载到服务器本地。进入第 5 步合成。")
|
||
st.session_state.current_step = 4
|
||
st.rerun()
|
||
else:
|
||
st.error("未生成任何可用视频,请检查生视频接口或稍后重试。")
|
||
|
||
# Display Videos (even when partially available)
|
||
if st.session_state.scene_videos or scenes:
|
||
cols = st.columns(4)
|
||
|
||
for i, scene in enumerate(scenes):
|
||
scene_id = scene["id"]
|
||
vid_path = st.session_state.scene_videos.get(scene_id)
|
||
|
||
with cols[i % 4]:
|
||
st.markdown(f"**Scene {scene_id}**")
|
||
# Expand Prompt info
|
||
with st.expander("Prompt", expanded=False):
|
||
st.caption(scene.get("video_prompt", "No prompt"))
|
||
|
||
if vid_path and os.path.exists(vid_path):
|
||
st.video(vid_path)
|
||
else:
|
||
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
|
||
status = (asset or {}).get("status") or "pending"
|
||
task_id = (asset or {}).get("task_id")
|
||
if task_id:
|
||
st.caption(f"状态: {status} | Task: {str(task_id)[-6:]}")
|
||
else:
|
||
st.caption(f"状态: {status}")
|
||
st.warning("暂无本地视频(请点击上方“生成分镜视频并下载到服务器”)。")
|
||
|
||
# Per-scene regenerate button
|
||
if st.button(f"🔄 重生 S{scene_id}", key=f"regen_vid_{scene_id}"):
|
||
with st.spinner(f"正在重生成 Scene {scene_id} 视频..."):
|
||
vid_gen = VideoGenerator()
|
||
img_path = st.session_state.scene_images.get(scene_id)
|
||
if img_path and os.path.exists(img_path):
|
||
t_id = vid_gen.submit_scene_video_task(st.session_state.project_id, scene_id, img_path, scene.get("video_prompt"))
|
||
if t_id:
|
||
# Simple poll loop for single video regen
|
||
import time
|
||
for _ in range(60):
|
||
status, url = vid_gen.check_task_status(t_id)
|
||
if status == "succeeded":
|
||
# Use per-project unique name to avoid cross-project overwrite
|
||
out_name = path_utils.unique_filename(
|
||
prefix="scene_video",
|
||
ext="mp4",
|
||
project_id=st.session_state.project_id,
|
||
scene_id=scene_id,
|
||
extra=(t_id[-8:] if isinstance(t_id, str) else None),
|
||
)
|
||
target_dir = path_utils.project_videos_dir(st.session_state.project_id)
|
||
target_path = str(target_dir / out_name)
|
||
if vid_gen._download_video_to(url, target_path) and os.path.exists(target_path):
|
||
st.session_state.scene_videos[scene_id] = target_path
|
||
db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=target_path, task_id=t_id)
|
||
st.rerun()
|
||
break
|
||
elif status in ["failed", "cancelled"]:
|
||
st.error(f"生成失败: {status}")
|
||
break
|
||
time.sleep(5)
|
||
else:
|
||
st.error("提交任务失败")
|
||
else:
|
||
st.error("缺少源图片")
|
||
|
||
c_act1, c_act2 = st.columns([1, 4])
|
||
with c_act1:
|
||
if st.button("🧹 清空视频并重新生成", type="secondary"):
|
||
# Clear videos and rerun
|
||
st.session_state.scene_videos = {}
|
||
# Also clear DB video assets
|
||
if st.session_state.project_id:
|
||
db.clear_assets(st.session_state.project_id, "video", status="pending")
|
||
st.rerun()
|
||
|
||
with c_act2:
|
||
can_compose = any(
|
||
p and os.path.exists(p) for p in (st.session_state.scene_videos or {}).values()
|
||
)
|
||
if st.button("进入第 5 步:合成最终成片", type="primary", disabled=not can_compose):
|
||
st.session_state.current_step = 4
|
||
st.rerun()
|
||
if not can_compose:
|
||
st.caption("⚠️ 需要至少一个分镜视频已下载到服务器本地后,才能进入合成。")
|
||
|
||
# --- Step 5: Final Composition & Tuning ---
|
||
if st.session_state.current_step >= 4:
|
||
with st.expander("🎞️ 5. 合成结果与微调", expanded=(st.session_state.current_step == 4)):
|
||
|
||
# Tabs: Preview vs Tuning
|
||
tab_preview, tab_tune = st.tabs(["🎥 预览", "🎛️ 微调 (Track Editor)"])
|
||
|
||
with tab_tune:
|
||
st.subheader("轨道微调 (Tuning)")
|
||
|
||
# Global Settings
|
||
col_g1, col_g2 = st.columns(2)
|
||
with col_g1:
|
||
# BGM Select (支持 .mp3 和 .mp4)
|
||
bgm_dir = config.ASSETS_DIR / "bgm"
|
||
bgm_files = list(bgm_dir.glob("*.[mM][pP][34]")) + list(bgm_dir.glob("*.[mM][pP]3"))
|
||
bgm_files = [f for f in bgm_files if f.is_file() and not f.name.startswith('.')]
|
||
bgm_names = [f.name for f in bgm_files]
|
||
|
||
# 智能推荐默认 BGM
|
||
bgm_style = st.session_state.script_data.get("bgm_style", "")
|
||
recommended_bgm = match_bgm_by_style(bgm_style, bgm_dir)
|
||
default_idx = 0
|
||
if recommended_bgm:
|
||
rec_name = Path(recommended_bgm).name
|
||
if rec_name in bgm_names:
|
||
default_idx = bgm_names.index(rec_name) + 1 # +1 因为有 "None" 选项
|
||
|
||
selected_bgm = st.selectbox(
|
||
f"背景音乐 (BGM) - 推荐风格: {bgm_style or '未指定'}",
|
||
["None"] + bgm_names,
|
||
index=default_idx
|
||
)
|
||
# 明确提示:BGM 目录为空或选中 BGM 不存在时,本次将不含 BGM
|
||
if not bgm_names:
|
||
st.warning(f"BGM 目录为空:{bgm_dir}(本次合成将不含 BGM)")
|
||
elif selected_bgm != "None":
|
||
candidate = config.ASSETS_DIR / "bgm" / selected_bgm
|
||
if not candidate.exists():
|
||
st.warning(f"所选 BGM 文件不存在:{candidate}(本次合成将不含 BGM)")
|
||
with col_g2:
|
||
# Voice Select
|
||
selected_voice = st.selectbox("配音音色 (TTS)", [config.VOLC_TTS_DEFAULT_VOICE, "zh_female_meilinvyou_saturn_bigtts"])
|
||
|
||
st.divider()
|
||
|
||
# Global Timeline Editor
|
||
st.markdown("### 🎙️ 旁白与字幕时间轴 (单位: 秒)")
|
||
timeline = st.session_state.script_data.get("voiceover_timeline", [])
|
||
updated_timeline = []
|
||
for i, item in enumerate(timeline):
|
||
c1, c2, c3, c4 = st.columns([3, 3, 1, 1])
|
||
with c1:
|
||
item["text"] = st.text_input(f"旁白 #{i+1}", value=item.get("text", ""), key=_ui_key(f"tune_vo_{i}"))
|
||
with c2:
|
||
item["subtitle"] = st.text_input(f"字幕 #{i+1}", value=item.get("subtitle", item.get("text", "")), key=_ui_key(f"tune_sub_{i}"))
|
||
with c3:
|
||
# 兼容旧格式: 如果有 start_ratio 则转换为 start_time
|
||
default_start = item.get("start_time", item.get("start_ratio", 0) * 12)
|
||
item["start_time"] = st.number_input(f"开始(秒) #{i+1}", value=float(default_start), min_value=0.0, max_value=30.0, step=0.5, key=_ui_key(f"tune_start_{i}"))
|
||
with c4:
|
||
# 兼容旧格式: 如果有 duration_ratio 则转换为 duration
|
||
default_dur = item.get("duration", item.get("duration_ratio", 0.25) * 12)
|
||
item["duration"] = st.number_input(f"时长(秒) #{i+1}", value=float(default_dur), min_value=0.5, max_value=15.0, step=0.5, key=_ui_key(f"tune_dur_{i}"))
|
||
# 清理旧字段
|
||
item.pop("start_ratio", None)
|
||
item.pop("duration_ratio", None)
|
||
updated_timeline.append(item)
|
||
|
||
st.session_state.script_data["voiceover_timeline"] = updated_timeline
|
||
|
||
st.divider()
|
||
|
||
# Per Scene Editor (Only Fancy Text)
|
||
scenes = st.session_state.script_data.get("scenes", [])
|
||
updated_scenes = []
|
||
|
||
for i, scene in enumerate(scenes):
|
||
with st.container():
|
||
st.markdown(f"**Scene {scene['id']}**")
|
||
ft = scene.get("fancy_text", {})
|
||
ft_text = ft.get("text", "") if isinstance(ft, dict) else ""
|
||
new_ft = st.text_input(f"花字", value=ft_text, key=_ui_key(f"tune_ft_{i}"))
|
||
# 兼容:旧数据可能没有 fancy_text 字段
|
||
if not isinstance(scene.get("fancy_text"), dict):
|
||
scene["fancy_text"] = {}
|
||
scene["fancy_text"]["text"] = new_ft
|
||
|
||
updated_scenes.append(scene)
|
||
|
||
st.session_state.script_data["scenes"] = updated_scenes
|
||
|
||
if st.button("🔄 重新合成 (Re-Compose)", type="primary"):
|
||
with st.spinner("正在应用修改并重新合成..."):
|
||
# 自动补齐视频下载逻辑 (关键优化)
|
||
_ensure_local_videos(st.session_state.project_id, st.session_state.script_data.get("scenes", []))
|
||
|
||
composer = VideoComposer(voice_type=selected_voice)
|
||
|
||
bgm_path = None
|
||
if selected_bgm != "None":
|
||
bgm_path = str(config.ASSETS_DIR / "bgm" / selected_bgm)
|
||
|
||
try:
|
||
# Save updated script first
|
||
db.update_project_script(st.session_state.project_id, st.session_state.script_data)
|
||
|
||
t0 = perf_counter()
|
||
output_path = composer.compose_from_script(
|
||
script=st.session_state.script_data,
|
||
video_map=st.session_state.scene_videos,
|
||
bgm_path=bgm_path,
|
||
output_name=f"final_{st.session_state.project_id}_{int(time.time())}", # Unique name for history
|
||
project_id=st.session_state.project_id,
|
||
)
|
||
_record_metrics(st.session_state.project_id, {
|
||
"compose_s": round(perf_counter() - t0, 3),
|
||
"bgm_used": bool(bgm_path and Path(bgm_path).exists()),
|
||
})
|
||
st.session_state.final_video = output_path
|
||
db.save_asset(st.session_state.project_id, 0, "final_video", "completed", local_path=output_path)
|
||
|
||
st.success("重新合成成功!")
|
||
st.rerun()
|
||
except Exception as e:
|
||
st.error(f"合成失败: {e}")
|
||
|
||
with tab_preview:
|
||
st.subheader("🛠️ 导出与交付")
|
||
|
||
col_ex1, col_ex2 = st.columns([1, 2])
|
||
with col_ex1:
|
||
if st.button("📦 生成剪映素材包 (ZIP)", type="primary", key="btn_export_zip"):
|
||
with st.spinner("正在打包视频、音频与字幕..."):
|
||
try:
|
||
zip_path = export_utils.create_capcut_package(
|
||
st.session_state.project_id,
|
||
st.session_state.script_data,
|
||
{"scene_videos": st.session_state.scene_videos}
|
||
)
|
||
st.session_state['export_zip_path'] = zip_path
|
||
st.success("打包完成!请点击下载")
|
||
except Exception as e:
|
||
st.error(f"打包失败: {e}")
|
||
|
||
with col_ex2:
|
||
if 'export_zip_path' in st.session_state and os.path.exists(st.session_state['export_zip_path']):
|
||
zip_name = f"capcut_pack_{st.session_state.project_id}.zip"
|
||
with open(st.session_state['export_zip_path'], "rb") as f:
|
||
st.download_button(
|
||
label=f"📥 点击下载素材包 ({zip_name})",
|
||
data=f,
|
||
file_name=zip_name,
|
||
mime="application/zip"
|
||
)
|
||
st.divider()
|
||
|
||
# 扫描所有历史合成版本
|
||
import glob
|
||
|
||
# 查找匹配当前项目的所有最终视频
|
||
pattern_timestamp = str(config.OUTPUT_DIR / f"final_{st.session_state.project_id}_*.mp4")
|
||
pattern_simple = str(config.OUTPUT_DIR / f"final_{st.session_state.project_id}.mp4")
|
||
|
||
found_files = []
|
||
for f in glob.glob(pattern_timestamp):
|
||
found_files.append(f)
|
||
for f in glob.glob(pattern_simple):
|
||
if f not in found_files:
|
||
found_files.append(f)
|
||
|
||
# 按修改时间倒序排列
|
||
found_files.sort(key=os.path.getmtime, reverse=True)
|
||
|
||
if not found_files:
|
||
st.info("暂无合成视频,请先点击开始合成。")
|
||
if st.button("✨ 开始首次合成", type="primary"):
|
||
with st.spinner("正在进行多轨合成..."):
|
||
# 自动补齐视频下载逻辑 (关键优化)
|
||
_ensure_local_videos(st.session_state.project_id, st.session_state.script_data.get("scenes", []))
|
||
|
||
# Default compose logic with smart BGM matching
|
||
composer = VideoComposer(voice_type=config.VOLC_TTS_DEFAULT_VOICE)
|
||
|
||
# 智能匹配 BGM:根据脚本 bgm_style 匹配
|
||
bgm_style = st.session_state.script_data.get("bgm_style", "")
|
||
bgm_path = match_bgm_by_style(bgm_style, config.ASSETS_DIR / "bgm")
|
||
if bgm_path and not Path(bgm_path).exists():
|
||
st.warning(f"推荐的 BGM 文件不存在:{bgm_path}(本次将不含 BGM)")
|
||
bgm_path = None
|
||
|
||
try:
|
||
# 首次合成也加上时间戳
|
||
output_name = f"final_{st.session_state.project_id}_{int(time.time())}"
|
||
t0 = perf_counter()
|
||
output_path = composer.compose_from_script(
|
||
script=st.session_state.script_data,
|
||
video_map=st.session_state.scene_videos,
|
||
bgm_path=bgm_path,
|
||
output_name=output_name,
|
||
project_id=st.session_state.project_id,
|
||
)
|
||
_record_metrics(st.session_state.project_id, {
|
||
"compose_s": round(perf_counter() - t0, 3),
|
||
"bgm_used": bool(bgm_path and Path(bgm_path).exists()),
|
||
})
|
||
st.session_state.final_video = output_path
|
||
db.save_asset(st.session_state.project_id, 0, "final_video", "completed", local_path=output_path)
|
||
db.update_project_status(st.session_state.project_id, "completed")
|
||
st.rerun()
|
||
except Exception as e:
|
||
st.error(f"合成失败: {e}")
|
||
|
||
else:
|
||
st.success(f"共找到 {len(found_files)} 个历史版本")
|
||
|
||
# 提示用户:如果重新合成,系统会自动补全未下载的素材
|
||
st.caption("💡 重新合成将应用当前的微调设置。如果分镜未下载,系统将尝试自动补全。")
|
||
|
||
# 遍历显示所有历史版本
|
||
for idx, vid_path in enumerate(found_files):
|
||
mtime = os.path.getmtime(vid_path)
|
||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime))
|
||
file_name = os.path.basename(vid_path)
|
||
|
||
# 容器样式
|
||
with st.container():
|
||
if idx == 0:
|
||
st.markdown(f"### 🟢 最新版本 ({time_str})")
|
||
else:
|
||
st.markdown(f"#### ⚪ 历史版本 {len(found_files)-idx} ({time_str})")
|
||
|
||
c1, c2, c3 = st.columns([1, 2, 1])
|
||
with c2:
|
||
st.video(vid_path)
|
||
|
||
# 下载按钮
|
||
with open(vid_path, "rb") as file:
|
||
st.download_button(
|
||
label=f"📥 下载 ({file_name})",
|
||
data=file,
|
||
file_name=file_name,
|
||
mime="video/mp4",
|
||
key=f"dl_btn_{file_name}"
|
||
)
|
||
st.divider()
|
||
|
||
# ============================================================
|
||
# Page: History (New)
|
||
# ============================================================
|
||
elif st.session_state.view_mode == "history":
|
||
st.header("📜 历史任务")
|
||
|
||
projects = db.list_projects_for_user(_current_user())
|
||
|
||
for proj in projects:
|
||
with st.expander(f"{proj['name']} ({proj['updated_at']})"):
|
||
st.caption(f"ID: {proj['id']} | Status: {proj.get('status', 'unknown')}")
|
||
|
||
assets = db.get_assets(proj['id'])
|
||
|
||
if assets:
|
||
st.markdown("#### 📥 下载素材")
|
||
|
||
# Group assets
|
||
final_vids = [a for a in assets if a['asset_type'] == 'final_video']
|
||
scene_imgs = [a for a in assets if a['asset_type'] == 'image']
|
||
scene_vids = [a for a in assets if a['asset_type'] == 'video']
|
||
|
||
# Final Video Download
|
||
if final_vids:
|
||
fv = final_vids[0]
|
||
if fv['local_path'] and os.path.exists(fv['local_path']):
|
||
st.success(f"✅ **最终成片已生成**")
|
||
with open(fv['local_path'], "rb") as f:
|
||
st.download_button(
|
||
label=f"📥 下载最终成片 ({os.path.basename(fv['local_path'])})",
|
||
data=f,
|
||
file_name=os.path.basename(fv['local_path']),
|
||
mime="video/mp4"
|
||
)
|
||
|
||
# Script Download
|
||
if proj.get("script_data"):
|
||
try:
|
||
script_json = json.dumps(proj["script_data"] if isinstance(proj["script_data"], dict) else json.loads(proj["script_data"]), ensure_ascii=False, indent=2)
|
||
st.download_button(
|
||
label="📥 下载脚本/字幕配置 (JSON)",
|
||
data=script_json,
|
||
file_name=f"script_{proj['id']}.json",
|
||
mime="application/json"
|
||
)
|
||
except:
|
||
pass
|
||
|
||
st.markdown("---")
|
||
|
||
c1, c2 = st.columns(2)
|
||
with c1:
|
||
st.markdown("**分镜图片**")
|
||
for img in scene_imgs:
|
||
if img['local_path'] and os.path.exists(img['local_path']):
|
||
with open(img['local_path'], "rb") as f:
|
||
st.download_button(f"Scene {img['scene_id']} 图片", f, key=f"dl_img_{proj['id']}_{img['id']}", file_name=f"scene_{img['scene_id']}.png")
|
||
|
||
with c2:
|
||
st.markdown("**分镜视频**")
|
||
for vid in scene_vids:
|
||
if vid['local_path'] and os.path.exists(vid['local_path']):
|
||
with open(vid['local_path'], "rb") as f:
|
||
st.download_button(f"Scene {vid['scene_id']} 视频", f, key=f"dl_vid_{proj['id']}_{vid['id']}", file_name=f"scene_{vid['scene_id']}.mp4")
|
||
else:
|
||
st.info("暂无生成素材")
|
||
|
||
# ============================================================
|
||
# Page: Settings (New)
|
||
# ============================================================
|
||
elif st.session_state.view_mode == "settings":
|
||
st.header("⚙️ 设置配置")
|
||
|
||
st.subheader("Prompt 配置")
|
||
|
||
# Script Generation Prompt
|
||
cu = _current_user() or {}
|
||
user_prompt = None
|
||
try:
|
||
user_prompt = db.get_user_prompt(cu.get("id"), "prompt_script_gen") if cu.get("id") else None
|
||
except Exception:
|
||
user_prompt = None
|
||
current_prompt = user_prompt or db.get_config("prompt_script_gen")
|
||
|
||
# 显示当前状态
|
||
if user_prompt:
|
||
st.info("✅ 已加载自定义 Prompt(当前用户)")
|
||
elif current_prompt:
|
||
st.info("✅ 已加载自定义 Prompt(全局默认)")
|
||
else:
|
||
st.warning("⚠️ 使用默认 Prompt(数据库中无自定义配置)")
|
||
# Load default from instance if not in DB
|
||
temp_gen = ScriptGenerator()
|
||
current_prompt = temp_gen.default_system_prompt
|
||
|
||
new_prompt = st.text_area("脚本拆分 Prompt (System Prompt)", value=current_prompt, height=400)
|
||
|
||
col_save, col_reset = st.columns([1, 3])
|
||
with col_save:
|
||
if st.button("💾 保存配置", type="primary"):
|
||
# Save per-user prompt
|
||
db.set_user_prompt(cu.get("id"), "prompt_script_gen", new_prompt)
|
||
saved = db.get_user_prompt(cu.get("id"), "prompt_script_gen")
|
||
if saved == new_prompt:
|
||
st.success("✅ 已保存为“当前用户 Prompt”。下次生成脚本仅影响你自己的账号。")
|
||
else:
|
||
st.error("❌ 保存可能失败,请检查日志")
|
||
|
||
with col_reset:
|
||
if st.button("🔄 恢复默认"):
|
||
temp_gen = ScriptGenerator()
|
||
# Clear per-user override: set empty to remove effect
|
||
db.set_user_prompt(cu.get("id"), "prompt_script_gen", "")
|
||
st.success("已清除当前用户 Prompt,将回退到全局/默认 Prompt。")
|
||
st.rerun()
|
||
|
||
st.markdown("---")
|
||
st.subheader("账号与权限")
|
||
|
||
# Password change for current user
|
||
with st.expander("修改我的密码", expanded=False):
|
||
p1 = st.text_input("新密码", type="password", key=_ui_key("pwd_new"))
|
||
p2 = st.text_input("确认新密码", type="password", key=_ui_key("pwd_new2"))
|
||
if st.button("保存新密码", type="primary", key=_ui_key("pwd_save")):
|
||
if not p1 or len(p1) < 6:
|
||
st.error("密码至少 6 位")
|
||
elif p1 != p2:
|
||
st.error("两次输入不一致")
|
||
else:
|
||
db.reset_user_password(cu.get("id"), p1)
|
||
st.success("密码已更新,请重新登录。")
|
||
_logout()
|
||
|
||
# Admin console
|
||
if (cu.get("role") == "admin"):
|
||
st.markdown("### Admin:账号管理")
|
||
users = db.list_users()
|
||
if users:
|
||
st.dataframe(
|
||
[{k: u.get(k) for k in ["username", "role", "is_active", "last_login_at", "created_at", "id"]} for u in users],
|
||
use_container_width=True,
|
||
)
|
||
|
||
with st.expander("创建/更新用户", expanded=False):
|
||
u_name = st.text_input("用户名", key=_ui_key("adm_u_name"))
|
||
u_role = st.selectbox("角色", ["user", "admin"], index=0, key=_ui_key("adm_u_role"))
|
||
u_active = st.selectbox("是否启用", ["启用", "禁用"], index=0, key=_ui_key("adm_u_active"))
|
||
u_pwd = st.text_input("初始/重置密码", type="password", key=_ui_key("adm_u_pwd"))
|
||
if st.button("保存用户", type="primary", key=_ui_key("adm_u_save")):
|
||
if not u_name:
|
||
st.error("用户名不能为空")
|
||
elif not u_pwd or len(u_pwd) < 6:
|
||
st.error("密码至少 6 位")
|
||
else:
|
||
uid = db.upsert_user(
|
||
username=u_name.strip(),
|
||
password=u_pwd,
|
||
role=u_role,
|
||
is_active=(1 if u_active == "启用" else 0),
|
||
)
|
||
st.success(f"已保存用户:{u_name} (id={uid})")
|
||
st.rerun()
|
||
|
||
with st.expander("重置/禁用用户(按用户名)", expanded=False):
|
||
uname = st.selectbox(
|
||
"选择用户",
|
||
options=[u.get("username") for u in users if u.get("username")],
|
||
key=_ui_key("adm_sel_user"),
|
||
) if users else None
|
||
new_pass = st.text_input("新密码(可选)", type="password", key=_ui_key("adm_reset_pwd"))
|
||
new_active = st.selectbox("状态", ["不修改", "启用", "禁用"], key=_ui_key("adm_reset_active"))
|
||
if st.button("应用修改", type="secondary", key=_ui_key("adm_apply")):
|
||
target = next((u for u in users if u.get("username") == uname), None)
|
||
if not target:
|
||
st.error("未找到用户")
|
||
else:
|
||
if new_pass:
|
||
if len(new_pass) < 6:
|
||
st.error("密码至少 6 位")
|
||
st.stop()
|
||
db.reset_user_password(target.get("id"), new_pass)
|
||
if new_active != "不修改":
|
||
db.set_user_active(target.get("id"), 1 if new_active == "启用" else 0)
|
||
st.success("已更新")
|
||
st.rerun()
|
||
|