chore: sync code and project files
This commit is contained in:
652
app.py
652
app.py
@@ -26,6 +26,7 @@ from modules import path_utils
|
||||
from modules import limits
|
||||
from modules.legacy_path_mapper import map_legacy_local_path
|
||||
from modules.legacy_normalizer import normalize_legacy_project
|
||||
import extra_streamlit_components as stx
|
||||
|
||||
# Page Config
|
||||
st.set_page_config(
|
||||
@@ -35,6 +36,84 @@ st.set_page_config(
|
||||
initial_sidebar_state="expanded"
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Auth (login + remember cookie)
|
||||
# ============================================================
|
||||
COOKIE_NAME = "vf_session"
|
||||
cookie_manager = stx.CookieManager(key="vf_cookie_mgr")
|
||||
|
||||
def _current_user() -> dict:
|
||||
u = st.session_state.get("current_user")
|
||||
return u if isinstance(u, dict) else None
|
||||
|
||||
def _set_current_user(u: dict):
|
||||
st.session_state.current_user = u
|
||||
|
||||
def _try_restore_login():
|
||||
if _current_user():
|
||||
return
|
||||
try:
|
||||
token = cookie_manager.get(COOKIE_NAME)
|
||||
except Exception:
|
||||
token = None
|
||||
if token:
|
||||
u = db.validate_session(token)
|
||||
if u:
|
||||
_set_current_user(u)
|
||||
|
||||
def _logout():
|
||||
try:
|
||||
token = cookie_manager.get(COOKIE_NAME)
|
||||
except Exception:
|
||||
token = None
|
||||
if token:
|
||||
db.revoke_session(token)
|
||||
try:
|
||||
cookie_manager.delete(COOKIE_NAME)
|
||||
except Exception:
|
||||
pass
|
||||
st.session_state.pop("current_user", None)
|
||||
st.rerun()
|
||||
|
||||
def _login_gate():
|
||||
_try_restore_login()
|
||||
u = _current_user()
|
||||
if u:
|
||||
return
|
||||
st.title("登录")
|
||||
# Login page runs before the global CSS block below; inject minimal CSS here.
|
||||
st.markdown(
|
||||
"""
|
||||
<style>
|
||||
button[aria-label="Show password text"],
|
||||
button[aria-label="Hide password text"] {
|
||||
display: none !important;
|
||||
}
|
||||
</style>
|
||||
""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
c1, c2 = st.columns([1, 2])
|
||||
with c1:
|
||||
username = st.text_input("用户名", value="", key="login_user")
|
||||
password = st.text_input("密码", value="", type="password", key="login_pass")
|
||||
if st.button("登录", type="primary"):
|
||||
au = db.authenticate_user(username.strip(), password)
|
||||
if not au:
|
||||
st.error("用户名或密码错误,或账号已禁用")
|
||||
st.stop()
|
||||
token = db.create_session(au["id"])
|
||||
try:
|
||||
cookie_manager.set(COOKIE_NAME, token, max_age=7 * 24 * 3600)
|
||||
except Exception:
|
||||
# fallback: session-only
|
||||
pass
|
||||
_set_current_user(au)
|
||||
st.rerun()
|
||||
st.stop()
|
||||
|
||||
_login_gate()
|
||||
|
||||
# ============================================================
|
||||
# BGM 智能匹配函数
|
||||
# ============================================================
|
||||
@@ -109,6 +188,12 @@ st.markdown("""
|
||||
.stTextInput input, .stTextArea textarea {
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
/* Hide Streamlit password reveal (eye) buttons to avoid accidental exposure */
|
||||
button[aria-label="Show password text"],
|
||||
button[aria-label="Hide password text"] {
|
||||
display: none !important;
|
||||
}
|
||||
</style>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
@@ -150,9 +235,9 @@ def _ui_key(suffix: str) -> str:
|
||||
|
||||
def load_project(project_id):
|
||||
"""Load project state from DB"""
|
||||
data = db.get_project(project_id)
|
||||
data = db.get_project_for_user(project_id, _current_user())
|
||||
if not data:
|
||||
st.error("Project not found")
|
||||
st.error("Project not found / no permission")
|
||||
return
|
||||
|
||||
st.session_state.project_id = project_id
|
||||
@@ -203,6 +288,7 @@ def load_project(project_id):
|
||||
st.session_state.uploaded_images = []
|
||||
|
||||
# Restore assets
|
||||
# RBAC: assets are project-scoped, so permission already checked above.
|
||||
assets = db.get_assets(project_id)
|
||||
images = {}
|
||||
videos = {}
|
||||
@@ -238,11 +324,107 @@ def load_project(project_id):
|
||||
else:
|
||||
st.session_state.current_step = 0
|
||||
|
||||
# ============================================================
|
||||
# Helper Functions (must be defined before Sidebar uses them)
|
||||
# ============================================================
|
||||
def _record_metrics(project_id: str, patch: dict):
|
||||
"""Persist lightweight timing/diagnostic metrics into project.product_info['_metrics']."""
|
||||
if not project_id or not isinstance(patch, dict) or not patch:
|
||||
return
|
||||
try:
|
||||
proj = db.get_project(project_id) or {}
|
||||
product_info = proj.get("product_info") or {}
|
||||
metrics = product_info.get("_metrics") if isinstance(product_info.get("_metrics"), dict) else {}
|
||||
metrics.update(patch)
|
||||
metrics["updated_at"] = time.time()
|
||||
product_info["_metrics"] = metrics
|
||||
db.update_project_product_info(project_id, product_info)
|
||||
except Exception:
|
||||
# metrics must never break UX
|
||||
pass
|
||||
|
||||
|
||||
def _get_metrics(project_id: str) -> dict:
|
||||
try:
|
||||
proj = db.get_project(project_id) or {}
|
||||
product_info = proj.get("product_info") or {}
|
||||
m = product_info.get("_metrics")
|
||||
return m if isinstance(m, dict) else {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _ensure_local_videos(project_id: str, scenes: list):
|
||||
"""
|
||||
Step 5 合成前的保障逻辑:
|
||||
检查分镜视频是否已下载到服务器本地。如果只有 URL 没有本地文件,则后台静默下载。
|
||||
"""
|
||||
if not project_id or not scenes:
|
||||
return
|
||||
|
||||
vid_gen = VideoGenerator()
|
||||
downloaded = 0
|
||||
missing_assets = []
|
||||
|
||||
for scene in scenes:
|
||||
scene_id = scene["id"]
|
||||
local_path = st.session_state.scene_videos.get(scene_id)
|
||||
|
||||
# 如果本地文件不存在,尝试补全
|
||||
if not local_path or not os.path.exists(local_path):
|
||||
asset = db.get_asset(project_id, scene_id, "video")
|
||||
if not asset:
|
||||
missing_assets.append(f"Scene {scene_id}")
|
||||
continue
|
||||
|
||||
meta = asset.get("metadata") or {}
|
||||
video_url = meta.get("video_url")
|
||||
task_id = asset.get("task_id")
|
||||
|
||||
# 如果 DB 里没有 URL,但有 task_id,尝试现场查询一次
|
||||
if not video_url and task_id:
|
||||
logger.info(f"Checking task {task_id} status for Scene {scene_id} during composition...")
|
||||
status, url = vid_gen.check_task_status(task_id)
|
||||
if status == "succeeded" and url:
|
||||
video_url = url
|
||||
db.update_asset_metadata(project_id, scene_id, "video", {"video_url": url, "volc_status": status})
|
||||
|
||||
if video_url:
|
||||
out_name = path_utils.unique_filename(
|
||||
prefix="scene_video",
|
||||
ext="mp4",
|
||||
project_id=project_id,
|
||||
scene_id=scene_id,
|
||||
extra=f"auto_{int(time.time())}"
|
||||
)
|
||||
logger.info(f"Auto-downloading missing video for Scene {scene_id} from {video_url}")
|
||||
target_dir = path_utils.project_videos_dir(project_id)
|
||||
target_path = str(target_dir / out_name)
|
||||
if vid_gen._download_video_to(video_url, target_path):
|
||||
st.session_state.scene_videos[scene_id] = target_path
|
||||
db.save_asset(project_id, scene_id, "video", "completed", local_path=target_path)
|
||||
downloaded += 1
|
||||
else:
|
||||
missing_assets.append(f"Scene {scene_id} (未生成/无URL)")
|
||||
|
||||
if downloaded > 0:
|
||||
logger.info(f"Successfully auto-downloaded {downloaded} missing videos for project {project_id}")
|
||||
|
||||
if missing_assets:
|
||||
msg = f"无法合成:以下分镜缺少视频素材,请先在第 4 步生成:{', '.join(missing_assets)}"
|
||||
logger.warning(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
# ============================================================
|
||||
# Sidebar
|
||||
# ============================================================
|
||||
with st.sidebar:
|
||||
st.title("📽️ Video Flow")
|
||||
cu = _current_user()
|
||||
if cu:
|
||||
st.caption(f"登录用户: {cu.get('username')} ({cu.get('role')})")
|
||||
if st.button("退出登录", type="secondary"):
|
||||
_logout()
|
||||
|
||||
# Mode Selection - 正确计算 index
|
||||
mode_options = ["🛠️ 工作台", "📜 历史任务", "⚙️ 设置"]
|
||||
@@ -262,7 +444,7 @@ with st.sidebar:
|
||||
if st.session_state.view_mode == "workspace":
|
||||
# Project Selection
|
||||
st.subheader("Current Project")
|
||||
projects = db.list_projects()
|
||||
projects = db.list_projects_for_user(_current_user())
|
||||
proj_options = {p['id']: f"{p.get('name', 'Untitled')} ({p['id']})" for p in projects}
|
||||
|
||||
selected_proj_id = st.selectbox(
|
||||
@@ -307,13 +489,6 @@ with st.sidebar:
|
||||
for k in keys:
|
||||
if k in m:
|
||||
st.caption(f"{k}: {m.get(k)}")
|
||||
# 在线剪辑入口(React Editor)
|
||||
web_base_url = os.getenv("WEB_BASE_URL", "http://localhost:3000").rstrip("/")
|
||||
st.markdown(
|
||||
f"[打开在线剪辑器]({web_base_url}/editor/{st.session_state.project_id})",
|
||||
unsafe_allow_html=False,
|
||||
)
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# Navigation / Progress
|
||||
@@ -335,35 +510,6 @@ with st.sidebar:
|
||||
if key != "view_mode": del st.session_state[key]
|
||||
st.rerun()
|
||||
|
||||
# ============================================================
|
||||
# Helper Functions
|
||||
# ============================================================
|
||||
def _record_metrics(project_id: str, patch: dict):
|
||||
"""Persist lightweight timing/diagnostic metrics into project.product_info['_metrics']."""
|
||||
if not project_id or not isinstance(patch, dict) or not patch:
|
||||
return
|
||||
try:
|
||||
proj = db.get_project(project_id) or {}
|
||||
product_info = proj.get("product_info") or {}
|
||||
metrics = product_info.get("_metrics") if isinstance(product_info.get("_metrics"), dict) else {}
|
||||
metrics.update(patch)
|
||||
metrics["updated_at"] = time.time()
|
||||
product_info["_metrics"] = metrics
|
||||
db.update_project_product_info(project_id, product_info)
|
||||
except Exception:
|
||||
# metrics must never break UX
|
||||
pass
|
||||
|
||||
|
||||
def _get_metrics(project_id: str) -> dict:
|
||||
try:
|
||||
proj = db.get_project(project_id) or {}
|
||||
product_info = proj.get("product_info") or {}
|
||||
m = product_info.get("_metrics")
|
||||
return m if isinstance(m, dict) else {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def save_uploaded_file(project_id: str, uploaded_file):
|
||||
"""Save uploaded file to per-project upload dir (avoid overwrites across projects)."""
|
||||
if uploaded_file is None:
|
||||
@@ -467,13 +613,24 @@ if st.session_state.view_mode == "workspace":
|
||||
# DB: Create Project
|
||||
# 将 uploaded_images 保存到 product_info 以便持久化
|
||||
product_info = {"category": category, "price": price, "tags": tags, "params": params, "style_hint": style_hint, "uploaded_images": image_paths}
|
||||
db.create_project(st.session_state.project_id, product_name, product_info)
|
||||
db.create_project(
|
||||
st.session_state.project_id,
|
||||
product_name,
|
||||
product_info,
|
||||
owner_user_id=(_current_user() or {}).get("id"),
|
||||
)
|
||||
|
||||
# Call Script Generator
|
||||
with st.spinner(f"正在分析商品信息并生成脚本 ({selected_model_label})..."):
|
||||
gen = ScriptGenerator()
|
||||
t0 = perf_counter()
|
||||
script = gen.generate_script(product_name, product_info, image_paths, model_provider=model_provider)
|
||||
script = gen.generate_script(
|
||||
product_name,
|
||||
product_info,
|
||||
image_paths,
|
||||
model_provider=model_provider,
|
||||
user_id=(_current_user() or {}).get("id"),
|
||||
)
|
||||
_record_metrics(st.session_state.project_id, {
|
||||
"script_gen_s": round(perf_counter() - t0, 3),
|
||||
"script_model": model_provider,
|
||||
@@ -689,9 +846,9 @@ if st.session_state.view_mode == "workspace":
|
||||
if not ok:
|
||||
st.warning("系统正在生成其他任务(生图并发已达上限),请稍后再试。")
|
||||
st.stop()
|
||||
img_gen = ImageGenerator()
|
||||
# Pass ALL uploaded images as reference
|
||||
base_imgs = st.session_state.uploaded_images if st.session_state.uploaded_images else []
|
||||
img_gen = ImageGenerator()
|
||||
# Pass ALL uploaded images as reference
|
||||
base_imgs = st.session_state.uploaded_images if st.session_state.uploaded_images else []
|
||||
|
||||
if not base_imgs:
|
||||
st.error("No base image found (未找到参考底图). Please upload in Step 1.")
|
||||
@@ -742,7 +899,7 @@ if st.session_state.view_mode == "workspace":
|
||||
total_scenes = len(scenes)
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
|
||||
try:
|
||||
t0 = perf_counter()
|
||||
# Parallel workers within a single run; global semaphore already acquired above.
|
||||
@@ -755,7 +912,7 @@ if st.session_state.view_mode == "workspace":
|
||||
img_gen.generate_single_scene_image,
|
||||
scene=scene,
|
||||
original_image_path=list(base_imgs), # ONLY merchant images
|
||||
previous_image_path=None,
|
||||
previous_image_path=None,
|
||||
model_provider=img_provider,
|
||||
visual_anchor=visual_anchor,
|
||||
project_id=st.session_state.project_id,
|
||||
@@ -768,17 +925,15 @@ if st.session_state.view_mode == "workspace":
|
||||
status_text.text(f"已完成 {done}/{total_scenes}(Scene {scene_id})")
|
||||
try:
|
||||
img_path = fut.result()
|
||||
if img_path:
|
||||
st.session_state.scene_images[scene_id] = img_path
|
||||
db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=img_path)
|
||||
# Invalidate stale video for this scene (image changed => old video is wrong)
|
||||
db.clear_asset(st.session_state.project_id, scene_id, "video", status="pending")
|
||||
st.session_state.scene_videos.pop(scene_id, None)
|
||||
except Exception as e:
|
||||
img_path = None
|
||||
st.warning(f"Scene {scene_id} 生成失败:{e}")
|
||||
|
||||
if img_path:
|
||||
st.session_state.scene_images[scene_id] = img_path
|
||||
db.save_asset(st.session_state.project_id, scene_id, "image", "completed", local_path=img_path)
|
||||
# Invalidate stale video for this scene (image changed => old video is wrong)
|
||||
db.clear_asset(st.session_state.project_id, scene_id, "video", status="pending")
|
||||
st.session_state.scene_videos.pop(scene_id, None)
|
||||
|
||||
|
||||
progress_bar.progress(done / total_scenes)
|
||||
|
||||
status_text.text("生图完成!")
|
||||
@@ -863,112 +1018,152 @@ if st.session_state.view_mode == "workspace":
|
||||
scenes = st.session_state.script_data.get("scenes", [])
|
||||
vid_gen = VideoGenerator()
|
||||
|
||||
# Submit-only (non-blocking) to avoid freezing Streamlit under concurrency
|
||||
if st.button("🎬 提交图生视频任务(非阻塞)", type="primary"):
|
||||
st.caption("简化策略:第 4 步完成“生成→轮询→下载到服务器本地”。当至少有一个分镜视频落盘后,才允许进入第 5 步合成。")
|
||||
|
||||
if st.button("🎬 生成分镜视频并下载到服务器(阻塞)", type="primary"):
|
||||
with limits.acquire_video(blocking=False) as ok:
|
||||
if not ok:
|
||||
st.warning("系统正在处理其他视频任务(并发已达上限),请稍后再试。")
|
||||
st.stop()
|
||||
t0 = perf_counter()
|
||||
submitted = 0
|
||||
|
||||
if not st.session_state.project_id:
|
||||
st.error("缺少 project_id,无法生成视频。")
|
||||
st.stop()
|
||||
|
||||
if not scenes:
|
||||
st.error("脚本中没有分镜,无法生成视频。")
|
||||
st.stop()
|
||||
|
||||
# 先把 DB 中已完成且存在的本地视频恢复到 session
|
||||
for scene in scenes:
|
||||
scene_id = scene["id"]
|
||||
existing = st.session_state.scene_videos.get(scene_id)
|
||||
if existing and os.path.exists(existing):
|
||||
continue
|
||||
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
|
||||
local_path = (asset or {}).get("local_path")
|
||||
if local_path and os.path.exists(local_path):
|
||||
st.session_state.scene_videos[scene_id] = local_path
|
||||
|
||||
t0 = perf_counter()
|
||||
total = len(scenes)
|
||||
done = sum(1 for _sid, p in (st.session_state.scene_videos or {}).items() if p and os.path.exists(p))
|
||||
progress = st.progress(0.0)
|
||||
status_text = st.empty()
|
||||
|
||||
# 收集/提交任务
|
||||
tasks = {} # scene_id -> task_id
|
||||
for scene in scenes:
|
||||
scene_id = scene["id"]
|
||||
existing = st.session_state.scene_videos.get(scene_id)
|
||||
if existing and os.path.exists(existing):
|
||||
continue
|
||||
|
||||
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
|
||||
task_id = (asset or {}).get("task_id")
|
||||
if task_id:
|
||||
tasks[scene_id] = task_id
|
||||
continue
|
||||
|
||||
image_path = st.session_state.scene_images.get(scene_id)
|
||||
prompt = scene.get("video_prompt", "High quality video")
|
||||
task_id = vid_gen.submit_scene_video_task(
|
||||
new_task_id = vid_gen.submit_scene_video_task(
|
||||
st.session_state.project_id, scene_id, image_path, prompt
|
||||
)
|
||||
if task_id:
|
||||
submitted += 1
|
||||
_record_metrics(st.session_state.project_id, {
|
||||
"video_submit_s": round(perf_counter() - t0, 3),
|
||||
"video_submitted": submitted,
|
||||
})
|
||||
if submitted:
|
||||
db.update_project_status(st.session_state.project_id, "videos_processing")
|
||||
st.success(f"已提交 {submitted} 个分镜视频任务。可点击下方“刷新恢复”下载结果。")
|
||||
time.sleep(0.5)
|
||||
st.rerun()
|
||||
else:
|
||||
st.warning("未提交任何任务(可能缺少图片或接口失败)。")
|
||||
if new_task_id:
|
||||
tasks[scene_id] = new_task_id
|
||||
|
||||
if tasks:
|
||||
db.update_project_status(st.session_state.project_id, "videos_processing")
|
||||
|
||||
out_dir = path_utils.project_videos_dir(st.session_state.project_id)
|
||||
pending = set(tasks.keys())
|
||||
deadline = time.time() + 15 * 60 # 15 min
|
||||
|
||||
while pending and time.time() < deadline:
|
||||
status_text.text(f"视频生成/下载中:已完成 {done}/{total},队列中 {len(pending)} ...")
|
||||
to_remove = []
|
||||
|
||||
for scene_id in list(pending):
|
||||
task_id = tasks.get(scene_id)
|
||||
if not task_id:
|
||||
to_remove.append(scene_id)
|
||||
continue
|
||||
|
||||
if st.button("🔄 刷新状态并恢复已完成任务", type="secondary"):
|
||||
with limits.acquire_video(blocking=False) as ok:
|
||||
if not ok:
|
||||
st.warning("系统正在处理其他视频任务(并发已达上限),请稍后再试。")
|
||||
st.stop()
|
||||
t0 = perf_counter()
|
||||
updated = 0
|
||||
for scene in scenes:
|
||||
scene_id = scene["id"]
|
||||
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
|
||||
if not asset or not asset.get("task_id"):
|
||||
continue
|
||||
# if already have local video, skip
|
||||
existing = st.session_state.scene_videos.get(scene_id)
|
||||
if existing and os.path.exists(existing):
|
||||
continue
|
||||
task_id = asset.get("task_id")
|
||||
# Query volc status; store URL for direct preview (no server download)
|
||||
status = None
|
||||
url = None
|
||||
# short retries for "succeeded but url missing"
|
||||
for attempt in range(3):
|
||||
status, url = vid_gen.check_task_status(task_id)
|
||||
if status == "succeeded" and url:
|
||||
break
|
||||
time.sleep(0.5 * (2 ** attempt))
|
||||
out_name = path_utils.unique_filename(
|
||||
prefix="scene_video",
|
||||
ext="mp4",
|
||||
project_id=st.session_state.project_id,
|
||||
scene_id=scene_id,
|
||||
extra=(task_id[-8:] if isinstance(task_id, str) else None),
|
||||
)
|
||||
target_path = str(out_dir / out_name)
|
||||
ok_dl = vid_gen._download_video_to(url, target_path)
|
||||
if ok_dl and os.path.exists(target_path):
|
||||
st.session_state.scene_videos[scene_id] = target_path
|
||||
db.save_asset(
|
||||
st.session_state.project_id,
|
||||
scene_id,
|
||||
"video",
|
||||
"completed",
|
||||
local_path=target_path,
|
||||
task_id=task_id,
|
||||
metadata={"downloaded_at": time.time()},
|
||||
)
|
||||
done += 1
|
||||
else:
|
||||
db.save_asset(
|
||||
st.session_state.project_id,
|
||||
scene_id,
|
||||
"video",
|
||||
"failed",
|
||||
task_id=task_id,
|
||||
metadata={"download_error": True, "checked_at": time.time()},
|
||||
)
|
||||
to_remove.append(scene_id)
|
||||
elif status in ["failed", "cancelled"]:
|
||||
db.save_asset(
|
||||
st.session_state.project_id,
|
||||
scene_id,
|
||||
"video",
|
||||
"failed",
|
||||
task_id=task_id,
|
||||
metadata={"volc_status": status, "checked_at": time.time()},
|
||||
)
|
||||
to_remove.append(scene_id)
|
||||
else:
|
||||
db.update_asset_metadata(
|
||||
st.session_state.project_id,
|
||||
scene_id,
|
||||
"video",
|
||||
{"volc_status": status, "checked_at": time.time()},
|
||||
)
|
||||
|
||||
meta_patch = {"checked_at": time.time(), "volc_status": status}
|
||||
if url:
|
||||
meta_patch["video_url"] = url
|
||||
db.update_asset_metadata(st.session_state.project_id, scene_id, "video", meta_patch)
|
||||
updated += 1
|
||||
for sid in to_remove:
|
||||
pending.discard(sid)
|
||||
|
||||
progress.progress(min(1.0, done / max(total, 1)))
|
||||
if pending:
|
||||
time.sleep(5)
|
||||
|
||||
if pending:
|
||||
st.warning(f"仍有 {len(pending)} 个分镜未在本轮完成:{sorted(list(pending))}。可再次点击按钮继续轮询与下载。")
|
||||
|
||||
_record_metrics(st.session_state.project_id, {
|
||||
"video_recover_s": round(perf_counter() - t0, 3),
|
||||
"video_recovered": updated,
|
||||
"video_blocking_total_s": round(perf_counter() - t0, 3),
|
||||
"video_done": done,
|
||||
"video_total": total,
|
||||
})
|
||||
if updated:
|
||||
st.success(f"已刷新 {updated} 个分镜状态(成功的将以 URL 直连预览)。")
|
||||
else:
|
||||
st.info("暂无可恢复的视频(可能仍在排队/生成中)。")
|
||||
time.sleep(0.5)
|
||||
st.rerun()
|
||||
|
||||
if st.button("📥 准备合成素材(下载成功的视频到服务器)", type="secondary"):
|
||||
with limits.acquire_video(blocking=False) as ok:
|
||||
if not ok:
|
||||
st.warning("系统正在处理其他视频任务(并发已达上限),请稍后再试。")
|
||||
st.stop()
|
||||
downloaded = 0
|
||||
for scene in scenes:
|
||||
scene_id = scene["id"]
|
||||
existing = st.session_state.scene_videos.get(scene_id)
|
||||
if existing and os.path.exists(existing):
|
||||
continue
|
||||
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
|
||||
meta = (asset or {}).get("metadata") or {}
|
||||
video_url = meta.get("video_url")
|
||||
if not video_url:
|
||||
continue
|
||||
out_name = path_utils.unique_filename(
|
||||
prefix="scene_video",
|
||||
ext="mp4",
|
||||
project_id=st.session_state.project_id,
|
||||
scene_id=scene_id,
|
||||
)
|
||||
target_path = str(path_utils.project_videos_dir(st.session_state.project_id) / out_name)
|
||||
if vid_gen._download_video_to(video_url, target_path):
|
||||
st.session_state.scene_videos[scene_id] = target_path
|
||||
db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=target_path, task_id=(asset or {}).get("task_id"), metadata=meta)
|
||||
downloaded += 1
|
||||
if downloaded:
|
||||
st.success(f"已下载 {downloaded} 段视频,可进入合成。")
|
||||
if any(p and os.path.exists(p) for p in (st.session_state.scene_videos or {}).values()):
|
||||
db.update_project_status(st.session_state.project_id, "videos_generated")
|
||||
st.success("已生成并下载到服务器本地。进入第 5 步合成。")
|
||||
st.session_state.current_step = 4
|
||||
st.rerun()
|
||||
else:
|
||||
st.info("暂无可下载的视频(请先刷新状态获取 video_url)。")
|
||||
time.sleep(0.5)
|
||||
st.rerun()
|
||||
st.error("未生成任何可用视频,请检查生视频接口或稍后重试。")
|
||||
|
||||
# Display Videos (even when partially available)
|
||||
if st.session_state.scene_videos or scenes:
|
||||
@@ -987,43 +1182,14 @@ if st.session_state.view_mode == "workspace":
|
||||
if vid_path and os.path.exists(vid_path):
|
||||
st.video(vid_path)
|
||||
else:
|
||||
# Try URL preview from DB metadata
|
||||
asset = db.get_asset(st.session_state.project_id, scene_id, "video")
|
||||
meta = (asset or {}).get("metadata") or {}
|
||||
video_url = meta.get("video_url")
|
||||
if video_url:
|
||||
# Detect stale mapping: if source image signature differs, warn and avoid misleading preview
|
||||
stale = False
|
||||
try:
|
||||
cur_img = st.session_state.scene_images.get(scene_id)
|
||||
if cur_img and os.path.exists(cur_img):
|
||||
st_img = stat(cur_img)
|
||||
cur_size = int(getattr(st_img, "st_size", 0) or 0)
|
||||
cur_mtime = float(getattr(st_img, "st_mtime", 0.0) or 0.0)
|
||||
src_size = meta.get("source_image_size")
|
||||
src_mtime = meta.get("source_image_mtime")
|
||||
if (src_size and cur_size and int(src_size) != cur_size) or (src_mtime and cur_mtime and abs(float(src_mtime) - cur_mtime) > 1e-3):
|
||||
stale = True
|
||||
except Exception:
|
||||
stale = False
|
||||
if stale:
|
||||
st.warning("检测到该视频可能基于旧图片生成(图片已更新)。请点击“提交图生视频任务”重新生成,以避免主体不一致。")
|
||||
st.caption("URL 直连预览(不经服务器落盘)")
|
||||
st.video(video_url)
|
||||
status = (asset or {}).get("status") or "pending"
|
||||
task_id = (asset or {}).get("task_id")
|
||||
if task_id:
|
||||
st.caption(f"状态: {status} | Task: {str(task_id)[-6:]}")
|
||||
else:
|
||||
st.warning("Video missing")
|
||||
# --- Recovery Logic ---
|
||||
if asset and asset.get("task_id"):
|
||||
task_id = asset.get("task_id")
|
||||
if st.button(f"🔍 刷新URL (Task {task_id[-6:]})", key=f"recov_{scene_id}"):
|
||||
with st.spinner("查询任务状态中..."):
|
||||
status, url = vid_gen.check_task_status(task_id)
|
||||
patch = {"checked_at": time.time(), "volc_status": status}
|
||||
if url:
|
||||
patch["video_url"] = url
|
||||
db.update_asset_metadata(st.session_state.project_id, scene_id, "video", patch)
|
||||
st.success("已刷新任务状态。")
|
||||
st.rerun()
|
||||
st.caption(f"状态: {status}")
|
||||
st.warning("暂无本地视频(请点击上方“生成分镜视频并下载到服务器”)。")
|
||||
|
||||
# Per-scene regenerate button
|
||||
if st.button(f"🔄 重生 S{scene_id}", key=f"regen_vid_{scene_id}"):
|
||||
@@ -1046,14 +1212,11 @@ if st.session_state.view_mode == "workspace":
|
||||
scene_id=scene_id,
|
||||
extra=(t_id[-8:] if isinstance(t_id, str) else None),
|
||||
)
|
||||
new_path = vid_gen._download_video(
|
||||
url,
|
||||
out_name,
|
||||
output_dir=path_utils.project_videos_dir(st.session_state.project_id),
|
||||
)
|
||||
if new_path:
|
||||
st.session_state.scene_videos[scene_id] = new_path
|
||||
db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=new_path, task_id=t_id)
|
||||
target_dir = path_utils.project_videos_dir(st.session_state.project_id)
|
||||
target_path = str(target_dir / out_name)
|
||||
if vid_gen._download_video_to(url, target_path) and os.path.exists(target_path):
|
||||
st.session_state.scene_videos[scene_id] = target_path
|
||||
db.save_asset(st.session_state.project_id, scene_id, "video", "completed", local_path=target_path, task_id=t_id)
|
||||
st.rerun()
|
||||
break
|
||||
elif status in ["failed", "cancelled"]:
|
||||
@@ -1067,18 +1230,23 @@ if st.session_state.view_mode == "workspace":
|
||||
|
||||
c_act1, c_act2 = st.columns([1, 4])
|
||||
with c_act1:
|
||||
if st.button("🔄 重新生成所有视频", type="secondary"):
|
||||
if st.button("🧹 清空视频并重新生成", type="secondary"):
|
||||
# Clear videos and rerun
|
||||
st.session_state.scene_videos = {}
|
||||
# Also clear DB video assets to avoid stale URL preview
|
||||
# Also clear DB video assets
|
||||
if st.session_state.project_id:
|
||||
db.clear_assets(st.session_state.project_id, "video", status="pending")
|
||||
st.rerun()
|
||||
|
||||
with c_act2:
|
||||
if st.button("下一步:合成最终成片", type="primary"):
|
||||
can_compose = any(
|
||||
p and os.path.exists(p) for p in (st.session_state.scene_videos or {}).values()
|
||||
)
|
||||
if st.button("进入第 5 步:合成最终成片", type="primary", disabled=not can_compose):
|
||||
st.session_state.current_step = 4
|
||||
st.rerun()
|
||||
if not can_compose:
|
||||
st.caption("⚠️ 需要至少一个分镜视频已下载到服务器本地后,才能进入合成。")
|
||||
|
||||
# --- Step 5: Final Composition & Tuning ---
|
||||
if st.session_state.current_step >= 4:
|
||||
@@ -1174,6 +1342,9 @@ if st.session_state.view_mode == "workspace":
|
||||
|
||||
if st.button("🔄 重新合成 (Re-Compose)", type="primary"):
|
||||
with st.spinner("正在应用修改并重新合成..."):
|
||||
# 自动补齐视频下载逻辑 (关键优化)
|
||||
_ensure_local_videos(st.session_state.project_id, st.session_state.script_data.get("scenes", []))
|
||||
|
||||
composer = VideoComposer(voice_type=selected_voice)
|
||||
|
||||
bgm_path = None
|
||||
@@ -1255,6 +1426,9 @@ if st.session_state.view_mode == "workspace":
|
||||
st.info("暂无合成视频,请先点击开始合成。")
|
||||
if st.button("✨ 开始首次合成", type="primary"):
|
||||
with st.spinner("正在进行多轨合成..."):
|
||||
# 自动补齐视频下载逻辑 (关键优化)
|
||||
_ensure_local_videos(st.session_state.project_id, st.session_state.script_data.get("scenes", []))
|
||||
|
||||
# Default compose logic with smart BGM matching
|
||||
composer = VideoComposer(voice_type=config.VOLC_TTS_DEFAULT_VOICE)
|
||||
|
||||
@@ -1290,6 +1464,9 @@ if st.session_state.view_mode == "workspace":
|
||||
else:
|
||||
st.success(f"共找到 {len(found_files)} 个历史版本")
|
||||
|
||||
# 提示用户:如果重新合成,系统会自动补全未下载的素材
|
||||
st.caption("💡 重新合成将应用当前的微调设置。如果分镜未下载,系统将尝试自动补全。")
|
||||
|
||||
# 遍历显示所有历史版本
|
||||
for idx, vid_path in enumerate(found_files):
|
||||
mtime = os.path.getmtime(vid_path)
|
||||
@@ -1324,7 +1501,7 @@ if st.session_state.view_mode == "workspace":
|
||||
elif st.session_state.view_mode == "history":
|
||||
st.header("📜 历史任务")
|
||||
|
||||
projects = db.list_projects()
|
||||
projects = db.list_projects_for_user(_current_user())
|
||||
|
||||
for proj in projects:
|
||||
with st.expander(f"{proj['name']} ({proj['updated_at']})"):
|
||||
@@ -1394,11 +1571,19 @@ elif st.session_state.view_mode == "settings":
|
||||
st.subheader("Prompt 配置")
|
||||
|
||||
# Script Generation Prompt
|
||||
current_prompt = db.get_config("prompt_script_gen")
|
||||
cu = _current_user() or {}
|
||||
user_prompt = None
|
||||
try:
|
||||
user_prompt = db.get_user_prompt(cu.get("id"), "prompt_script_gen") if cu.get("id") else None
|
||||
except Exception:
|
||||
user_prompt = None
|
||||
current_prompt = user_prompt or db.get_config("prompt_script_gen")
|
||||
|
||||
# 显示当前状态
|
||||
if current_prompt:
|
||||
st.info("✅ 已加载自定义 Prompt(来自数据库)")
|
||||
if user_prompt:
|
||||
st.info("✅ 已加载自定义 Prompt(当前用户)")
|
||||
elif current_prompt:
|
||||
st.info("✅ 已加载自定义 Prompt(全局默认)")
|
||||
else:
|
||||
st.warning("⚠️ 使用默认 Prompt(数据库中无自定义配置)")
|
||||
# Load default from instance if not in DB
|
||||
@@ -1410,18 +1595,89 @@ elif st.session_state.view_mode == "settings":
|
||||
col_save, col_reset = st.columns([1, 3])
|
||||
with col_save:
|
||||
if st.button("💾 保存配置", type="primary"):
|
||||
db.set_config("prompt_script_gen", new_prompt, "System prompt for script generation step")
|
||||
# 验证保存
|
||||
saved = db.get_config("prompt_script_gen")
|
||||
# Save per-user prompt
|
||||
db.set_user_prompt(cu.get("id"), "prompt_script_gen", new_prompt)
|
||||
saved = db.get_user_prompt(cu.get("id"), "prompt_script_gen")
|
||||
if saved == new_prompt:
|
||||
st.success("✅ 配置已保存并验证成功!下次生成脚本时将使用新 Prompt。")
|
||||
st.success("✅ 已保存为“当前用户 Prompt”。下次生成脚本仅影响你自己的账号。")
|
||||
else:
|
||||
st.error("❌ 保存可能失败,请检查日志")
|
||||
|
||||
with col_reset:
|
||||
if st.button("🔄 恢复默认"):
|
||||
temp_gen = ScriptGenerator()
|
||||
db.set_config("prompt_script_gen", temp_gen.default_system_prompt, "System prompt for script generation step (DEFAULT)")
|
||||
st.success("已恢复默认 Prompt,请刷新页面查看")
|
||||
# Clear per-user override: set empty to remove effect
|
||||
db.set_user_prompt(cu.get("id"), "prompt_script_gen", "")
|
||||
st.success("已清除当前用户 Prompt,将回退到全局/默认 Prompt。")
|
||||
st.rerun()
|
||||
|
||||
st.markdown("---")
|
||||
st.subheader("账号与权限")
|
||||
|
||||
# Password change for current user
|
||||
with st.expander("修改我的密码", expanded=False):
|
||||
p1 = st.text_input("新密码", type="password", key=_ui_key("pwd_new"))
|
||||
p2 = st.text_input("确认新密码", type="password", key=_ui_key("pwd_new2"))
|
||||
if st.button("保存新密码", type="primary", key=_ui_key("pwd_save")):
|
||||
if not p1 or len(p1) < 6:
|
||||
st.error("密码至少 6 位")
|
||||
elif p1 != p2:
|
||||
st.error("两次输入不一致")
|
||||
else:
|
||||
db.reset_user_password(cu.get("id"), p1)
|
||||
st.success("密码已更新,请重新登录。")
|
||||
_logout()
|
||||
|
||||
# Admin console
|
||||
if (cu.get("role") == "admin"):
|
||||
st.markdown("### Admin:账号管理")
|
||||
users = db.list_users()
|
||||
if users:
|
||||
st.dataframe(
|
||||
[{k: u.get(k) for k in ["username", "role", "is_active", "last_login_at", "created_at", "id"]} for u in users],
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
with st.expander("创建/更新用户", expanded=False):
|
||||
u_name = st.text_input("用户名", key=_ui_key("adm_u_name"))
|
||||
u_role = st.selectbox("角色", ["user", "admin"], index=0, key=_ui_key("adm_u_role"))
|
||||
u_active = st.selectbox("是否启用", ["启用", "禁用"], index=0, key=_ui_key("adm_u_active"))
|
||||
u_pwd = st.text_input("初始/重置密码", type="password", key=_ui_key("adm_u_pwd"))
|
||||
if st.button("保存用户", type="primary", key=_ui_key("adm_u_save")):
|
||||
if not u_name:
|
||||
st.error("用户名不能为空")
|
||||
elif not u_pwd or len(u_pwd) < 6:
|
||||
st.error("密码至少 6 位")
|
||||
else:
|
||||
uid = db.upsert_user(
|
||||
username=u_name.strip(),
|
||||
password=u_pwd,
|
||||
role=u_role,
|
||||
is_active=(1 if u_active == "启用" else 0),
|
||||
)
|
||||
st.success(f"已保存用户:{u_name} (id={uid})")
|
||||
st.rerun()
|
||||
|
||||
with st.expander("重置/禁用用户(按用户名)", expanded=False):
|
||||
uname = st.selectbox(
|
||||
"选择用户",
|
||||
options=[u.get("username") for u in users if u.get("username")],
|
||||
key=_ui_key("adm_sel_user"),
|
||||
) if users else None
|
||||
new_pass = st.text_input("新密码(可选)", type="password", key=_ui_key("adm_reset_pwd"))
|
||||
new_active = st.selectbox("状态", ["不修改", "启用", "禁用"], key=_ui_key("adm_reset_active"))
|
||||
if st.button("应用修改", type="secondary", key=_ui_key("adm_apply")):
|
||||
target = next((u for u in users if u.get("username") == uname), None)
|
||||
if not target:
|
||||
st.error("未找到用户")
|
||||
else:
|
||||
if new_pass:
|
||||
if len(new_pass) < 6:
|
||||
st.error("密码至少 6 位")
|
||||
st.stop()
|
||||
db.reset_user_password(target.get("id"), new_pass)
|
||||
if new_active != "不修改":
|
||||
db.set_user_active(target.get("id"), 1 if new_active == "启用" else 0)
|
||||
st.success("已更新")
|
||||
st.rerun()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user