代码提交

This commit is contained in:
2026-03-03 10:38:37 +08:00
parent e904d7af1d
commit 000d1ef1a8
22 changed files with 7043 additions and 0 deletions

241
core/generator.js Normal file
View File

@@ -0,0 +1,241 @@
const fs = require("fs");
const path = require("path");
const axios = require("axios");
const { S3Client, PutObjectCommand, DeleteObjectsCommand } = require("@aws-sdk/client-s3");
const IMAGE_EXTS = new Set([".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"]);
const CONTENT_TYPES = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".webp": "image/webp",
".bmp": "image/bmp",
".gif": "image/gif",
};
function isR2Configured(config) {
return (
config &&
config.r2_account_id &&
config.r2_access_key_id &&
config.r2_secret_access_key &&
config.r2_bucket &&
config.r2_public_url
);
}
function createR2Client(config) {
const endpoint = `https://${config.r2_account_id}.r2.cloudflarestorage.com`;
return new S3Client({
region: "auto",
endpoint,
forcePathStyle: true,
credentials: {
accessKeyId: config.r2_access_key_id,
secretAccessKey: config.r2_secret_access_key,
},
});
}
/** 将指定本地文件路径上传到 R2返回 { urls, keys } */
async function uploadRefFilesToR2({ client, bucket, publicBaseUrl, filePaths }) {
const runId = Date.now().toString(36) + "-" + Math.random().toString(36).slice(2, 8);
const prefix = `ref/${runId}`;
const urls = [];
const keys = [];
const base = (publicBaseUrl || "").replace(/\/$/, "");
for (let i = 0; i < filePaths.length; i++) {
const filePath = filePaths[i];
if (!fs.existsSync(filePath)) continue;
const ext = path.extname(filePath).toLowerCase();
if (!IMAGE_EXTS.has(ext)) continue;
const name = path.basename(filePath);
const key = `${prefix}/${name}`;
const body = fs.readFileSync(filePath);
const contentType = CONTENT_TYPES[ext] || "application/octet-stream";
await client.send(
new PutObjectCommand({
Bucket: bucket,
Key: key,
Body: body,
ContentType: contentType,
})
);
urls.push(`${base}/${key}`);
keys.push(key);
}
return { urls, keys };
}
async function deleteR2Objects({ client, bucket, keys }) {
if (!keys || keys.length === 0) return;
await client.send(
new DeleteObjectsCommand({
Bucket: bucket,
Delete: { Objects: keys.map((Key) => ({ Key })) },
})
);
}
function extractImageUrls(dataObj) {
if (!dataObj) return [];
const candidates = [];
if (Array.isArray(dataObj)) {
dataObj.forEach((v) => v && candidates.push(String(v)));
} else if (typeof dataObj === "object") {
for (const key of ["img_url", "img_urls", "image_urls", "urls", "images", "result"]) {
const val = dataObj[key];
if (!val) continue;
if (typeof val === "string") candidates.push(val);
else if (Array.isArray(val)) val.forEach((v) => v && candidates.push(String(v)));
}
}
const seen = new Set();
return candidates.filter((u) => !seen.has(u) && seen.add(u));
}
/**
* 执行一次生图流程,支持进度回调
* @param {object} config - 完整配置api_key, r2_*, api_create_url, api_result_url, poll_interval_seconds, max_wait_seconds 等)
* @param {object} options - { prompt, size, aspectRatio, refFilePaths [], saveDir, extraParams {} }
* @param {function} onProgress - (step, message) => void
*/
async function runGeneration(config, options, onProgress = () => {}) {
const apiKey = config.api_key || process.env.WUYIN_API_KEY;
if (!apiKey) throw new Error("未配置 API 密钥");
const prompt = (options.prompt || "").trim();
if (!prompt) throw new Error("请输入提示词");
const createUrl = config.api_create_url || "https://api.wuyinkeji.com/api/async/image_nanoBanana2";
const resultUrl = config.api_result_url || "https://api.wuyinkeji.com/api/async/detail";
const pollIntervalMs = (config.poll_interval_seconds || 5) * 1000;
const maxWaitMs = (config.max_wait_seconds || 300) * 1000;
const saveDir = options.saveDir || config.default_save_dir || path.join(process.cwd(), "save");
const extraParams = options.extraParams || config.extra_params || {};
let r2KeysToDelete = [];
let combinedUrls = [];
if (isR2Configured(config) && options.refFilePaths && options.refFilePaths.length > 0) {
onProgress("upload", "正在上传参考图到 R2…");
const client = createR2Client(config);
const { urls, keys } = await uploadRefFilesToR2({
client,
bucket: config.r2_bucket,
publicBaseUrl: config.r2_public_url,
filePaths: options.refFilePaths,
});
r2KeysToDelete = keys;
combinedUrls = urls;
onProgress("upload_done", `已上传 ${urls.length} 张参考图`);
}
try {
onProgress("create", "正在创建生图任务…");
const body = {
prompt,
size: options.size || "1K",
aspectRatio: options.aspectRatio || "auto",
key: apiKey,
...extraParams,
};
if (combinedUrls.length > 0) body.urls = combinedUrls;
const createResp = await axios.post(createUrl, body, {
headers: {
Authorization: apiKey,
"Content-Type": "application/json;charset=utf-8;",
},
timeout: 30000,
});
const payload = createResp.data;
if (!payload || payload.code !== 200) {
throw new Error("创建任务失败: " + (payload?.msg || JSON.stringify(payload)));
}
const taskId = (payload.data || {}).id;
if (!taskId) throw new Error("返回数据中缺少任务 id");
onProgress("poll", "正在等待生成结果…");
const imageUrls = await pollResult({
apiKey,
taskId,
resultUrl,
pollIntervalMs,
maxWaitMs,
});
if (!imageUrls || imageUrls.length === 0) {
throw new Error("任务完成但未获取到图片 URL");
}
onProgress("download", `正在保存 ${imageUrls.length} 张图片…`);
if (!fs.existsSync(saveDir)) fs.mkdirSync(saveDir, { recursive: true });
const savedPaths = [];
for (let i = 0; i < imageUrls.length; i++) {
const url = imageUrls[i];
const resp = await axios.get(url, { responseType: "arraybuffer", timeout: 60000 });
let ext = ".jpg";
for (const e of [".png", ".jpeg", ".jpg", ".webp", ".bmp", ".gif"]) {
if (url.toLowerCase().includes(e)) {
ext = e;
break;
}
}
const filePath = path.join(saveDir, `${taskId}_${i + 1}${ext}`);
fs.writeFileSync(filePath, resp.data);
savedPaths.push(filePath);
}
onProgress("done", `已保存到 ${saveDir}`);
return { taskId, saveDir, count: imageUrls.length, savedPaths };
} finally {
if (r2KeysToDelete.length > 0 && isR2Configured(config)) {
try {
const client = createR2Client(config);
await deleteR2Objects({ client, bucket: config.r2_bucket, keys: r2KeysToDelete });
onProgress("cleanup", "已删除 R2 临时参考图");
} catch (e) {
onProgress("cleanup_error", "删除 R2 参考图失败: " + e.message);
}
}
}
}
async function pollResult({ apiKey, taskId, resultUrl, pollIntervalMs, maxWaitMs }) {
const start = Date.now();
const headers = {
Authorization: apiKey,
"Content-Type": "application/x-www-form-urlencoded;charset=utf-8;",
};
while (true) {
const resp = await axios.get(resultUrl, {
params: { key: apiKey, id: taskId },
headers,
timeout: 30000,
});
const payload = resp.data;
if (!payload || payload.code !== 200) {
throw new Error(`查询失败: ${payload?.msg || "未知错误"}`);
}
const data = payload.data || {};
const status = data.status;
if (status === 2) return extractImageUrls(data);
if (status === 3) throw new Error("任务生成失败: " + (data.message || payload.msg || "未知原因"));
const urls = extractImageUrls(data);
if (urls.length > 0) return urls;
if (Date.now() - start > maxWaitMs) {
throw new Error("等待结果超时");
}
await new Promise((r) => setTimeout(r, pollIntervalMs));
}
}
module.exports = {
runGeneration,
isR2Configured,
uploadRefFilesToR2,
deleteR2Objects,
createR2Client,
};