代码提交

This commit is contained in:
2026-03-03 10:38:37 +08:00
parent e904d7af1d
commit 000d1ef1a8
22 changed files with 7043 additions and 0 deletions

343
nano_banana_client.py Normal file
View File

@@ -0,0 +1,343 @@
import argparse
import base64
import json
import os
import time
from pathlib import Path
from typing import List, Optional
import requests
# NanoBanana2 生图接口,文档 https://api.wuyinkeji.com/doc/65
NANOBANANA_URL = "https://api.wuyinkeji.com/api/async/image_nanoBanana2"
# 全模型通用结果详情接口,文档见 https://api.wuyinkeji.com/doc/47
RESULT_DETAIL_URL = "https://api.wuyinkeji.com/api/async/detail"
CONFIG_FILE_NAME = "config_nano_banana.json"
def load_config(base_dir: Path) -> dict:
"""
从配置文件中读取默认参数(如 api_key、prompt 等)。
"""
config_path = base_dir / CONFIG_FILE_NAME
if not config_path.exists():
return {}
try:
with config_path.open("r", encoding="utf-8") as f:
return json.load(f) or {}
except Exception as e:
print(f"读取配置文件失败,将忽略配置文件。错误: {e}")
return {}
def load_api_key(cli_key: Optional[str], config: dict) -> str:
"""
优先级:命令行参数 > 配置文件 > 环境变量 WUYIN_API_KEY。
"""
key_from_config = config.get("api_key") if isinstance(config, dict) else None
key = cli_key or key_from_config or os.environ.get("WUYIN_API_KEY")
if not key:
raise SystemExit(
"未找到 API 密钥,请在 config_nano_banana.json 中填写 api_key"
"或通过参数 --api-key 传入,或在环境变量 WUYIN_API_KEY 中设置。"
)
return key
def collect_reference_images(input_dir: Path, max_files: int = 14) -> List[str]:
"""
从指定目录读取参考图,返回 base64 字符串数组。
"""
if not input_dir.exists():
return []
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"}
files = [p for p in sorted(input_dir.iterdir()) if p.suffix.lower() in exts]
files = files[:max_files]
encoded_list: List[str] = []
for p in files:
with p.open("rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
encoded_list.append(b64)
return encoded_list
def create_task(
api_key: str,
prompt: str,
size: str,
aspect_ratio: str,
ref_b64_list: List[str],
) -> str:
"""
调用 NanoBanana2 异步图片生成接口,返回任务 id。
使用 JSON 请求体urls 字段为数组,避免类型错误。
"""
headers = {
"Authorization": api_key,
"Content-Type": "application/json;charset=utf-8;",
}
data: dict = {
"prompt": prompt,
"size": size,
"aspectRatio": aspect_ratio,
"key": api_key,
}
if ref_b64_list:
# urls 为字符串数组,元素可以是 URL 或 Base64
data["urls"] = ref_b64_list
resp = requests.post(NANOBANANA_URL, json=data, headers=headers, timeout=30)
resp.raise_for_status()
payload = resp.json()
if payload.get("code") != 200:
raise RuntimeError(f"创建任务失败: {payload}")
data_obj = payload.get("data") or {}
task_id = data_obj.get("id")
if not task_id:
raise RuntimeError(f"返回数据中缺少任务 id: {payload}")
return task_id
def extract_image_urls(data_obj) -> List[str]:
"""
尝试从结果详情 data 字段中提取图片 URL兼容多种字段命名。
"""
if not data_obj:
return []
# 常见字段名兼容
candidates = []
if isinstance(data_obj, dict):
# 兼容多种字段命名,包括结果详情返回的 result 数组
for key in ("img_url", "img_urls", "image_urls", "urls", "images", "result"):
if key in data_obj and data_obj[key]:
val = data_obj[key]
if isinstance(val, str):
candidates.append(val)
elif isinstance(val, list):
candidates.extend(str(v) for v in val if v)
elif isinstance(data_obj, list):
candidates.extend(str(v) for v in data_obj if v)
# 去重
seen = set()
result: List[str] = []
for u in candidates:
if u not in seen:
seen.add(u)
result.append(u)
return result
def query_result(
api_key: str,
task_id: str,
poll_interval: float = 5.0,
max_wait: float = 300.0,
) -> List[str]:
"""
轮询结果详情接口,直到任务完成或超时,返回图片 URL 列表。
"""
start = time.time()
params = {"key": api_key, "id": task_id}
headers = {
"Authorization": api_key,
"Content-Type": "application/x-www-form-urlencoded;charset=utf-8;",
}
while True:
resp = requests.get(RESULT_DETAIL_URL, params=params, headers=headers, timeout=30)
resp.raise_for_status()
payload = resp.json()
if payload.get("code") != 200:
# 有些平台在排队/处理中阶段也返回 code=200这里仅在明确失败时直接抛出
msg = payload.get("msg") or "未知错误"
raise RuntimeError(f"查询任务失败: code={payload.get('code')}, msg={msg}")
data_obj = payload.get("data") or {}
status = data_obj.get("status")
# 约定0 排队中1 生成中2 成功3 失败
if status == 2:
urls = extract_image_urls(data_obj)
if urls:
return urls
# 没有找到 URL 但状态为成功,直接返回空列表交由上层处理
return []
elif status == 3:
reason = data_obj.get("fail_reason") or payload.get("msg") or "未知原因"
raise RuntimeError(f"任务生成失败: {reason}")
# 如果未提供 status但已经带有图片 URL也视为成功
urls = extract_image_urls(data_obj)
if urls:
return urls
if time.time() - start > max_wait:
raise TimeoutError("等待任务结果超时,请稍后在控制台或结果查询接口自行确认。")
time.sleep(poll_interval)
def download_images(urls: List[str], output_dir: Path, task_id: str) -> None:
"""
将图片 URL 下载到指定目录。
"""
output_dir.mkdir(parents=True, exist_ok=True)
for idx, url in enumerate(urls, start=1):
try:
resp = requests.get(url, stream=True, timeout=60)
resp.raise_for_status()
except Exception as e:
print(f"下载失败: {url} | 错误: {e}")
continue
# 根据 URL 粗略推断扩展名
ext = ".jpg"
for cand in (".png", ".jpeg", ".jpg", ".webp", ".bmp", ".gif"):
if cand in url.lower():
ext = cand
break
filename = output_dir / f"{task_id}_{idx}{ext}"
with filename.open("wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print(f"已保存: {filename}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"NanoBanana2 调用脚本:"
"读取 ./01 目录中的参考图,生成图片并保存到 ./save 目录。"
)
)
parser.add_argument(
"-p",
"--prompt",
required=False,
help="提示词,可在此修改,也可留空在运行时输入。",
)
parser.add_argument(
"--size",
default="1K",
choices=["1K", "2K", "4K"],
help="输出图像大小,支持 1K/2K/4K默认 1K。",
)
parser.add_argument(
"--aspect-ratio",
default="auto",
help="输出图像比例,如 auto、1:1、16:9 等,默认 auto。",
)
parser.add_argument(
"--input-dir",
default="01",
help="参考图目录,相对于当前脚本所在目录,默认 01。",
)
parser.add_argument(
"--output-dir",
default="save",
help="生成图片保存目录,相对于当前脚本所在目录,默认 save。",
)
parser.add_argument(
"--api-key",
help="速创API接口密钥如未提供将尝试从环境变量 WUYIN_API_KEY 读取。",
)
parser.add_argument(
"--poll-interval",
type=float,
default=5.0,
help="轮询结果详情接口的间隔秒数,默认 5 秒。",
)
parser.add_argument(
"--max-wait",
type=float,
default=300.0,
help="等待结果的最长时间(秒),默认 300 秒。",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
base_dir = Path(__file__).resolve().parent
config = load_config(base_dir)
api_key = load_api_key(args.api_key, config)
# 提示词优先级:命令行参数 > 配置文件 > 运行时输入
prompt = args.prompt or (config.get("prompt") if isinstance(config, dict) else None)
if not prompt:
prompt = input("请输入提示词(prompt)").strip()
if not prompt:
raise SystemExit("提示词不能为空。")
input_dir = base_dir / args.input_dir
output_dir = base_dir / args.output_dir
# 本地参考图(转为 base64
ref_b64_list = collect_reference_images(input_dir)
# 配置中的参考图 URL直接透传给接口
extra_urls_cfg = config.get("reference_urls") if isinstance(config, dict) else None
extra_urls: List[str] = []
if isinstance(extra_urls_cfg, str):
extra_urls = [extra_urls_cfg]
elif isinstance(extra_urls_cfg, list):
extra_urls = [str(u) for u in extra_urls_cfg if u]
combined_urls: List[str] = []
combined_urls.extend(ref_b64_list)
combined_urls.extend(extra_urls)
if ref_b64_list:
print(f"已从目录 {input_dir} 读取 {len(ref_b64_list)} 张本地参考图。")
if extra_urls:
print(f"已从配置文件读取 {len(extra_urls)} 个参考图 URL。")
if not combined_urls:
print(f"未找到任何参考图,将仅根据提示词生成。")
print("正在创建 NanoBanana2 任务...")
task_id = create_task(
api_key=api_key,
prompt=prompt,
size=args.size,
aspect_ratio=args.aspect_ratio,
ref_b64_list=combined_urls,
)
print(f"任务已创建,任务 id: {task_id}")
print("开始轮询任务结果(结果详情接口地址可根据官方文档调整)...")
urls = query_result(
api_key=api_key,
task_id=task_id,
poll_interval=args.poll_interval,
max_wait=args.max_wait,
)
if not urls:
print("任务完成,但未在结果中找到图片 URL请登录速创API控制台或检查结果详情接口。")
return
print(f"任务完成,共获取到 {len(urls)} 个图片 URL开始下载...")
download_images(urls, output_dir=output_dir, task_id=task_id)
print("全部处理完成。")
if __name__ == "__main__":
main()