Files
Banana/core/generator.js
2026-03-03 10:38:37 +08:00

242 lines
7.9 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

const fs = require("fs");
const path = require("path");
const axios = require("axios");
const { S3Client, PutObjectCommand, DeleteObjectsCommand } = require("@aws-sdk/client-s3");
const IMAGE_EXTS = new Set([".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"]);
const CONTENT_TYPES = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".webp": "image/webp",
".bmp": "image/bmp",
".gif": "image/gif",
};
function isR2Configured(config) {
return (
config &&
config.r2_account_id &&
config.r2_access_key_id &&
config.r2_secret_access_key &&
config.r2_bucket &&
config.r2_public_url
);
}
function createR2Client(config) {
const endpoint = `https://${config.r2_account_id}.r2.cloudflarestorage.com`;
return new S3Client({
region: "auto",
endpoint,
forcePathStyle: true,
credentials: {
accessKeyId: config.r2_access_key_id,
secretAccessKey: config.r2_secret_access_key,
},
});
}
/** 将指定本地文件路径上传到 R2返回 { urls, keys } */
async function uploadRefFilesToR2({ client, bucket, publicBaseUrl, filePaths }) {
const runId = Date.now().toString(36) + "-" + Math.random().toString(36).slice(2, 8);
const prefix = `ref/${runId}`;
const urls = [];
const keys = [];
const base = (publicBaseUrl || "").replace(/\/$/, "");
for (let i = 0; i < filePaths.length; i++) {
const filePath = filePaths[i];
if (!fs.existsSync(filePath)) continue;
const ext = path.extname(filePath).toLowerCase();
if (!IMAGE_EXTS.has(ext)) continue;
const name = path.basename(filePath);
const key = `${prefix}/${name}`;
const body = fs.readFileSync(filePath);
const contentType = CONTENT_TYPES[ext] || "application/octet-stream";
await client.send(
new PutObjectCommand({
Bucket: bucket,
Key: key,
Body: body,
ContentType: contentType,
})
);
urls.push(`${base}/${key}`);
keys.push(key);
}
return { urls, keys };
}
async function deleteR2Objects({ client, bucket, keys }) {
if (!keys || keys.length === 0) return;
await client.send(
new DeleteObjectsCommand({
Bucket: bucket,
Delete: { Objects: keys.map((Key) => ({ Key })) },
})
);
}
function extractImageUrls(dataObj) {
if (!dataObj) return [];
const candidates = [];
if (Array.isArray(dataObj)) {
dataObj.forEach((v) => v && candidates.push(String(v)));
} else if (typeof dataObj === "object") {
for (const key of ["img_url", "img_urls", "image_urls", "urls", "images", "result"]) {
const val = dataObj[key];
if (!val) continue;
if (typeof val === "string") candidates.push(val);
else if (Array.isArray(val)) val.forEach((v) => v && candidates.push(String(v)));
}
}
const seen = new Set();
return candidates.filter((u) => !seen.has(u) && seen.add(u));
}
/**
* 执行一次生图流程,支持进度回调
* @param {object} config - 完整配置api_key, r2_*, api_create_url, api_result_url, poll_interval_seconds, max_wait_seconds 等)
* @param {object} options - { prompt, size, aspectRatio, refFilePaths [], saveDir, extraParams {} }
* @param {function} onProgress - (step, message) => void
*/
async function runGeneration(config, options, onProgress = () => {}) {
const apiKey = config.api_key || process.env.WUYIN_API_KEY;
if (!apiKey) throw new Error("未配置 API 密钥");
const prompt = (options.prompt || "").trim();
if (!prompt) throw new Error("请输入提示词");
const createUrl = config.api_create_url || "https://api.wuyinkeji.com/api/async/image_nanoBanana2";
const resultUrl = config.api_result_url || "https://api.wuyinkeji.com/api/async/detail";
const pollIntervalMs = (config.poll_interval_seconds || 5) * 1000;
const maxWaitMs = (config.max_wait_seconds || 300) * 1000;
const saveDir = options.saveDir || config.default_save_dir || path.join(process.cwd(), "save");
const extraParams = options.extraParams || config.extra_params || {};
let r2KeysToDelete = [];
let combinedUrls = [];
if (isR2Configured(config) && options.refFilePaths && options.refFilePaths.length > 0) {
onProgress("upload", "正在上传参考图到 R2…");
const client = createR2Client(config);
const { urls, keys } = await uploadRefFilesToR2({
client,
bucket: config.r2_bucket,
publicBaseUrl: config.r2_public_url,
filePaths: options.refFilePaths,
});
r2KeysToDelete = keys;
combinedUrls = urls;
onProgress("upload_done", `已上传 ${urls.length} 张参考图`);
}
try {
onProgress("create", "正在创建生图任务…");
const body = {
prompt,
size: options.size || "1K",
aspectRatio: options.aspectRatio || "auto",
key: apiKey,
...extraParams,
};
if (combinedUrls.length > 0) body.urls = combinedUrls;
const createResp = await axios.post(createUrl, body, {
headers: {
Authorization: apiKey,
"Content-Type": "application/json;charset=utf-8;",
},
timeout: 30000,
});
const payload = createResp.data;
if (!payload || payload.code !== 200) {
throw new Error("创建任务失败: " + (payload?.msg || JSON.stringify(payload)));
}
const taskId = (payload.data || {}).id;
if (!taskId) throw new Error("返回数据中缺少任务 id");
onProgress("poll", "正在等待生成结果…");
const imageUrls = await pollResult({
apiKey,
taskId,
resultUrl,
pollIntervalMs,
maxWaitMs,
});
if (!imageUrls || imageUrls.length === 0) {
throw new Error("任务完成但未获取到图片 URL");
}
onProgress("download", `正在保存 ${imageUrls.length} 张图片…`);
if (!fs.existsSync(saveDir)) fs.mkdirSync(saveDir, { recursive: true });
const savedPaths = [];
for (let i = 0; i < imageUrls.length; i++) {
const url = imageUrls[i];
const resp = await axios.get(url, { responseType: "arraybuffer", timeout: 60000 });
let ext = ".jpg";
for (const e of [".png", ".jpeg", ".jpg", ".webp", ".bmp", ".gif"]) {
if (url.toLowerCase().includes(e)) {
ext = e;
break;
}
}
const filePath = path.join(saveDir, `${taskId}_${i + 1}${ext}`);
fs.writeFileSync(filePath, resp.data);
savedPaths.push(filePath);
}
onProgress("done", `已保存到 ${saveDir}`);
return { taskId, saveDir, count: imageUrls.length, savedPaths };
} finally {
if (r2KeysToDelete.length > 0 && isR2Configured(config)) {
try {
const client = createR2Client(config);
await deleteR2Objects({ client, bucket: config.r2_bucket, keys: r2KeysToDelete });
onProgress("cleanup", "已删除 R2 临时参考图");
} catch (e) {
onProgress("cleanup_error", "删除 R2 参考图失败: " + e.message);
}
}
}
}
async function pollResult({ apiKey, taskId, resultUrl, pollIntervalMs, maxWaitMs }) {
const start = Date.now();
const headers = {
Authorization: apiKey,
"Content-Type": "application/x-www-form-urlencoded;charset=utf-8;",
};
while (true) {
const resp = await axios.get(resultUrl, {
params: { key: apiKey, id: taskId },
headers,
timeout: 30000,
});
const payload = resp.data;
if (!payload || payload.code !== 200) {
throw new Error(`查询失败: ${payload?.msg || "未知错误"}`);
}
const data = payload.data || {};
const status = data.status;
if (status === 2) return extractImageUrls(data);
if (status === 3) throw new Error("任务生成失败: " + (data.message || payload.msg || "未知原因"));
const urls = extractImageUrls(data);
if (urls.length > 0) return urls;
if (Date.now() - start > maxWaitMs) {
throw new Error("等待结果超时");
}
await new Promise((r) => setTimeout(r, pollIntervalMs));
}
}
module.exports = {
runGeneration,
isR2Configured,
uploadRefFilesToR2,
deleteR2Objects,
createR2Client,
};