代码提交
This commit is contained in:
241
core/generator.js
Normal file
241
core/generator.js
Normal file
@@ -0,0 +1,241 @@
|
||||
const fs = require("fs");
|
||||
const path = require("path");
|
||||
const axios = require("axios");
|
||||
const { S3Client, PutObjectCommand, DeleteObjectsCommand } = require("@aws-sdk/client-s3");
|
||||
|
||||
const IMAGE_EXTS = new Set([".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"]);
|
||||
const CONTENT_TYPES = {
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".webp": "image/webp",
|
||||
".bmp": "image/bmp",
|
||||
".gif": "image/gif",
|
||||
};
|
||||
|
||||
function isR2Configured(config) {
|
||||
return (
|
||||
config &&
|
||||
config.r2_account_id &&
|
||||
config.r2_access_key_id &&
|
||||
config.r2_secret_access_key &&
|
||||
config.r2_bucket &&
|
||||
config.r2_public_url
|
||||
);
|
||||
}
|
||||
|
||||
function createR2Client(config) {
|
||||
const endpoint = `https://${config.r2_account_id}.r2.cloudflarestorage.com`;
|
||||
return new S3Client({
|
||||
region: "auto",
|
||||
endpoint,
|
||||
forcePathStyle: true,
|
||||
credentials: {
|
||||
accessKeyId: config.r2_access_key_id,
|
||||
secretAccessKey: config.r2_secret_access_key,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/** 将指定本地文件路径上传到 R2,返回 { urls, keys } */
|
||||
async function uploadRefFilesToR2({ client, bucket, publicBaseUrl, filePaths }) {
|
||||
const runId = Date.now().toString(36) + "-" + Math.random().toString(36).slice(2, 8);
|
||||
const prefix = `ref/${runId}`;
|
||||
const urls = [];
|
||||
const keys = [];
|
||||
const base = (publicBaseUrl || "").replace(/\/$/, "");
|
||||
|
||||
for (let i = 0; i < filePaths.length; i++) {
|
||||
const filePath = filePaths[i];
|
||||
if (!fs.existsSync(filePath)) continue;
|
||||
const ext = path.extname(filePath).toLowerCase();
|
||||
if (!IMAGE_EXTS.has(ext)) continue;
|
||||
const name = path.basename(filePath);
|
||||
const key = `${prefix}/${name}`;
|
||||
const body = fs.readFileSync(filePath);
|
||||
const contentType = CONTENT_TYPES[ext] || "application/octet-stream";
|
||||
|
||||
await client.send(
|
||||
new PutObjectCommand({
|
||||
Bucket: bucket,
|
||||
Key: key,
|
||||
Body: body,
|
||||
ContentType: contentType,
|
||||
})
|
||||
);
|
||||
urls.push(`${base}/${key}`);
|
||||
keys.push(key);
|
||||
}
|
||||
return { urls, keys };
|
||||
}
|
||||
|
||||
async function deleteR2Objects({ client, bucket, keys }) {
|
||||
if (!keys || keys.length === 0) return;
|
||||
await client.send(
|
||||
new DeleteObjectsCommand({
|
||||
Bucket: bucket,
|
||||
Delete: { Objects: keys.map((Key) => ({ Key })) },
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
function extractImageUrls(dataObj) {
|
||||
if (!dataObj) return [];
|
||||
const candidates = [];
|
||||
if (Array.isArray(dataObj)) {
|
||||
dataObj.forEach((v) => v && candidates.push(String(v)));
|
||||
} else if (typeof dataObj === "object") {
|
||||
for (const key of ["img_url", "img_urls", "image_urls", "urls", "images", "result"]) {
|
||||
const val = dataObj[key];
|
||||
if (!val) continue;
|
||||
if (typeof val === "string") candidates.push(val);
|
||||
else if (Array.isArray(val)) val.forEach((v) => v && candidates.push(String(v)));
|
||||
}
|
||||
}
|
||||
const seen = new Set();
|
||||
return candidates.filter((u) => !seen.has(u) && seen.add(u));
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行一次生图流程,支持进度回调
|
||||
* @param {object} config - 完整配置(api_key, r2_*, api_create_url, api_result_url, poll_interval_seconds, max_wait_seconds 等)
|
||||
* @param {object} options - { prompt, size, aspectRatio, refFilePaths [], saveDir, extraParams {} }
|
||||
* @param {function} onProgress - (step, message) => void
|
||||
*/
|
||||
async function runGeneration(config, options, onProgress = () => {}) {
|
||||
const apiKey = config.api_key || process.env.WUYIN_API_KEY;
|
||||
if (!apiKey) throw new Error("未配置 API 密钥");
|
||||
|
||||
const prompt = (options.prompt || "").trim();
|
||||
if (!prompt) throw new Error("请输入提示词");
|
||||
|
||||
const createUrl = config.api_create_url || "https://api.wuyinkeji.com/api/async/image_nanoBanana2";
|
||||
const resultUrl = config.api_result_url || "https://api.wuyinkeji.com/api/async/detail";
|
||||
const pollIntervalMs = (config.poll_interval_seconds || 5) * 1000;
|
||||
const maxWaitMs = (config.max_wait_seconds || 300) * 1000;
|
||||
const saveDir = options.saveDir || config.default_save_dir || path.join(process.cwd(), "save");
|
||||
const extraParams = options.extraParams || config.extra_params || {};
|
||||
|
||||
let r2KeysToDelete = [];
|
||||
let combinedUrls = [];
|
||||
|
||||
if (isR2Configured(config) && options.refFilePaths && options.refFilePaths.length > 0) {
|
||||
onProgress("upload", "正在上传参考图到 R2…");
|
||||
const client = createR2Client(config);
|
||||
const { urls, keys } = await uploadRefFilesToR2({
|
||||
client,
|
||||
bucket: config.r2_bucket,
|
||||
publicBaseUrl: config.r2_public_url,
|
||||
filePaths: options.refFilePaths,
|
||||
});
|
||||
r2KeysToDelete = keys;
|
||||
combinedUrls = urls;
|
||||
onProgress("upload_done", `已上传 ${urls.length} 张参考图`);
|
||||
}
|
||||
|
||||
try {
|
||||
onProgress("create", "正在创建生图任务…");
|
||||
const body = {
|
||||
prompt,
|
||||
size: options.size || "1K",
|
||||
aspectRatio: options.aspectRatio || "auto",
|
||||
key: apiKey,
|
||||
...extraParams,
|
||||
};
|
||||
if (combinedUrls.length > 0) body.urls = combinedUrls;
|
||||
|
||||
const createResp = await axios.post(createUrl, body, {
|
||||
headers: {
|
||||
Authorization: apiKey,
|
||||
"Content-Type": "application/json;charset=utf-8;",
|
||||
},
|
||||
timeout: 30000,
|
||||
});
|
||||
const payload = createResp.data;
|
||||
if (!payload || payload.code !== 200) {
|
||||
throw new Error("创建任务失败: " + (payload?.msg || JSON.stringify(payload)));
|
||||
}
|
||||
const taskId = (payload.data || {}).id;
|
||||
if (!taskId) throw new Error("返回数据中缺少任务 id");
|
||||
|
||||
onProgress("poll", "正在等待生成结果…");
|
||||
const imageUrls = await pollResult({
|
||||
apiKey,
|
||||
taskId,
|
||||
resultUrl,
|
||||
pollIntervalMs,
|
||||
maxWaitMs,
|
||||
});
|
||||
|
||||
if (!imageUrls || imageUrls.length === 0) {
|
||||
throw new Error("任务完成但未获取到图片 URL");
|
||||
}
|
||||
|
||||
onProgress("download", `正在保存 ${imageUrls.length} 张图片…`);
|
||||
if (!fs.existsSync(saveDir)) fs.mkdirSync(saveDir, { recursive: true });
|
||||
const savedPaths = [];
|
||||
for (let i = 0; i < imageUrls.length; i++) {
|
||||
const url = imageUrls[i];
|
||||
const resp = await axios.get(url, { responseType: "arraybuffer", timeout: 60000 });
|
||||
let ext = ".jpg";
|
||||
for (const e of [".png", ".jpeg", ".jpg", ".webp", ".bmp", ".gif"]) {
|
||||
if (url.toLowerCase().includes(e)) {
|
||||
ext = e;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const filePath = path.join(saveDir, `${taskId}_${i + 1}${ext}`);
|
||||
fs.writeFileSync(filePath, resp.data);
|
||||
savedPaths.push(filePath);
|
||||
}
|
||||
onProgress("done", `已保存到 ${saveDir}`);
|
||||
return { taskId, saveDir, count: imageUrls.length, savedPaths };
|
||||
} finally {
|
||||
if (r2KeysToDelete.length > 0 && isR2Configured(config)) {
|
||||
try {
|
||||
const client = createR2Client(config);
|
||||
await deleteR2Objects({ client, bucket: config.r2_bucket, keys: r2KeysToDelete });
|
||||
onProgress("cleanup", "已删除 R2 临时参考图");
|
||||
} catch (e) {
|
||||
onProgress("cleanup_error", "删除 R2 参考图失败: " + e.message);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function pollResult({ apiKey, taskId, resultUrl, pollIntervalMs, maxWaitMs }) {
|
||||
const start = Date.now();
|
||||
const headers = {
|
||||
Authorization: apiKey,
|
||||
"Content-Type": "application/x-www-form-urlencoded;charset=utf-8;",
|
||||
};
|
||||
while (true) {
|
||||
const resp = await axios.get(resultUrl, {
|
||||
params: { key: apiKey, id: taskId },
|
||||
headers,
|
||||
timeout: 30000,
|
||||
});
|
||||
const payload = resp.data;
|
||||
if (!payload || payload.code !== 200) {
|
||||
throw new Error(`查询失败: ${payload?.msg || "未知错误"}`);
|
||||
}
|
||||
const data = payload.data || {};
|
||||
const status = data.status;
|
||||
if (status === 2) return extractImageUrls(data);
|
||||
if (status === 3) throw new Error("任务生成失败: " + (data.message || payload.msg || "未知原因"));
|
||||
const urls = extractImageUrls(data);
|
||||
if (urls.length > 0) return urls;
|
||||
if (Date.now() - start > maxWaitMs) {
|
||||
throw new Error("等待结果超时");
|
||||
}
|
||||
await new Promise((r) => setTimeout(r, pollIntervalMs));
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
runGeneration,
|
||||
isR2Configured,
|
||||
uploadRefFilesToR2,
|
||||
deleteR2Objects,
|
||||
createR2Client,
|
||||
};
|
||||
Reference in New Issue
Block a user