242 lines
7.9 KiB
JavaScript
242 lines
7.9 KiB
JavaScript
const fs = require("fs");
|
||
const path = require("path");
|
||
const axios = require("axios");
|
||
const { S3Client, PutObjectCommand, DeleteObjectsCommand } = require("@aws-sdk/client-s3");
|
||
|
||
const IMAGE_EXTS = new Set([".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"]);
|
||
const CONTENT_TYPES = {
|
||
".jpg": "image/jpeg",
|
||
".jpeg": "image/jpeg",
|
||
".png": "image/png",
|
||
".webp": "image/webp",
|
||
".bmp": "image/bmp",
|
||
".gif": "image/gif",
|
||
};
|
||
|
||
function isR2Configured(config) {
|
||
return (
|
||
config &&
|
||
config.r2_account_id &&
|
||
config.r2_access_key_id &&
|
||
config.r2_secret_access_key &&
|
||
config.r2_bucket &&
|
||
config.r2_public_url
|
||
);
|
||
}
|
||
|
||
function createR2Client(config) {
|
||
const endpoint = `https://${config.r2_account_id}.r2.cloudflarestorage.com`;
|
||
return new S3Client({
|
||
region: "auto",
|
||
endpoint,
|
||
forcePathStyle: true,
|
||
credentials: {
|
||
accessKeyId: config.r2_access_key_id,
|
||
secretAccessKey: config.r2_secret_access_key,
|
||
},
|
||
});
|
||
}
|
||
|
||
/** 将指定本地文件路径上传到 R2,返回 { urls, keys } */
|
||
async function uploadRefFilesToR2({ client, bucket, publicBaseUrl, filePaths }) {
|
||
const runId = Date.now().toString(36) + "-" + Math.random().toString(36).slice(2, 8);
|
||
const prefix = `ref/${runId}`;
|
||
const urls = [];
|
||
const keys = [];
|
||
const base = (publicBaseUrl || "").replace(/\/$/, "");
|
||
|
||
for (let i = 0; i < filePaths.length; i++) {
|
||
const filePath = filePaths[i];
|
||
if (!fs.existsSync(filePath)) continue;
|
||
const ext = path.extname(filePath).toLowerCase();
|
||
if (!IMAGE_EXTS.has(ext)) continue;
|
||
const name = path.basename(filePath);
|
||
const key = `${prefix}/${name}`;
|
||
const body = fs.readFileSync(filePath);
|
||
const contentType = CONTENT_TYPES[ext] || "application/octet-stream";
|
||
|
||
await client.send(
|
||
new PutObjectCommand({
|
||
Bucket: bucket,
|
||
Key: key,
|
||
Body: body,
|
||
ContentType: contentType,
|
||
})
|
||
);
|
||
urls.push(`${base}/${key}`);
|
||
keys.push(key);
|
||
}
|
||
return { urls, keys };
|
||
}
|
||
|
||
async function deleteR2Objects({ client, bucket, keys }) {
|
||
if (!keys || keys.length === 0) return;
|
||
await client.send(
|
||
new DeleteObjectsCommand({
|
||
Bucket: bucket,
|
||
Delete: { Objects: keys.map((Key) => ({ Key })) },
|
||
})
|
||
);
|
||
}
|
||
|
||
function extractImageUrls(dataObj) {
|
||
if (!dataObj) return [];
|
||
const candidates = [];
|
||
if (Array.isArray(dataObj)) {
|
||
dataObj.forEach((v) => v && candidates.push(String(v)));
|
||
} else if (typeof dataObj === "object") {
|
||
for (const key of ["img_url", "img_urls", "image_urls", "urls", "images", "result"]) {
|
||
const val = dataObj[key];
|
||
if (!val) continue;
|
||
if (typeof val === "string") candidates.push(val);
|
||
else if (Array.isArray(val)) val.forEach((v) => v && candidates.push(String(v)));
|
||
}
|
||
}
|
||
const seen = new Set();
|
||
return candidates.filter((u) => !seen.has(u) && seen.add(u));
|
||
}
|
||
|
||
/**
|
||
* 执行一次生图流程,支持进度回调
|
||
* @param {object} config - 完整配置(api_key, r2_*, api_create_url, api_result_url, poll_interval_seconds, max_wait_seconds 等)
|
||
* @param {object} options - { prompt, size, aspectRatio, refFilePaths [], saveDir, extraParams {} }
|
||
* @param {function} onProgress - (step, message) => void
|
||
*/
|
||
async function runGeneration(config, options, onProgress = () => {}) {
|
||
const apiKey = config.api_key || process.env.WUYIN_API_KEY;
|
||
if (!apiKey) throw new Error("未配置 API 密钥");
|
||
|
||
const prompt = (options.prompt || "").trim();
|
||
if (!prompt) throw new Error("请输入提示词");
|
||
|
||
const createUrl = config.api_create_url || "https://api.wuyinkeji.com/api/async/image_nanoBanana2";
|
||
const resultUrl = config.api_result_url || "https://api.wuyinkeji.com/api/async/detail";
|
||
const pollIntervalMs = (config.poll_interval_seconds || 5) * 1000;
|
||
const maxWaitMs = (config.max_wait_seconds || 300) * 1000;
|
||
const saveDir = options.saveDir || config.default_save_dir || path.join(process.cwd(), "save");
|
||
const extraParams = options.extraParams || config.extra_params || {};
|
||
|
||
let r2KeysToDelete = [];
|
||
let combinedUrls = [];
|
||
|
||
if (isR2Configured(config) && options.refFilePaths && options.refFilePaths.length > 0) {
|
||
onProgress("upload", "正在上传参考图到 R2…");
|
||
const client = createR2Client(config);
|
||
const { urls, keys } = await uploadRefFilesToR2({
|
||
client,
|
||
bucket: config.r2_bucket,
|
||
publicBaseUrl: config.r2_public_url,
|
||
filePaths: options.refFilePaths,
|
||
});
|
||
r2KeysToDelete = keys;
|
||
combinedUrls = urls;
|
||
onProgress("upload_done", `已上传 ${urls.length} 张参考图`);
|
||
}
|
||
|
||
try {
|
||
onProgress("create", "正在创建生图任务…");
|
||
const body = {
|
||
prompt,
|
||
size: options.size || "1K",
|
||
aspectRatio: options.aspectRatio || "auto",
|
||
key: apiKey,
|
||
...extraParams,
|
||
};
|
||
if (combinedUrls.length > 0) body.urls = combinedUrls;
|
||
|
||
const createResp = await axios.post(createUrl, body, {
|
||
headers: {
|
||
Authorization: apiKey,
|
||
"Content-Type": "application/json;charset=utf-8;",
|
||
},
|
||
timeout: 30000,
|
||
});
|
||
const payload = createResp.data;
|
||
if (!payload || payload.code !== 200) {
|
||
throw new Error("创建任务失败: " + (payload?.msg || JSON.stringify(payload)));
|
||
}
|
||
const taskId = (payload.data || {}).id;
|
||
if (!taskId) throw new Error("返回数据中缺少任务 id");
|
||
|
||
onProgress("poll", "正在等待生成结果…");
|
||
const imageUrls = await pollResult({
|
||
apiKey,
|
||
taskId,
|
||
resultUrl,
|
||
pollIntervalMs,
|
||
maxWaitMs,
|
||
});
|
||
|
||
if (!imageUrls || imageUrls.length === 0) {
|
||
throw new Error("任务完成但未获取到图片 URL");
|
||
}
|
||
|
||
onProgress("download", `正在保存 ${imageUrls.length} 张图片…`);
|
||
if (!fs.existsSync(saveDir)) fs.mkdirSync(saveDir, { recursive: true });
|
||
const savedPaths = [];
|
||
for (let i = 0; i < imageUrls.length; i++) {
|
||
const url = imageUrls[i];
|
||
const resp = await axios.get(url, { responseType: "arraybuffer", timeout: 60000 });
|
||
let ext = ".jpg";
|
||
for (const e of [".png", ".jpeg", ".jpg", ".webp", ".bmp", ".gif"]) {
|
||
if (url.toLowerCase().includes(e)) {
|
||
ext = e;
|
||
break;
|
||
}
|
||
}
|
||
const filePath = path.join(saveDir, `${taskId}_${i + 1}${ext}`);
|
||
fs.writeFileSync(filePath, resp.data);
|
||
savedPaths.push(filePath);
|
||
}
|
||
onProgress("done", `已保存到 ${saveDir}`);
|
||
return { taskId, saveDir, count: imageUrls.length, savedPaths };
|
||
} finally {
|
||
if (r2KeysToDelete.length > 0 && isR2Configured(config)) {
|
||
try {
|
||
const client = createR2Client(config);
|
||
await deleteR2Objects({ client, bucket: config.r2_bucket, keys: r2KeysToDelete });
|
||
onProgress("cleanup", "已删除 R2 临时参考图");
|
||
} catch (e) {
|
||
onProgress("cleanup_error", "删除 R2 参考图失败: " + e.message);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
async function pollResult({ apiKey, taskId, resultUrl, pollIntervalMs, maxWaitMs }) {
|
||
const start = Date.now();
|
||
const headers = {
|
||
Authorization: apiKey,
|
||
"Content-Type": "application/x-www-form-urlencoded;charset=utf-8;",
|
||
};
|
||
while (true) {
|
||
const resp = await axios.get(resultUrl, {
|
||
params: { key: apiKey, id: taskId },
|
||
headers,
|
||
timeout: 30000,
|
||
});
|
||
const payload = resp.data;
|
||
if (!payload || payload.code !== 200) {
|
||
throw new Error(`查询失败: ${payload?.msg || "未知错误"}`);
|
||
}
|
||
const data = payload.data || {};
|
||
const status = data.status;
|
||
if (status === 2) return extractImageUrls(data);
|
||
if (status === 3) throw new Error("任务生成失败: " + (data.message || payload.msg || "未知原因"));
|
||
const urls = extractImageUrls(data);
|
||
if (urls.length > 0) return urls;
|
||
if (Date.now() - start > maxWaitMs) {
|
||
throw new Error("等待结果超时");
|
||
}
|
||
await new Promise((r) => setTimeout(r, pollIntervalMs));
|
||
}
|
||
}
|
||
|
||
module.exports = {
|
||
runGeneration,
|
||
isR2Configured,
|
||
uploadRefFilesToR2,
|
||
deleteR2Objects,
|
||
createR2Client,
|
||
};
|