Files
video-flow/api/routes/projects.py
2026-01-09 14:09:16 +08:00

316 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
项目管理 API 路由
提供项目 CRUD 和状态查询
"""
import os
import time
import logging
from typing import List, Optional, Any, Dict
from pathlib import Path
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
import config
from modules.db_manager import db
from modules.script_gen import ScriptGenerator
from modules.image_gen import ImageGenerator
from modules.video_gen import VideoGenerator
from modules.legacy_path_mapper import map_legacy_local_path
logger = logging.getLogger(__name__)
router = APIRouter()
# ============================================================
# Pydantic Models
# ============================================================
class ProductInfo(BaseModel):
category: str = ""
price: str = ""
tags: str = ""
params: str = ""
style_hint: str = ""
class ProjectCreateRequest(BaseModel):
name: str
product_info: ProductInfo
class ProjectResponse(BaseModel):
id: str
name: str
status: str
product_info: dict
script_data: Optional[dict] = None
created_at: float
updated_at: float
class ProjectListItem(BaseModel):
id: str
name: str
status: str
updated_at: float
class PromptDebugResponse(BaseModel):
"""
导出当次脚本生成实际使用的 prompt来自 projects.script_data._debug
注意:默认不返回 raw_output避免体积过大/泄露无关信息。
"""
project_id: str
provider: Optional[str] = None
system_prompt: Optional[str] = None
user_prompt: Optional[str] = None
image_urls: Optional[List[str]] = None
raw_output: Optional[str] = None
class SceneAssetResponse(BaseModel):
scene_id: int
asset_type: str
status: str
local_path: Optional[str] = None
url: Optional[str] = None
# ============================================================
# API Endpoints
# ============================================================
@router.get("", response_model=List[ProjectListItem])
async def list_projects():
"""获取所有项目列表"""
projects = db.list_projects()
return projects
@router.post("", response_model=dict)
async def create_project(request: ProjectCreateRequest):
"""创建新项目"""
project_id = f"PROJ-{int(time.time())}"
product_info_dict = request.product_info.model_dump()
db.create_project(project_id, request.name, product_info_dict)
return {
"id": project_id,
"message": "项目创建成功"
}
@router.get("/{project_id}", response_model=ProjectResponse)
async def get_project(project_id: str):
"""获取项目详情"""
project = db.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
return project
@router.get("/{project_id}/prompt-debug", response_model=PromptDebugResponse)
async def get_prompt_debug(project_id: str, include_raw_output: bool = False):
"""
获取该项目“生成脚本时实际使用”的 system_prompt / user_prompt。
- 数据来源projects.script_data._debug由 ScriptGenerator 在生成脚本时写入)
- include_raw_output=true 时返回 raw_output可能很大
"""
project = db.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
script_data: Dict[str, Any] = project.get("script_data") or {}
debug: Dict[str, Any] = script_data.get("_debug") or {}
resp = {
"project_id": project_id,
"provider": debug.get("provider"),
"system_prompt": debug.get("system_prompt"),
"user_prompt": debug.get("user_prompt"),
"image_urls": debug.get("image_urls"),
}
if include_raw_output:
resp["raw_output"] = debug.get("raw_output")
return resp
@router.get("/{project_id}/assets", response_model=List[dict])
async def get_project_assets(project_id: str):
"""获取项目所有素材"""
assets = db.get_assets(project_id)
# 添加可访问的 URL
for asset in assets:
source_path, mapped_url = map_legacy_local_path(asset.get("local_path"))
if source_path and os.path.exists(source_path):
# 转换为相对路径 URL
local_path = Path(source_path)
if mapped_url:
asset["url"] = mapped_url
asset["local_path"] = source_path
elif str(config.OUTPUT_DIR) in str(local_path):
asset["url"] = f"/static/output/{local_path.name}"
asset["local_path"] = source_path
elif str(config.TEMP_DIR) in str(local_path):
asset["url"] = f"/static/temp/{local_path.name}"
asset["local_path"] = source_path
return assets
@router.post("/{project_id}/upload-images")
async def upload_product_images(
project_id: str,
files: List[UploadFile] = File(...)
):
"""上传商品主图"""
project = db.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
saved_paths = []
for file in files:
# 保存到 temp 目录
file_path = config.TEMP_DIR / file.filename
with open(file_path, "wb") as f:
content = await file.read()
f.write(content)
saved_paths.append(str(file_path))
# 更新项目 product_info
product_info = project.get("product_info", {})
product_info["uploaded_images"] = saved_paths
# 注意:这里需要重新保存整个项目,简化处理
# 实际应该添加一个 update_product_info 方法
return {
"message": f"上传成功 {len(saved_paths)} 张图片",
"paths": saved_paths
}
@router.post("/{project_id}/generate-script")
async def generate_script(
project_id: str,
model_provider: str = "shubiaobiao"
):
"""生成脚本"""
project = db.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
product_info = project.get("product_info", {})
image_paths = product_info.get("uploaded_images", [])
gen = ScriptGenerator()
script = gen.generate_script(
project["name"],
product_info,
image_paths,
model_provider=model_provider
)
if script:
db.update_project_script(project_id, script)
return {
"message": "脚本生成成功",
"script": script
}
else:
raise HTTPException(status_code=500, detail="脚本生成失败")
@router.post("/{project_id}/generate-images")
async def generate_images(
project_id: str,
model_provider: str = "shubiaobiao"
):
"""生成分镜图片"""
project = db.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
script_data = project.get("script_data")
if not script_data:
raise HTTPException(status_code=400, detail="请先生成脚本")
product_info = project.get("product_info", {})
base_imgs = product_info.get("uploaded_images", [])
if not base_imgs:
raise HTTPException(status_code=400, detail="请先上传商品主图")
img_gen = ImageGenerator()
scenes = script_data.get("scenes", [])
visual_anchor = script_data.get("visual_anchor", "")
results = {}
current_refs = list(base_imgs)
for scene in scenes:
scene_id = scene["id"]
img_path = img_gen.generate_single_scene_image(
scene=scene,
original_image_path=current_refs,
previous_image_path=None,
model_provider=model_provider,
visual_anchor=visual_anchor
)
if img_path:
results[scene_id] = img_path
current_refs.append(img_path)
db.save_asset(project_id, scene_id, "image", "completed", local_path=img_path)
db.update_project_status(project_id, "images_generated")
return {
"message": f"生成成功 {len(results)} 张图片",
"images": results
}
@router.post("/{project_id}/generate-videos")
async def generate_videos(project_id: str):
"""生成分镜视频"""
project = db.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
script_data = project.get("script_data")
if not script_data:
raise HTTPException(status_code=400, detail="请先生成脚本")
# 获取已生成的图片
assets = db.get_assets(project_id, "image")
scene_images = {a["scene_id"]: a["local_path"] for a in assets if a["status"] == "completed"}
if not scene_images:
raise HTTPException(status_code=400, detail="请先生成分镜图片")
vid_gen = VideoGenerator()
videos = vid_gen.generate_scene_videos(
project_id,
script_data,
scene_images
)
if videos:
for sid, path in videos.items():
db.save_asset(project_id, sid, "video", "completed", local_path=path)
db.update_project_status(project_id, "videos_generated")
return {
"message": f"生成成功 {len(videos)} 个视频",
"videos": videos
}
else:
raise HTTPException(status_code=500, detail="视频生成失败")