344 lines
11 KiB
Python
344 lines
11 KiB
Python
import argparse
|
||
import base64
|
||
import json
|
||
import os
|
||
import time
|
||
from pathlib import Path
|
||
from typing import List, Optional
|
||
|
||
import requests
|
||
|
||
|
||
# NanoBanana2 生图接口,文档 https://api.wuyinkeji.com/doc/65
|
||
NANOBANANA_URL = "https://api.wuyinkeji.com/api/async/image_nanoBanana2"
|
||
# 全模型通用结果详情接口,文档见 https://api.wuyinkeji.com/doc/47
|
||
RESULT_DETAIL_URL = "https://api.wuyinkeji.com/api/async/detail"
|
||
CONFIG_FILE_NAME = "config_nano_banana.json"
|
||
|
||
|
||
def load_config(base_dir: Path) -> dict:
|
||
"""
|
||
从配置文件中读取默认参数(如 api_key、prompt 等)。
|
||
"""
|
||
config_path = base_dir / CONFIG_FILE_NAME
|
||
if not config_path.exists():
|
||
return {}
|
||
try:
|
||
with config_path.open("r", encoding="utf-8") as f:
|
||
return json.load(f) or {}
|
||
except Exception as e:
|
||
print(f"读取配置文件失败,将忽略配置文件。错误: {e}")
|
||
return {}
|
||
|
||
|
||
def load_api_key(cli_key: Optional[str], config: dict) -> str:
|
||
"""
|
||
优先级:命令行参数 > 配置文件 > 环境变量 WUYIN_API_KEY。
|
||
"""
|
||
key_from_config = config.get("api_key") if isinstance(config, dict) else None
|
||
key = cli_key or key_from_config or os.environ.get("WUYIN_API_KEY")
|
||
if not key:
|
||
raise SystemExit(
|
||
"未找到 API 密钥,请在 config_nano_banana.json 中填写 api_key,"
|
||
"或通过参数 --api-key 传入,或在环境变量 WUYIN_API_KEY 中设置。"
|
||
)
|
||
return key
|
||
|
||
|
||
def collect_reference_images(input_dir: Path, max_files: int = 14) -> List[str]:
|
||
"""
|
||
从指定目录读取参考图,返回 base64 字符串数组。
|
||
"""
|
||
if not input_dir.exists():
|
||
return []
|
||
|
||
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"}
|
||
files = [p for p in sorted(input_dir.iterdir()) if p.suffix.lower() in exts]
|
||
files = files[:max_files]
|
||
|
||
encoded_list: List[str] = []
|
||
for p in files:
|
||
with p.open("rb") as f:
|
||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||
encoded_list.append(b64)
|
||
return encoded_list
|
||
|
||
|
||
def create_task(
|
||
api_key: str,
|
||
prompt: str,
|
||
size: str,
|
||
aspect_ratio: str,
|
||
ref_b64_list: List[str],
|
||
) -> str:
|
||
"""
|
||
调用 NanoBanana2 异步图片生成接口,返回任务 id。
|
||
使用 JSON 请求体,urls 字段为数组,避免类型错误。
|
||
"""
|
||
headers = {
|
||
"Authorization": api_key,
|
||
"Content-Type": "application/json;charset=utf-8;",
|
||
}
|
||
|
||
data: dict = {
|
||
"prompt": prompt,
|
||
"size": size,
|
||
"aspectRatio": aspect_ratio,
|
||
"key": api_key,
|
||
}
|
||
|
||
if ref_b64_list:
|
||
# urls 为字符串数组,元素可以是 URL 或 Base64
|
||
data["urls"] = ref_b64_list
|
||
|
||
resp = requests.post(NANOBANANA_URL, json=data, headers=headers, timeout=30)
|
||
resp.raise_for_status()
|
||
payload = resp.json()
|
||
|
||
if payload.get("code") != 200:
|
||
raise RuntimeError(f"创建任务失败: {payload}")
|
||
|
||
data_obj = payload.get("data") or {}
|
||
task_id = data_obj.get("id")
|
||
if not task_id:
|
||
raise RuntimeError(f"返回数据中缺少任务 id: {payload}")
|
||
|
||
return task_id
|
||
|
||
|
||
def extract_image_urls(data_obj) -> List[str]:
|
||
"""
|
||
尝试从结果详情 data 字段中提取图片 URL,兼容多种字段命名。
|
||
"""
|
||
if not data_obj:
|
||
return []
|
||
|
||
# 常见字段名兼容
|
||
candidates = []
|
||
if isinstance(data_obj, dict):
|
||
# 兼容多种字段命名,包括结果详情返回的 result 数组
|
||
for key in ("img_url", "img_urls", "image_urls", "urls", "images", "result"):
|
||
if key in data_obj and data_obj[key]:
|
||
val = data_obj[key]
|
||
if isinstance(val, str):
|
||
candidates.append(val)
|
||
elif isinstance(val, list):
|
||
candidates.extend(str(v) for v in val if v)
|
||
elif isinstance(data_obj, list):
|
||
candidates.extend(str(v) for v in data_obj if v)
|
||
|
||
# 去重
|
||
seen = set()
|
||
result: List[str] = []
|
||
for u in candidates:
|
||
if u not in seen:
|
||
seen.add(u)
|
||
result.append(u)
|
||
return result
|
||
|
||
|
||
def query_result(
|
||
api_key: str,
|
||
task_id: str,
|
||
poll_interval: float = 5.0,
|
||
max_wait: float = 300.0,
|
||
) -> List[str]:
|
||
"""
|
||
轮询结果详情接口,直到任务完成或超时,返回图片 URL 列表。
|
||
"""
|
||
start = time.time()
|
||
params = {"key": api_key, "id": task_id}
|
||
headers = {
|
||
"Authorization": api_key,
|
||
"Content-Type": "application/x-www-form-urlencoded;charset=utf-8;",
|
||
}
|
||
|
||
while True:
|
||
resp = requests.get(RESULT_DETAIL_URL, params=params, headers=headers, timeout=30)
|
||
resp.raise_for_status()
|
||
payload = resp.json()
|
||
|
||
if payload.get("code") != 200:
|
||
# 有些平台在排队/处理中阶段也返回 code=200,这里仅在明确失败时直接抛出
|
||
msg = payload.get("msg") or "未知错误"
|
||
raise RuntimeError(f"查询任务失败: code={payload.get('code')}, msg={msg}")
|
||
|
||
data_obj = payload.get("data") or {}
|
||
status = data_obj.get("status")
|
||
|
||
# 约定:0 排队中,1 生成中,2 成功,3 失败
|
||
if status == 2:
|
||
urls = extract_image_urls(data_obj)
|
||
if urls:
|
||
return urls
|
||
# 没有找到 URL 但状态为成功,直接返回空列表交由上层处理
|
||
return []
|
||
elif status == 3:
|
||
reason = data_obj.get("fail_reason") or payload.get("msg") or "未知原因"
|
||
raise RuntimeError(f"任务生成失败: {reason}")
|
||
|
||
# 如果未提供 status,但已经带有图片 URL,也视为成功
|
||
urls = extract_image_urls(data_obj)
|
||
if urls:
|
||
return urls
|
||
|
||
if time.time() - start > max_wait:
|
||
raise TimeoutError("等待任务结果超时,请稍后在控制台或结果查询接口自行确认。")
|
||
|
||
time.sleep(poll_interval)
|
||
|
||
|
||
def download_images(urls: List[str], output_dir: Path, task_id: str) -> None:
|
||
"""
|
||
将图片 URL 下载到指定目录。
|
||
"""
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
for idx, url in enumerate(urls, start=1):
|
||
try:
|
||
resp = requests.get(url, stream=True, timeout=60)
|
||
resp.raise_for_status()
|
||
except Exception as e:
|
||
print(f"下载失败: {url} | 错误: {e}")
|
||
continue
|
||
|
||
# 根据 URL 粗略推断扩展名
|
||
ext = ".jpg"
|
||
for cand in (".png", ".jpeg", ".jpg", ".webp", ".bmp", ".gif"):
|
||
if cand in url.lower():
|
||
ext = cand
|
||
break
|
||
|
||
filename = output_dir / f"{task_id}_{idx}{ext}"
|
||
with filename.open("wb") as f:
|
||
for chunk in resp.iter_content(chunk_size=8192):
|
||
if chunk:
|
||
f.write(chunk)
|
||
|
||
print(f"已保存: {filename}")
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
parser = argparse.ArgumentParser(
|
||
description=(
|
||
"NanoBanana2 调用脚本:"
|
||
"读取 ./01 目录中的参考图,生成图片并保存到 ./save 目录。"
|
||
)
|
||
)
|
||
parser.add_argument(
|
||
"-p",
|
||
"--prompt",
|
||
required=False,
|
||
help="提示词,可在此修改,也可留空在运行时输入。",
|
||
)
|
||
parser.add_argument(
|
||
"--size",
|
||
default="1K",
|
||
choices=["1K", "2K", "4K"],
|
||
help="输出图像大小,支持 1K/2K/4K,默认 1K。",
|
||
)
|
||
parser.add_argument(
|
||
"--aspect-ratio",
|
||
default="auto",
|
||
help="输出图像比例,如 auto、1:1、16:9 等,默认 auto。",
|
||
)
|
||
parser.add_argument(
|
||
"--input-dir",
|
||
default="01",
|
||
help="参考图目录,相对于当前脚本所在目录,默认 01。",
|
||
)
|
||
parser.add_argument(
|
||
"--output-dir",
|
||
default="save",
|
||
help="生成图片保存目录,相对于当前脚本所在目录,默认 save。",
|
||
)
|
||
parser.add_argument(
|
||
"--api-key",
|
||
help="速创API接口密钥;如未提供,将尝试从环境变量 WUYIN_API_KEY 读取。",
|
||
)
|
||
parser.add_argument(
|
||
"--poll-interval",
|
||
type=float,
|
||
default=5.0,
|
||
help="轮询结果详情接口的间隔秒数,默认 5 秒。",
|
||
)
|
||
parser.add_argument(
|
||
"--max-wait",
|
||
type=float,
|
||
default=300.0,
|
||
help="等待结果的最长时间(秒),默认 300 秒。",
|
||
)
|
||
return parser.parse_args()
|
||
|
||
|
||
def main() -> None:
|
||
args = parse_args()
|
||
|
||
base_dir = Path(__file__).resolve().parent
|
||
config = load_config(base_dir)
|
||
|
||
api_key = load_api_key(args.api_key, config)
|
||
|
||
# 提示词优先级:命令行参数 > 配置文件 > 运行时输入
|
||
prompt = args.prompt or (config.get("prompt") if isinstance(config, dict) else None)
|
||
if not prompt:
|
||
prompt = input("请输入提示词(prompt):").strip()
|
||
if not prompt:
|
||
raise SystemExit("提示词不能为空。")
|
||
|
||
input_dir = base_dir / args.input_dir
|
||
output_dir = base_dir / args.output_dir
|
||
|
||
# 本地参考图(转为 base64)
|
||
ref_b64_list = collect_reference_images(input_dir)
|
||
|
||
# 配置中的参考图 URL(直接透传给接口)
|
||
extra_urls_cfg = config.get("reference_urls") if isinstance(config, dict) else None
|
||
extra_urls: List[str] = []
|
||
if isinstance(extra_urls_cfg, str):
|
||
extra_urls = [extra_urls_cfg]
|
||
elif isinstance(extra_urls_cfg, list):
|
||
extra_urls = [str(u) for u in extra_urls_cfg if u]
|
||
|
||
combined_urls: List[str] = []
|
||
combined_urls.extend(ref_b64_list)
|
||
combined_urls.extend(extra_urls)
|
||
|
||
if ref_b64_list:
|
||
print(f"已从目录 {input_dir} 读取 {len(ref_b64_list)} 张本地参考图。")
|
||
if extra_urls:
|
||
print(f"已从配置文件读取 {len(extra_urls)} 个参考图 URL。")
|
||
if not combined_urls:
|
||
print(f"未找到任何参考图,将仅根据提示词生成。")
|
||
|
||
print("正在创建 NanoBanana2 任务...")
|
||
task_id = create_task(
|
||
api_key=api_key,
|
||
prompt=prompt,
|
||
size=args.size,
|
||
aspect_ratio=args.aspect_ratio,
|
||
ref_b64_list=combined_urls,
|
||
)
|
||
print(f"任务已创建,任务 id: {task_id}")
|
||
|
||
print("开始轮询任务结果(结果详情接口地址可根据官方文档调整)...")
|
||
urls = query_result(
|
||
api_key=api_key,
|
||
task_id=task_id,
|
||
poll_interval=args.poll_interval,
|
||
max_wait=args.max_wait,
|
||
)
|
||
|
||
if not urls:
|
||
print("任务完成,但未在结果中找到图片 URL,请登录速创API控制台或检查结果详情接口。")
|
||
return
|
||
|
||
print(f"任务完成,共获取到 {len(urls)} 个图片 URL,开始下载...")
|
||
download_images(urls, output_dir=output_dir, task_id=task_id)
|
||
print("全部处理完成。")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|