390 lines
11 KiB
JavaScript
390 lines
11 KiB
JavaScript
const fs = require("fs");
|
||
const path = require("path");
|
||
const axios = require("axios");
|
||
const { S3Client, PutObjectCommand, DeleteObjectsCommand } = require("@aws-sdk/client-s3");
|
||
|
||
// NanoBanana2 生图接口,文档 https://api.wuyinkeji.com/doc/65
|
||
const NANOBANANA_URL = "https://api.wuyinkeji.com/api/async/image_nanoBanana2";
|
||
// 全模型通用结果详情接口,文档见 https://api.wuyinkeji.com/doc/47
|
||
const RESULT_DETAIL_URL = "https://api.wuyinkeji.com/api/async/detail";
|
||
const CONFIG_FILE_NAME = "config_nano_banana.json";
|
||
|
||
function loadConfig(baseDir) {
|
||
const configPath = path.join(baseDir, CONFIG_FILE_NAME);
|
||
if (!fs.existsSync(configPath)) return {};
|
||
try {
|
||
const raw = fs.readFileSync(configPath, "utf8");
|
||
return JSON.parse(raw) || {};
|
||
} catch (e) {
|
||
console.log("读取配置文件失败,将忽略配置文件。错误:", e.message);
|
||
return {};
|
||
}
|
||
}
|
||
|
||
function ensureApiKey(config) {
|
||
const key = config.api_key || process.env.WUYIN_API_KEY;
|
||
if (!key) {
|
||
throw new Error(
|
||
"未找到 API 密钥,请在 config_nano_banana.json 中填写 api_key," +
|
||
"或设置环境变量 WUYIN_API_KEY。"
|
||
);
|
||
}
|
||
return key;
|
||
}
|
||
|
||
function ensurePrompt(config) {
|
||
const prompt = config.prompt;
|
||
if (!prompt || !prompt.trim()) {
|
||
throw new Error(
|
||
"未在 config_nano_banana.json 中找到有效的 prompt,请先在文件中填写默认提示词。"
|
||
);
|
||
}
|
||
return prompt.trim();
|
||
}
|
||
|
||
/** 检查是否配置了 R2 图床(用于 01 参考图上传) */
|
||
function isR2Configured(config) {
|
||
return (
|
||
config &&
|
||
config.r2_account_id &&
|
||
config.r2_access_key_id &&
|
||
config.r2_secret_access_key &&
|
||
config.r2_bucket &&
|
||
config.r2_public_url
|
||
);
|
||
}
|
||
|
||
/** 创建 R2(S3 兼容)客户端 */
|
||
function createR2Client(config) {
|
||
const accountId = config.r2_account_id;
|
||
const endpoint = `https://${accountId}.r2.cloudflarestorage.com`;
|
||
return new S3Client({
|
||
region: "auto",
|
||
endpoint,
|
||
// 对不少 S3 兼容服务更稳:强制使用 path-style
|
||
// 形如 https://<accountId>.r2.cloudflarestorage.com/<bucket>/<key>
|
||
forcePathStyle: true,
|
||
credentials: {
|
||
accessKeyId: config.r2_access_key_id,
|
||
secretAccessKey: config.r2_secret_access_key,
|
||
},
|
||
});
|
||
}
|
||
|
||
/**
|
||
* 将 01 目录中的参考图上传到 R2,返回可公网访问的 URL 列表及本次上传的 key 列表(用于事后删除)
|
||
*/
|
||
async function uploadRefImagesToR2({ client, bucket, publicBaseUrl, inputDir, maxFiles = 14 }) {
|
||
if (!fs.existsSync(inputDir)) return { urls: [], keys: [] };
|
||
|
||
const exts = new Set([".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"]);
|
||
const files = fs
|
||
.readdirSync(inputDir)
|
||
.map((name) => path.join(inputDir, name))
|
||
.filter((p) => exts.has(path.extname(p).toLowerCase()))
|
||
.slice(0, maxFiles);
|
||
|
||
const runId = Date.now().toString(36) + "-" + Math.random().toString(36).slice(2, 8);
|
||
const prefix = `ref/${runId}`;
|
||
const urls = [];
|
||
const keys = [];
|
||
|
||
for (const filePath of files) {
|
||
const name = path.basename(filePath);
|
||
const key = `${prefix}/${name}`;
|
||
const body = fs.readFileSync(filePath);
|
||
const ext = path.extname(name).toLowerCase();
|
||
const contentType = { ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", ".webp": "image/webp", ".bmp": "image/bmp", ".gif": "image/gif" }[ext] || "application/octet-stream";
|
||
|
||
await client.send(
|
||
new PutObjectCommand({
|
||
Bucket: bucket,
|
||
Key: key,
|
||
Body: body,
|
||
ContentType: contentType,
|
||
})
|
||
);
|
||
const base = publicBaseUrl.replace(/\/$/, "");
|
||
urls.push(`${base}/${key}`);
|
||
keys.push(key);
|
||
}
|
||
return { urls, keys };
|
||
}
|
||
|
||
/** 从 R2 删除指定 key 的对象 */
|
||
async function deleteR2Objects({ client, bucket, keys }) {
|
||
if (!keys || keys.length === 0) return;
|
||
await client.send(
|
||
new DeleteObjectsCommand({
|
||
Bucket: bucket,
|
||
Delete: { Objects: keys.map((Key) => ({ Key })) },
|
||
})
|
||
);
|
||
}
|
||
|
||
async function collectReferenceImages(inputDir, maxFiles = 14) {
|
||
if (!fs.existsSync(inputDir)) return [];
|
||
|
||
const exts = new Set([".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"]);
|
||
const files = fs
|
||
.readdirSync(inputDir)
|
||
.map((name) => path.join(inputDir, name))
|
||
.filter((p) => exts.has(path.extname(p).toLowerCase()))
|
||
.slice(0, maxFiles);
|
||
|
||
const result = [];
|
||
for (const file of files) {
|
||
const buf = fs.readFileSync(file);
|
||
result.push(buf.toString("base64"));
|
||
}
|
||
return result;
|
||
}
|
||
|
||
async function createTask({ apiKey, prompt, size, aspectRatio, refB64List }) {
|
||
const headers = {
|
||
Authorization: apiKey,
|
||
"Content-Type": "application/json;charset=utf-8;",
|
||
};
|
||
|
||
const body = {
|
||
prompt,
|
||
size,
|
||
aspectRatio,
|
||
key: apiKey,
|
||
};
|
||
if (refB64List && refB64List.length > 0) {
|
||
body.urls = refB64List;
|
||
}
|
||
|
||
const resp = await axios.post(NANOBANANA_URL, body, {
|
||
headers,
|
||
timeout: 30000,
|
||
});
|
||
const payload = resp.data;
|
||
|
||
if (!payload || payload.code !== 200) {
|
||
throw new Error("创建任务失败: " + JSON.stringify(payload));
|
||
}
|
||
|
||
const data = payload.data || {};
|
||
const taskId = data.id;
|
||
if (!taskId) {
|
||
throw new Error("返回数据中缺少任务 id: " + JSON.stringify(payload));
|
||
}
|
||
return taskId;
|
||
}
|
||
|
||
function extractImageUrls(dataObj) {
|
||
if (!dataObj) return [];
|
||
|
||
const candidates = [];
|
||
if (Array.isArray(dataObj)) {
|
||
for (const v of dataObj) {
|
||
if (v) candidates.push(String(v));
|
||
}
|
||
} else if (typeof dataObj === "object") {
|
||
// 兼容多种字段命名,包括结果详情返回的 result 数组
|
||
for (const key of ["img_url", "img_urls", "image_urls", "urls", "images", "result"]) {
|
||
if (dataObj[key]) {
|
||
const val = dataObj[key];
|
||
if (typeof val === "string") {
|
||
candidates.push(val);
|
||
} else if (Array.isArray(val)) {
|
||
for (const v of val) {
|
||
if (v) candidates.push(String(v));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
const seen = new Set();
|
||
const result = [];
|
||
for (const url of candidates) {
|
||
if (!seen.has(url)) {
|
||
seen.add(url);
|
||
result.push(url);
|
||
}
|
||
}
|
||
return result;
|
||
}
|
||
|
||
async function queryResult({ apiKey, taskId, pollIntervalMs, maxWaitMs }) {
|
||
const headers = {
|
||
Authorization: apiKey,
|
||
"Content-Type": "application/x-www-form-urlencoded;charset=utf-8;",
|
||
};
|
||
const params = { key: apiKey, id: taskId };
|
||
const start = Date.now();
|
||
|
||
while (true) {
|
||
const resp = await axios.get(RESULT_DETAIL_URL, {
|
||
params,
|
||
headers,
|
||
timeout: 30000,
|
||
});
|
||
const payload = resp.data;
|
||
|
||
if (!payload || payload.code !== 200) {
|
||
const msg = (payload && payload.msg) || "未知错误";
|
||
throw new Error(`查询任务失败: code=${payload && payload.code}, msg=${msg}`);
|
||
}
|
||
|
||
const data = payload.data || {};
|
||
const status = data.status;
|
||
|
||
if (status === 2) {
|
||
const urls = extractImageUrls(data);
|
||
return urls;
|
||
}
|
||
if (status === 3) {
|
||
const reason = data.message || payload.msg || "未知原因";
|
||
throw new Error("任务生成失败: " + reason);
|
||
}
|
||
|
||
const urls = extractImageUrls(data);
|
||
if (urls.length > 0) return urls;
|
||
|
||
if (Date.now() - start > maxWaitMs) {
|
||
throw new Error("等待任务结果超时,请稍后在控制台或结果查询接口自行确认。");
|
||
}
|
||
|
||
await new Promise((resolve) => setTimeout(resolve, pollIntervalMs));
|
||
}
|
||
}
|
||
|
||
async function downloadImages({ urls, outputDir, taskId }) {
|
||
if (!fs.existsSync(outputDir)) {
|
||
fs.mkdirSync(outputDir, { recursive: true });
|
||
}
|
||
|
||
for (let i = 0; i < urls.length; i++) {
|
||
const url = urls[i];
|
||
try {
|
||
const resp = await axios.get(url, {
|
||
responseType: "arraybuffer",
|
||
timeout: 60000,
|
||
});
|
||
|
||
let ext = ".jpg";
|
||
for (const cand of [".png", ".jpeg", ".jpg", ".webp", ".bmp", ".gif"]) {
|
||
if (url.toLowerCase().includes(cand)) {
|
||
ext = cand;
|
||
break;
|
||
}
|
||
}
|
||
|
||
const filename = path.join(outputDir, `${taskId}_${i + 1}${ext}`);
|
||
fs.writeFileSync(filename, resp.data);
|
||
console.log("已保存:", filename);
|
||
} catch (e) {
|
||
console.log("下载失败:", url, "| 错误:", e.message);
|
||
}
|
||
}
|
||
}
|
||
|
||
async function main() {
|
||
const baseDir = __dirname;
|
||
const config = loadConfig(baseDir);
|
||
|
||
const apiKey = ensureApiKey(config);
|
||
const prompt = ensurePrompt(config);
|
||
|
||
const size = config.size || "1K";
|
||
const aspectRatio = config.aspectRatio || "auto";
|
||
const pollIntervalMs = (config.poll_interval_seconds || 5) * 1000;
|
||
const maxWaitMs = (config.max_wait_seconds || 300) * 1000;
|
||
|
||
const inputDir = path.join(baseDir, config.input_dir || "01");
|
||
const outputDir = path.join(baseDir, config.output_dir || "save");
|
||
|
||
let combinedUrls = [];
|
||
let r2KeysToDelete = [];
|
||
const useR2 = isR2Configured(config);
|
||
|
||
if (useR2) {
|
||
// 使用 R2 图床:将 01 目录的图片上传到 R2,用返回的 URL 作为参考图
|
||
const client = createR2Client(config);
|
||
const { urls: r2Urls, keys } = await uploadRefImagesToR2({
|
||
client,
|
||
bucket: config.r2_bucket,
|
||
publicBaseUrl: config.r2_public_url,
|
||
inputDir,
|
||
maxFiles: 14,
|
||
});
|
||
r2KeysToDelete = keys;
|
||
if (r2Urls.length > 0) {
|
||
console.log(`已上传 ${r2Urls.length} 张参考图到 R2 图床。`);
|
||
}
|
||
combinedUrls = [...r2Urls];
|
||
} else {
|
||
// 未配置 R2:本地 01 转为 base64
|
||
const refB64List = await collectReferenceImages(inputDir);
|
||
if (refB64List.length > 0) {
|
||
console.log(`已从目录 ${inputDir} 读取 ${refB64List.length} 张本地参考图(base64)。`);
|
||
}
|
||
combinedUrls = [...refB64List];
|
||
}
|
||
|
||
// 配置中的参考图 URL(直接透传给接口)
|
||
let extraUrls = [];
|
||
if (typeof config.reference_urls === "string") {
|
||
extraUrls = [config.reference_urls];
|
||
} else if (Array.isArray(config.reference_urls)) {
|
||
extraUrls = config.reference_urls.filter((u) => !!u).map(String);
|
||
}
|
||
combinedUrls = [...combinedUrls, ...extraUrls];
|
||
|
||
if (extraUrls.length > 0) {
|
||
console.log(`已从配置文件读取 ${extraUrls.length} 个参考图 URL。`);
|
||
}
|
||
if (combinedUrls.length === 0) {
|
||
console.log(`未找到任何参考图,将仅根据提示词生成。`);
|
||
}
|
||
|
||
try {
|
||
console.log("正在创建 NanoBanana2 任务...");
|
||
const taskId = await createTask({
|
||
apiKey,
|
||
prompt,
|
||
size,
|
||
aspectRatio,
|
||
refB64List: combinedUrls,
|
||
});
|
||
console.log("任务已创建,任务 id:", taskId);
|
||
|
||
console.log("开始轮询任务结果...");
|
||
const urls = await queryResult({
|
||
apiKey,
|
||
taskId,
|
||
pollIntervalMs,
|
||
maxWaitMs,
|
||
});
|
||
|
||
if (!urls || urls.length === 0) {
|
||
console.log("任务完成,但未在结果中找到图片 URL,请登录速创API控制台或检查结果详情接口。");
|
||
return;
|
||
}
|
||
|
||
console.log(`任务完成,共获取到 ${urls.length} 个图片 URL,开始下载...`);
|
||
await downloadImages({ urls, outputDir, taskId });
|
||
console.log("全部处理完成。");
|
||
} finally {
|
||
if (useR2 && r2KeysToDelete.length > 0) {
|
||
try {
|
||
const client = createR2Client(config);
|
||
await deleteR2Objects({ client, bucket: config.r2_bucket, keys: r2KeysToDelete });
|
||
console.log("已删除本次上传的 R2 参考图。");
|
||
} catch (e) {
|
||
console.warn("删除 R2 参考图失败:", e.message);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
main().catch((err) => {
|
||
console.error("执行出错:", err.message || err);
|
||
process.exit(1);
|
||
});
|
||
|