316 lines
9.2 KiB
Python
316 lines
9.2 KiB
Python
"""
|
||
项目管理 API 路由
|
||
提供项目 CRUD 和状态查询
|
||
"""
|
||
import os
|
||
import time
|
||
import logging
|
||
from typing import List, Optional, Any, Dict
|
||
from pathlib import Path
|
||
|
||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||
from pydantic import BaseModel
|
||
|
||
import config
|
||
from modules.db_manager import db
|
||
from modules.script_gen import ScriptGenerator
|
||
from modules.image_gen import ImageGenerator
|
||
from modules.video_gen import VideoGenerator
|
||
from modules.legacy_path_mapper import map_legacy_local_path
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter()
|
||
|
||
|
||
# ============================================================
|
||
# Pydantic Models
|
||
# ============================================================
|
||
|
||
class ProductInfo(BaseModel):
|
||
category: str = ""
|
||
price: str = ""
|
||
tags: str = ""
|
||
params: str = ""
|
||
style_hint: str = ""
|
||
|
||
|
||
class ProjectCreateRequest(BaseModel):
|
||
name: str
|
||
product_info: ProductInfo
|
||
|
||
|
||
class ProjectResponse(BaseModel):
|
||
id: str
|
||
name: str
|
||
status: str
|
||
product_info: dict
|
||
script_data: Optional[dict] = None
|
||
created_at: float
|
||
updated_at: float
|
||
|
||
|
||
class ProjectListItem(BaseModel):
|
||
id: str
|
||
name: str
|
||
status: str
|
||
updated_at: float
|
||
|
||
|
||
class PromptDebugResponse(BaseModel):
|
||
"""
|
||
导出当次脚本生成实际使用的 prompt(来自 projects.script_data._debug)。
|
||
注意:默认不返回 raw_output,避免体积过大/泄露无关信息。
|
||
"""
|
||
project_id: str
|
||
provider: Optional[str] = None
|
||
system_prompt: Optional[str] = None
|
||
user_prompt: Optional[str] = None
|
||
image_urls: Optional[List[str]] = None
|
||
raw_output: Optional[str] = None
|
||
|
||
|
||
class SceneAssetResponse(BaseModel):
|
||
scene_id: int
|
||
asset_type: str
|
||
status: str
|
||
local_path: Optional[str] = None
|
||
url: Optional[str] = None
|
||
|
||
|
||
# ============================================================
|
||
# API Endpoints
|
||
# ============================================================
|
||
|
||
@router.get("", response_model=List[ProjectListItem])
|
||
async def list_projects():
|
||
"""获取所有项目列表"""
|
||
projects = db.list_projects()
|
||
return projects
|
||
|
||
|
||
@router.post("", response_model=dict)
|
||
async def create_project(request: ProjectCreateRequest):
|
||
"""创建新项目"""
|
||
project_id = f"PROJ-{int(time.time())}"
|
||
|
||
product_info_dict = request.product_info.model_dump()
|
||
db.create_project(project_id, request.name, product_info_dict)
|
||
|
||
return {
|
||
"id": project_id,
|
||
"message": "项目创建成功"
|
||
}
|
||
|
||
|
||
@router.get("/{project_id}", response_model=ProjectResponse)
|
||
async def get_project(project_id: str):
|
||
"""获取项目详情"""
|
||
project = db.get_project(project_id)
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="项目不存在")
|
||
return project
|
||
|
||
|
||
@router.get("/{project_id}/prompt-debug", response_model=PromptDebugResponse)
|
||
async def get_prompt_debug(project_id: str, include_raw_output: bool = False):
|
||
"""
|
||
获取该项目“生成脚本时实际使用”的 system_prompt / user_prompt。
|
||
- 数据来源:projects.script_data._debug(由 ScriptGenerator 在生成脚本时写入)
|
||
- include_raw_output=true 时返回 raw_output(可能很大)
|
||
"""
|
||
project = db.get_project(project_id)
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="项目不存在")
|
||
|
||
script_data: Dict[str, Any] = project.get("script_data") or {}
|
||
debug: Dict[str, Any] = script_data.get("_debug") or {}
|
||
|
||
resp = {
|
||
"project_id": project_id,
|
||
"provider": debug.get("provider"),
|
||
"system_prompt": debug.get("system_prompt"),
|
||
"user_prompt": debug.get("user_prompt"),
|
||
"image_urls": debug.get("image_urls"),
|
||
}
|
||
if include_raw_output:
|
||
resp["raw_output"] = debug.get("raw_output")
|
||
return resp
|
||
|
||
|
||
@router.get("/{project_id}/assets", response_model=List[dict])
|
||
async def get_project_assets(project_id: str):
|
||
"""获取项目所有素材"""
|
||
assets = db.get_assets(project_id)
|
||
|
||
# 添加可访问的 URL
|
||
for asset in assets:
|
||
source_path, mapped_url = map_legacy_local_path(asset.get("local_path"))
|
||
if source_path and os.path.exists(source_path):
|
||
# 转换为相对路径 URL
|
||
local_path = Path(source_path)
|
||
if mapped_url:
|
||
asset["url"] = mapped_url
|
||
asset["local_path"] = source_path
|
||
elif str(config.OUTPUT_DIR) in str(local_path):
|
||
asset["url"] = f"/static/output/{local_path.name}"
|
||
asset["local_path"] = source_path
|
||
elif str(config.TEMP_DIR) in str(local_path):
|
||
asset["url"] = f"/static/temp/{local_path.name}"
|
||
asset["local_path"] = source_path
|
||
|
||
return assets
|
||
|
||
|
||
@router.post("/{project_id}/upload-images")
|
||
async def upload_product_images(
|
||
project_id: str,
|
||
files: List[UploadFile] = File(...)
|
||
):
|
||
"""上传商品主图"""
|
||
project = db.get_project(project_id)
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="项目不存在")
|
||
|
||
saved_paths = []
|
||
for file in files:
|
||
# 保存到 temp 目录
|
||
file_path = config.TEMP_DIR / file.filename
|
||
with open(file_path, "wb") as f:
|
||
content = await file.read()
|
||
f.write(content)
|
||
saved_paths.append(str(file_path))
|
||
|
||
# 更新项目 product_info
|
||
product_info = project.get("product_info", {})
|
||
product_info["uploaded_images"] = saved_paths
|
||
|
||
# 注意:这里需要重新保存整个项目,简化处理
|
||
# 实际应该添加一个 update_product_info 方法
|
||
|
||
return {
|
||
"message": f"上传成功 {len(saved_paths)} 张图片",
|
||
"paths": saved_paths
|
||
}
|
||
|
||
|
||
@router.post("/{project_id}/generate-script")
|
||
async def generate_script(
|
||
project_id: str,
|
||
model_provider: str = "shubiaobiao"
|
||
):
|
||
"""生成脚本"""
|
||
project = db.get_project(project_id)
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="项目不存在")
|
||
|
||
product_info = project.get("product_info", {})
|
||
image_paths = product_info.get("uploaded_images", [])
|
||
|
||
gen = ScriptGenerator()
|
||
script = gen.generate_script(
|
||
project["name"],
|
||
product_info,
|
||
image_paths,
|
||
model_provider=model_provider
|
||
)
|
||
|
||
if script:
|
||
db.update_project_script(project_id, script)
|
||
return {
|
||
"message": "脚本生成成功",
|
||
"script": script
|
||
}
|
||
else:
|
||
raise HTTPException(status_code=500, detail="脚本生成失败")
|
||
|
||
|
||
@router.post("/{project_id}/generate-images")
|
||
async def generate_images(
|
||
project_id: str,
|
||
model_provider: str = "shubiaobiao"
|
||
):
|
||
"""生成分镜图片"""
|
||
project = db.get_project(project_id)
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="项目不存在")
|
||
|
||
script_data = project.get("script_data")
|
||
if not script_data:
|
||
raise HTTPException(status_code=400, detail="请先生成脚本")
|
||
|
||
product_info = project.get("product_info", {})
|
||
base_imgs = product_info.get("uploaded_images", [])
|
||
|
||
if not base_imgs:
|
||
raise HTTPException(status_code=400, detail="请先上传商品主图")
|
||
|
||
img_gen = ImageGenerator()
|
||
scenes = script_data.get("scenes", [])
|
||
visual_anchor = script_data.get("visual_anchor", "")
|
||
|
||
results = {}
|
||
current_refs = list(base_imgs)
|
||
|
||
for scene in scenes:
|
||
scene_id = scene["id"]
|
||
img_path = img_gen.generate_single_scene_image(
|
||
scene=scene,
|
||
original_image_path=current_refs,
|
||
previous_image_path=None,
|
||
model_provider=model_provider,
|
||
visual_anchor=visual_anchor
|
||
)
|
||
|
||
if img_path:
|
||
results[scene_id] = img_path
|
||
current_refs.append(img_path)
|
||
db.save_asset(project_id, scene_id, "image", "completed", local_path=img_path)
|
||
|
||
db.update_project_status(project_id, "images_generated")
|
||
|
||
return {
|
||
"message": f"生成成功 {len(results)} 张图片",
|
||
"images": results
|
||
}
|
||
|
||
|
||
@router.post("/{project_id}/generate-videos")
|
||
async def generate_videos(project_id: str):
|
||
"""生成分镜视频"""
|
||
project = db.get_project(project_id)
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="项目不存在")
|
||
|
||
script_data = project.get("script_data")
|
||
if not script_data:
|
||
raise HTTPException(status_code=400, detail="请先生成脚本")
|
||
|
||
# 获取已生成的图片
|
||
assets = db.get_assets(project_id, "image")
|
||
scene_images = {a["scene_id"]: a["local_path"] for a in assets if a["status"] == "completed"}
|
||
|
||
if not scene_images:
|
||
raise HTTPException(status_code=400, detail="请先生成分镜图片")
|
||
|
||
vid_gen = VideoGenerator()
|
||
videos = vid_gen.generate_scene_videos(
|
||
project_id,
|
||
script_data,
|
||
scene_images
|
||
)
|
||
|
||
if videos:
|
||
for sid, path in videos.items():
|
||
db.save_asset(project_id, sid, "video", "completed", local_path=path)
|
||
db.update_project_status(project_id, "videos_generated")
|
||
|
||
return {
|
||
"message": f"生成成功 {len(videos)} 个视频",
|
||
"videos": videos
|
||
}
|
||
else:
|
||
raise HTTPException(status_code=500, detail="视频生成失败")
|
||
|
||
|
||
|