Files
Banana/nano_banana_client.py
2026-03-03 10:38:37 +08:00

344 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import base64
import json
import os
import time
from pathlib import Path
from typing import List, Optional
import requests
# NanoBanana2 生图接口,文档 https://api.wuyinkeji.com/doc/65
NANOBANANA_URL = "https://api.wuyinkeji.com/api/async/image_nanoBanana2"
# 全模型通用结果详情接口,文档见 https://api.wuyinkeji.com/doc/47
RESULT_DETAIL_URL = "https://api.wuyinkeji.com/api/async/detail"
CONFIG_FILE_NAME = "config_nano_banana.json"
def load_config(base_dir: Path) -> dict:
"""
从配置文件中读取默认参数(如 api_key、prompt 等)。
"""
config_path = base_dir / CONFIG_FILE_NAME
if not config_path.exists():
return {}
try:
with config_path.open("r", encoding="utf-8") as f:
return json.load(f) or {}
except Exception as e:
print(f"读取配置文件失败,将忽略配置文件。错误: {e}")
return {}
def load_api_key(cli_key: Optional[str], config: dict) -> str:
"""
优先级:命令行参数 > 配置文件 > 环境变量 WUYIN_API_KEY。
"""
key_from_config = config.get("api_key") if isinstance(config, dict) else None
key = cli_key or key_from_config or os.environ.get("WUYIN_API_KEY")
if not key:
raise SystemExit(
"未找到 API 密钥,请在 config_nano_banana.json 中填写 api_key"
"或通过参数 --api-key 传入,或在环境变量 WUYIN_API_KEY 中设置。"
)
return key
def collect_reference_images(input_dir: Path, max_files: int = 14) -> List[str]:
"""
从指定目录读取参考图,返回 base64 字符串数组。
"""
if not input_dir.exists():
return []
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"}
files = [p for p in sorted(input_dir.iterdir()) if p.suffix.lower() in exts]
files = files[:max_files]
encoded_list: List[str] = []
for p in files:
with p.open("rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
encoded_list.append(b64)
return encoded_list
def create_task(
api_key: str,
prompt: str,
size: str,
aspect_ratio: str,
ref_b64_list: List[str],
) -> str:
"""
调用 NanoBanana2 异步图片生成接口,返回任务 id。
使用 JSON 请求体urls 字段为数组,避免类型错误。
"""
headers = {
"Authorization": api_key,
"Content-Type": "application/json;charset=utf-8;",
}
data: dict = {
"prompt": prompt,
"size": size,
"aspectRatio": aspect_ratio,
"key": api_key,
}
if ref_b64_list:
# urls 为字符串数组,元素可以是 URL 或 Base64
data["urls"] = ref_b64_list
resp = requests.post(NANOBANANA_URL, json=data, headers=headers, timeout=30)
resp.raise_for_status()
payload = resp.json()
if payload.get("code") != 200:
raise RuntimeError(f"创建任务失败: {payload}")
data_obj = payload.get("data") or {}
task_id = data_obj.get("id")
if not task_id:
raise RuntimeError(f"返回数据中缺少任务 id: {payload}")
return task_id
def extract_image_urls(data_obj) -> List[str]:
"""
尝试从结果详情 data 字段中提取图片 URL兼容多种字段命名。
"""
if not data_obj:
return []
# 常见字段名兼容
candidates = []
if isinstance(data_obj, dict):
# 兼容多种字段命名,包括结果详情返回的 result 数组
for key in ("img_url", "img_urls", "image_urls", "urls", "images", "result"):
if key in data_obj and data_obj[key]:
val = data_obj[key]
if isinstance(val, str):
candidates.append(val)
elif isinstance(val, list):
candidates.extend(str(v) for v in val if v)
elif isinstance(data_obj, list):
candidates.extend(str(v) for v in data_obj if v)
# 去重
seen = set()
result: List[str] = []
for u in candidates:
if u not in seen:
seen.add(u)
result.append(u)
return result
def query_result(
api_key: str,
task_id: str,
poll_interval: float = 5.0,
max_wait: float = 300.0,
) -> List[str]:
"""
轮询结果详情接口,直到任务完成或超时,返回图片 URL 列表。
"""
start = time.time()
params = {"key": api_key, "id": task_id}
headers = {
"Authorization": api_key,
"Content-Type": "application/x-www-form-urlencoded;charset=utf-8;",
}
while True:
resp = requests.get(RESULT_DETAIL_URL, params=params, headers=headers, timeout=30)
resp.raise_for_status()
payload = resp.json()
if payload.get("code") != 200:
# 有些平台在排队/处理中阶段也返回 code=200这里仅在明确失败时直接抛出
msg = payload.get("msg") or "未知错误"
raise RuntimeError(f"查询任务失败: code={payload.get('code')}, msg={msg}")
data_obj = payload.get("data") or {}
status = data_obj.get("status")
# 约定0 排队中1 生成中2 成功3 失败
if status == 2:
urls = extract_image_urls(data_obj)
if urls:
return urls
# 没有找到 URL 但状态为成功,直接返回空列表交由上层处理
return []
elif status == 3:
reason = data_obj.get("fail_reason") or payload.get("msg") or "未知原因"
raise RuntimeError(f"任务生成失败: {reason}")
# 如果未提供 status但已经带有图片 URL也视为成功
urls = extract_image_urls(data_obj)
if urls:
return urls
if time.time() - start > max_wait:
raise TimeoutError("等待任务结果超时,请稍后在控制台或结果查询接口自行确认。")
time.sleep(poll_interval)
def download_images(urls: List[str], output_dir: Path, task_id: str) -> None:
"""
将图片 URL 下载到指定目录。
"""
output_dir.mkdir(parents=True, exist_ok=True)
for idx, url in enumerate(urls, start=1):
try:
resp = requests.get(url, stream=True, timeout=60)
resp.raise_for_status()
except Exception as e:
print(f"下载失败: {url} | 错误: {e}")
continue
# 根据 URL 粗略推断扩展名
ext = ".jpg"
for cand in (".png", ".jpeg", ".jpg", ".webp", ".bmp", ".gif"):
if cand in url.lower():
ext = cand
break
filename = output_dir / f"{task_id}_{idx}{ext}"
with filename.open("wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print(f"已保存: {filename}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"NanoBanana2 调用脚本:"
"读取 ./01 目录中的参考图,生成图片并保存到 ./save 目录。"
)
)
parser.add_argument(
"-p",
"--prompt",
required=False,
help="提示词,可在此修改,也可留空在运行时输入。",
)
parser.add_argument(
"--size",
default="1K",
choices=["1K", "2K", "4K"],
help="输出图像大小,支持 1K/2K/4K默认 1K。",
)
parser.add_argument(
"--aspect-ratio",
default="auto",
help="输出图像比例,如 auto、1:1、16:9 等,默认 auto。",
)
parser.add_argument(
"--input-dir",
default="01",
help="参考图目录,相对于当前脚本所在目录,默认 01。",
)
parser.add_argument(
"--output-dir",
default="save",
help="生成图片保存目录,相对于当前脚本所在目录,默认 save。",
)
parser.add_argument(
"--api-key",
help="速创API接口密钥如未提供将尝试从环境变量 WUYIN_API_KEY 读取。",
)
parser.add_argument(
"--poll-interval",
type=float,
default=5.0,
help="轮询结果详情接口的间隔秒数,默认 5 秒。",
)
parser.add_argument(
"--max-wait",
type=float,
default=300.0,
help="等待结果的最长时间(秒),默认 300 秒。",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
base_dir = Path(__file__).resolve().parent
config = load_config(base_dir)
api_key = load_api_key(args.api_key, config)
# 提示词优先级:命令行参数 > 配置文件 > 运行时输入
prompt = args.prompt or (config.get("prompt") if isinstance(config, dict) else None)
if not prompt:
prompt = input("请输入提示词(prompt)").strip()
if not prompt:
raise SystemExit("提示词不能为空。")
input_dir = base_dir / args.input_dir
output_dir = base_dir / args.output_dir
# 本地参考图(转为 base64
ref_b64_list = collect_reference_images(input_dir)
# 配置中的参考图 URL直接透传给接口
extra_urls_cfg = config.get("reference_urls") if isinstance(config, dict) else None
extra_urls: List[str] = []
if isinstance(extra_urls_cfg, str):
extra_urls = [extra_urls_cfg]
elif isinstance(extra_urls_cfg, list):
extra_urls = [str(u) for u in extra_urls_cfg if u]
combined_urls: List[str] = []
combined_urls.extend(ref_b64_list)
combined_urls.extend(extra_urls)
if ref_b64_list:
print(f"已从目录 {input_dir} 读取 {len(ref_b64_list)} 张本地参考图。")
if extra_urls:
print(f"已从配置文件读取 {len(extra_urls)} 个参考图 URL。")
if not combined_urls:
print(f"未找到任何参考图,将仅根据提示词生成。")
print("正在创建 NanoBanana2 任务...")
task_id = create_task(
api_key=api_key,
prompt=prompt,
size=args.size,
aspect_ratio=args.aspect_ratio,
ref_b64_list=combined_urls,
)
print(f"任务已创建,任务 id: {task_id}")
print("开始轮询任务结果(结果详情接口地址可根据官方文档调整)...")
urls = query_result(
api_key=api_key,
task_id=task_id,
poll_interval=args.poll_interval,
max_wait=args.max_wait,
)
if not urls:
print("任务完成,但未在结果中找到图片 URL请登录速创API控制台或检查结果详情接口。")
return
print(f"任务完成,共获取到 {len(urls)} 个图片 URL开始下载...")
download_images(urls, output_dir=output_dir, task_id=task_id)
print("全部处理完成。")
if __name__ == "__main__":
main()