""" 项目管理 API 路由 提供项目 CRUD 和状态查询 """ import os import time import logging from typing import List, Optional, Any, Dict from pathlib import Path from fastapi import APIRouter, HTTPException, UploadFile, File, Form from pydantic import BaseModel import config from modules.db_manager import db from modules.script_gen import ScriptGenerator from modules.image_gen import ImageGenerator from modules.video_gen import VideoGenerator from modules.legacy_path_mapper import map_legacy_local_path logger = logging.getLogger(__name__) router = APIRouter() # ============================================================ # Pydantic Models # ============================================================ class ProductInfo(BaseModel): category: str = "" price: str = "" tags: str = "" params: str = "" style_hint: str = "" class ProjectCreateRequest(BaseModel): name: str product_info: ProductInfo class ProjectResponse(BaseModel): id: str name: str status: str product_info: dict script_data: Optional[dict] = None created_at: float updated_at: float class ProjectListItem(BaseModel): id: str name: str status: str updated_at: float class PromptDebugResponse(BaseModel): """ 导出当次脚本生成实际使用的 prompt(来自 projects.script_data._debug)。 注意:默认不返回 raw_output,避免体积过大/泄露无关信息。 """ project_id: str provider: Optional[str] = None system_prompt: Optional[str] = None user_prompt: Optional[str] = None image_urls: Optional[List[str]] = None raw_output: Optional[str] = None class SceneAssetResponse(BaseModel): scene_id: int asset_type: str status: str local_path: Optional[str] = None url: Optional[str] = None # ============================================================ # API Endpoints # ============================================================ @router.get("", response_model=List[ProjectListItem]) async def list_projects(): """获取所有项目列表""" projects = db.list_projects() return projects @router.post("", response_model=dict) async def create_project(request: ProjectCreateRequest): """创建新项目""" project_id = f"PROJ-{int(time.time())}" product_info_dict = request.product_info.model_dump() db.create_project(project_id, request.name, product_info_dict) return { "id": project_id, "message": "项目创建成功" } @router.get("/{project_id}", response_model=ProjectResponse) async def get_project(project_id: str): """获取项目详情""" project = db.get_project(project_id) if not project: raise HTTPException(status_code=404, detail="项目不存在") return project @router.get("/{project_id}/prompt-debug", response_model=PromptDebugResponse) async def get_prompt_debug(project_id: str, include_raw_output: bool = False): """ 获取该项目“生成脚本时实际使用”的 system_prompt / user_prompt。 - 数据来源:projects.script_data._debug(由 ScriptGenerator 在生成脚本时写入) - include_raw_output=true 时返回 raw_output(可能很大) """ project = db.get_project(project_id) if not project: raise HTTPException(status_code=404, detail="项目不存在") script_data: Dict[str, Any] = project.get("script_data") or {} debug: Dict[str, Any] = script_data.get("_debug") or {} resp = { "project_id": project_id, "provider": debug.get("provider"), "system_prompt": debug.get("system_prompt"), "user_prompt": debug.get("user_prompt"), "image_urls": debug.get("image_urls"), } if include_raw_output: resp["raw_output"] = debug.get("raw_output") return resp @router.get("/{project_id}/assets", response_model=List[dict]) async def get_project_assets(project_id: str): """获取项目所有素材""" assets = db.get_assets(project_id) # 添加可访问的 URL for asset in assets: source_path, mapped_url = map_legacy_local_path(asset.get("local_path")) if source_path and os.path.exists(source_path): # 转换为相对路径 URL local_path = Path(source_path) if mapped_url: asset["url"] = mapped_url asset["local_path"] = source_path elif str(config.OUTPUT_DIR) in str(local_path): asset["url"] = f"/static/output/{local_path.name}" asset["local_path"] = source_path elif str(config.TEMP_DIR) in str(local_path): asset["url"] = f"/static/temp/{local_path.name}" asset["local_path"] = source_path return assets @router.post("/{project_id}/upload-images") async def upload_product_images( project_id: str, files: List[UploadFile] = File(...) ): """上传商品主图""" project = db.get_project(project_id) if not project: raise HTTPException(status_code=404, detail="项目不存在") saved_paths = [] for file in files: # 保存到 temp 目录 file_path = config.TEMP_DIR / file.filename with open(file_path, "wb") as f: content = await file.read() f.write(content) saved_paths.append(str(file_path)) # 更新项目 product_info product_info = project.get("product_info", {}) product_info["uploaded_images"] = saved_paths # 注意:这里需要重新保存整个项目,简化处理 # 实际应该添加一个 update_product_info 方法 return { "message": f"上传成功 {len(saved_paths)} 张图片", "paths": saved_paths } @router.post("/{project_id}/generate-script") async def generate_script( project_id: str, model_provider: str = "shubiaobiao" ): """生成脚本""" project = db.get_project(project_id) if not project: raise HTTPException(status_code=404, detail="项目不存在") product_info = project.get("product_info", {}) image_paths = product_info.get("uploaded_images", []) gen = ScriptGenerator() script = gen.generate_script( project["name"], product_info, image_paths, model_provider=model_provider ) if script: db.update_project_script(project_id, script) return { "message": "脚本生成成功", "script": script } else: raise HTTPException(status_code=500, detail="脚本生成失败") @router.post("/{project_id}/generate-images") async def generate_images( project_id: str, model_provider: str = "shubiaobiao" ): """生成分镜图片""" project = db.get_project(project_id) if not project: raise HTTPException(status_code=404, detail="项目不存在") script_data = project.get("script_data") if not script_data: raise HTTPException(status_code=400, detail="请先生成脚本") product_info = project.get("product_info", {}) base_imgs = product_info.get("uploaded_images", []) if not base_imgs: raise HTTPException(status_code=400, detail="请先上传商品主图") img_gen = ImageGenerator() scenes = script_data.get("scenes", []) visual_anchor = script_data.get("visual_anchor", "") results = {} current_refs = list(base_imgs) for scene in scenes: scene_id = scene["id"] img_path = img_gen.generate_single_scene_image( scene=scene, original_image_path=current_refs, previous_image_path=None, model_provider=model_provider, visual_anchor=visual_anchor ) if img_path: results[scene_id] = img_path current_refs.append(img_path) db.save_asset(project_id, scene_id, "image", "completed", local_path=img_path) db.update_project_status(project_id, "images_generated") return { "message": f"生成成功 {len(results)} 张图片", "images": results } @router.post("/{project_id}/generate-videos") async def generate_videos(project_id: str): """生成分镜视频""" project = db.get_project(project_id) if not project: raise HTTPException(status_code=404, detail="项目不存在") script_data = project.get("script_data") if not script_data: raise HTTPException(status_code=400, detail="请先生成脚本") # 获取已生成的图片 assets = db.get_assets(project_id, "image") scene_images = {a["scene_id"]: a["local_path"] for a in assets if a["status"] == "completed"} if not scene_images: raise HTTPException(status_code=400, detail="请先生成分镜图片") vid_gen = VideoGenerator() videos = vid_gen.generate_scene_videos( project_id, script_data, scene_images ) if videos: for sid, path in videos.items(): db.save_asset(project_id, sid, "video", "completed", local_path=path) db.update_project_status(project_id, "videos_generated") return { "message": f"生成成功 {len(videos)} 个视频", "videos": videos } else: raise HTTPException(status_code=500, detail="视频生成失败")