import argparse import base64 import json import os import time from pathlib import Path from typing import List, Optional import requests # NanoBanana2 生图接口,文档 https://api.wuyinkeji.com/doc/65 NANOBANANA_URL = "https://api.wuyinkeji.com/api/async/image_nanoBanana2" # 全模型通用结果详情接口,文档见 https://api.wuyinkeji.com/doc/47 RESULT_DETAIL_URL = "https://api.wuyinkeji.com/api/async/detail" CONFIG_FILE_NAME = "config_nano_banana.json" def load_config(base_dir: Path) -> dict: """ 从配置文件中读取默认参数(如 api_key、prompt 等)。 """ config_path = base_dir / CONFIG_FILE_NAME if not config_path.exists(): return {} try: with config_path.open("r", encoding="utf-8") as f: return json.load(f) or {} except Exception as e: print(f"读取配置文件失败,将忽略配置文件。错误: {e}") return {} def load_api_key(cli_key: Optional[str], config: dict) -> str: """ 优先级:命令行参数 > 配置文件 > 环境变量 WUYIN_API_KEY。 """ key_from_config = config.get("api_key") if isinstance(config, dict) else None key = cli_key or key_from_config or os.environ.get("WUYIN_API_KEY") if not key: raise SystemExit( "未找到 API 密钥,请在 config_nano_banana.json 中填写 api_key," "或通过参数 --api-key 传入,或在环境变量 WUYIN_API_KEY 中设置。" ) return key def collect_reference_images(input_dir: Path, max_files: int = 14) -> List[str]: """ 从指定目录读取参考图,返回 base64 字符串数组。 """ if not input_dir.exists(): return [] exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"} files = [p for p in sorted(input_dir.iterdir()) if p.suffix.lower() in exts] files = files[:max_files] encoded_list: List[str] = [] for p in files: with p.open("rb") as f: b64 = base64.b64encode(f.read()).decode("ascii") encoded_list.append(b64) return encoded_list def create_task( api_key: str, prompt: str, size: str, aspect_ratio: str, ref_b64_list: List[str], ) -> str: """ 调用 NanoBanana2 异步图片生成接口,返回任务 id。 使用 JSON 请求体,urls 字段为数组,避免类型错误。 """ headers = { "Authorization": api_key, "Content-Type": "application/json;charset=utf-8;", } data: dict = { "prompt": prompt, "size": size, "aspectRatio": aspect_ratio, "key": api_key, } if ref_b64_list: # urls 为字符串数组,元素可以是 URL 或 Base64 data["urls"] = ref_b64_list resp = requests.post(NANOBANANA_URL, json=data, headers=headers, timeout=30) resp.raise_for_status() payload = resp.json() if payload.get("code") != 200: raise RuntimeError(f"创建任务失败: {payload}") data_obj = payload.get("data") or {} task_id = data_obj.get("id") if not task_id: raise RuntimeError(f"返回数据中缺少任务 id: {payload}") return task_id def extract_image_urls(data_obj) -> List[str]: """ 尝试从结果详情 data 字段中提取图片 URL,兼容多种字段命名。 """ if not data_obj: return [] # 常见字段名兼容 candidates = [] if isinstance(data_obj, dict): # 兼容多种字段命名,包括结果详情返回的 result 数组 for key in ("img_url", "img_urls", "image_urls", "urls", "images", "result"): if key in data_obj and data_obj[key]: val = data_obj[key] if isinstance(val, str): candidates.append(val) elif isinstance(val, list): candidates.extend(str(v) for v in val if v) elif isinstance(data_obj, list): candidates.extend(str(v) for v in data_obj if v) # 去重 seen = set() result: List[str] = [] for u in candidates: if u not in seen: seen.add(u) result.append(u) return result def query_result( api_key: str, task_id: str, poll_interval: float = 5.0, max_wait: float = 300.0, ) -> List[str]: """ 轮询结果详情接口,直到任务完成或超时,返回图片 URL 列表。 """ start = time.time() params = {"key": api_key, "id": task_id} headers = { "Authorization": api_key, "Content-Type": "application/x-www-form-urlencoded;charset=utf-8;", } while True: resp = requests.get(RESULT_DETAIL_URL, params=params, headers=headers, timeout=30) resp.raise_for_status() payload = resp.json() if payload.get("code") != 200: # 有些平台在排队/处理中阶段也返回 code=200,这里仅在明确失败时直接抛出 msg = payload.get("msg") or "未知错误" raise RuntimeError(f"查询任务失败: code={payload.get('code')}, msg={msg}") data_obj = payload.get("data") or {} status = data_obj.get("status") # 约定:0 排队中,1 生成中,2 成功,3 失败 if status == 2: urls = extract_image_urls(data_obj) if urls: return urls # 没有找到 URL 但状态为成功,直接返回空列表交由上层处理 return [] elif status == 3: reason = data_obj.get("fail_reason") or payload.get("msg") or "未知原因" raise RuntimeError(f"任务生成失败: {reason}") # 如果未提供 status,但已经带有图片 URL,也视为成功 urls = extract_image_urls(data_obj) if urls: return urls if time.time() - start > max_wait: raise TimeoutError("等待任务结果超时,请稍后在控制台或结果查询接口自行确认。") time.sleep(poll_interval) def download_images(urls: List[str], output_dir: Path, task_id: str) -> None: """ 将图片 URL 下载到指定目录。 """ output_dir.mkdir(parents=True, exist_ok=True) for idx, url in enumerate(urls, start=1): try: resp = requests.get(url, stream=True, timeout=60) resp.raise_for_status() except Exception as e: print(f"下载失败: {url} | 错误: {e}") continue # 根据 URL 粗略推断扩展名 ext = ".jpg" for cand in (".png", ".jpeg", ".jpg", ".webp", ".bmp", ".gif"): if cand in url.lower(): ext = cand break filename = output_dir / f"{task_id}_{idx}{ext}" with filename.open("wb") as f: for chunk in resp.iter_content(chunk_size=8192): if chunk: f.write(chunk) print(f"已保存: {filename}") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=( "NanoBanana2 调用脚本:" "读取 ./01 目录中的参考图,生成图片并保存到 ./save 目录。" ) ) parser.add_argument( "-p", "--prompt", required=False, help="提示词,可在此修改,也可留空在运行时输入。", ) parser.add_argument( "--size", default="1K", choices=["1K", "2K", "4K"], help="输出图像大小,支持 1K/2K/4K,默认 1K。", ) parser.add_argument( "--aspect-ratio", default="auto", help="输出图像比例,如 auto、1:1、16:9 等,默认 auto。", ) parser.add_argument( "--input-dir", default="01", help="参考图目录,相对于当前脚本所在目录,默认 01。", ) parser.add_argument( "--output-dir", default="save", help="生成图片保存目录,相对于当前脚本所在目录,默认 save。", ) parser.add_argument( "--api-key", help="速创API接口密钥;如未提供,将尝试从环境变量 WUYIN_API_KEY 读取。", ) parser.add_argument( "--poll-interval", type=float, default=5.0, help="轮询结果详情接口的间隔秒数,默认 5 秒。", ) parser.add_argument( "--max-wait", type=float, default=300.0, help="等待结果的最长时间(秒),默认 300 秒。", ) return parser.parse_args() def main() -> None: args = parse_args() base_dir = Path(__file__).resolve().parent config = load_config(base_dir) api_key = load_api_key(args.api_key, config) # 提示词优先级:命令行参数 > 配置文件 > 运行时输入 prompt = args.prompt or (config.get("prompt") if isinstance(config, dict) else None) if not prompt: prompt = input("请输入提示词(prompt):").strip() if not prompt: raise SystemExit("提示词不能为空。") input_dir = base_dir / args.input_dir output_dir = base_dir / args.output_dir # 本地参考图(转为 base64) ref_b64_list = collect_reference_images(input_dir) # 配置中的参考图 URL(直接透传给接口) extra_urls_cfg = config.get("reference_urls") if isinstance(config, dict) else None extra_urls: List[str] = [] if isinstance(extra_urls_cfg, str): extra_urls = [extra_urls_cfg] elif isinstance(extra_urls_cfg, list): extra_urls = [str(u) for u in extra_urls_cfg if u] combined_urls: List[str] = [] combined_urls.extend(ref_b64_list) combined_urls.extend(extra_urls) if ref_b64_list: print(f"已从目录 {input_dir} 读取 {len(ref_b64_list)} 张本地参考图。") if extra_urls: print(f"已从配置文件读取 {len(extra_urls)} 个参考图 URL。") if not combined_urls: print(f"未找到任何参考图,将仅根据提示词生成。") print("正在创建 NanoBanana2 任务...") task_id = create_task( api_key=api_key, prompt=prompt, size=args.size, aspect_ratio=args.aspect_ratio, ref_b64_list=combined_urls, ) print(f"任务已创建,任务 id: {task_id}") print("开始轮询任务结果(结果详情接口地址可根据官方文档调整)...") urls = query_result( api_key=api_key, task_id=task_id, poll_interval=args.poll_interval, max_wait=args.max_wait, ) if not urls: print("任务完成,但未在结果中找到图片 URL,请登录速创API控制台或检查结果详情接口。") return print(f"任务完成,共获取到 {len(urls)} 个图片 URL,开始下载...") download_images(urls, output_dir=output_dir, task_id=task_id) print("全部处理完成。") if __name__ == "__main__": main()