const fs = require("fs"); const path = require("path"); const axios = require("axios"); const { S3Client, PutObjectCommand, DeleteObjectsCommand } = require("@aws-sdk/client-s3"); const IMAGE_EXTS = new Set([".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif"]); const CONTENT_TYPES = { ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", ".webp": "image/webp", ".bmp": "image/bmp", ".gif": "image/gif", }; function isR2Configured(config) { return ( config && config.r2_account_id && config.r2_access_key_id && config.r2_secret_access_key && config.r2_bucket && config.r2_public_url ); } function createR2Client(config) { const endpoint = `https://${config.r2_account_id}.r2.cloudflarestorage.com`; return new S3Client({ region: "auto", endpoint, forcePathStyle: true, credentials: { accessKeyId: config.r2_access_key_id, secretAccessKey: config.r2_secret_access_key, }, }); } /** 将指定本地文件路径上传到 R2,返回 { urls, keys } */ async function uploadRefFilesToR2({ client, bucket, publicBaseUrl, filePaths }) { const runId = Date.now().toString(36) + "-" + Math.random().toString(36).slice(2, 8); const prefix = `ref/${runId}`; const urls = []; const keys = []; const base = (publicBaseUrl || "").replace(/\/$/, ""); for (let i = 0; i < filePaths.length; i++) { const filePath = filePaths[i]; if (!fs.existsSync(filePath)) continue; const ext = path.extname(filePath).toLowerCase(); if (!IMAGE_EXTS.has(ext)) continue; const name = path.basename(filePath); const key = `${prefix}/${name}`; const body = fs.readFileSync(filePath); const contentType = CONTENT_TYPES[ext] || "application/octet-stream"; await client.send( new PutObjectCommand({ Bucket: bucket, Key: key, Body: body, ContentType: contentType, }) ); urls.push(`${base}/${key}`); keys.push(key); } return { urls, keys }; } async function deleteR2Objects({ client, bucket, keys }) { if (!keys || keys.length === 0) return; await client.send( new DeleteObjectsCommand({ Bucket: bucket, Delete: { Objects: keys.map((Key) => ({ Key })) }, }) ); } function extractImageUrls(dataObj) { if (!dataObj) return []; const candidates = []; if (Array.isArray(dataObj)) { dataObj.forEach((v) => v && candidates.push(String(v))); } else if (typeof dataObj === "object") { for (const key of ["img_url", "img_urls", "image_urls", "urls", "images", "result"]) { const val = dataObj[key]; if (!val) continue; if (typeof val === "string") candidates.push(val); else if (Array.isArray(val)) val.forEach((v) => v && candidates.push(String(v))); } } const seen = new Set(); return candidates.filter((u) => !seen.has(u) && seen.add(u)); } /** * 执行一次生图流程,支持进度回调 * @param {object} config - 完整配置(api_key, r2_*, api_create_url, api_result_url, poll_interval_seconds, max_wait_seconds 等) * @param {object} options - { prompt, size, aspectRatio, refFilePaths [], saveDir, extraParams {} } * @param {function} onProgress - (step, message) => void */ async function runGeneration(config, options, onProgress = () => {}) { const apiKey = config.api_key || process.env.WUYIN_API_KEY; if (!apiKey) throw new Error("未配置 API 密钥"); const prompt = (options.prompt || "").trim(); if (!prompt) throw new Error("请输入提示词"); const createUrl = config.api_create_url || "https://api.wuyinkeji.com/api/async/image_nanoBanana2"; const resultUrl = config.api_result_url || "https://api.wuyinkeji.com/api/async/detail"; const pollIntervalMs = (config.poll_interval_seconds || 5) * 1000; const maxWaitMs = (config.max_wait_seconds || 300) * 1000; const saveDir = options.saveDir || config.default_save_dir || path.join(process.cwd(), "save"); const extraParams = options.extraParams || config.extra_params || {}; let r2KeysToDelete = []; let combinedUrls = []; if (isR2Configured(config) && options.refFilePaths && options.refFilePaths.length > 0) { onProgress("upload", "正在上传参考图到 R2…"); const client = createR2Client(config); const { urls, keys } = await uploadRefFilesToR2({ client, bucket: config.r2_bucket, publicBaseUrl: config.r2_public_url, filePaths: options.refFilePaths, }); r2KeysToDelete = keys; combinedUrls = urls; onProgress("upload_done", `已上传 ${urls.length} 张参考图`); } try { onProgress("create", "正在创建生图任务…"); const body = { prompt, size: options.size || "1K", aspectRatio: options.aspectRatio || "auto", key: apiKey, ...extraParams, }; if (combinedUrls.length > 0) body.urls = combinedUrls; const createResp = await axios.post(createUrl, body, { headers: { Authorization: apiKey, "Content-Type": "application/json;charset=utf-8;", }, timeout: 30000, }); const payload = createResp.data; if (!payload || payload.code !== 200) { throw new Error("创建任务失败: " + (payload?.msg || JSON.stringify(payload))); } const taskId = (payload.data || {}).id; if (!taskId) throw new Error("返回数据中缺少任务 id"); onProgress("poll", "正在等待生成结果…"); const imageUrls = await pollResult({ apiKey, taskId, resultUrl, pollIntervalMs, maxWaitMs, }); if (!imageUrls || imageUrls.length === 0) { throw new Error("任务完成但未获取到图片 URL"); } onProgress("download", `正在保存 ${imageUrls.length} 张图片…`); if (!fs.existsSync(saveDir)) fs.mkdirSync(saveDir, { recursive: true }); const savedPaths = []; for (let i = 0; i < imageUrls.length; i++) { const url = imageUrls[i]; const resp = await axios.get(url, { responseType: "arraybuffer", timeout: 60000 }); let ext = ".jpg"; for (const e of [".png", ".jpeg", ".jpg", ".webp", ".bmp", ".gif"]) { if (url.toLowerCase().includes(e)) { ext = e; break; } } const filePath = path.join(saveDir, `${taskId}_${i + 1}${ext}`); fs.writeFileSync(filePath, resp.data); savedPaths.push(filePath); } onProgress("done", `已保存到 ${saveDir}`); return { taskId, saveDir, count: imageUrls.length, savedPaths }; } finally { if (r2KeysToDelete.length > 0 && isR2Configured(config)) { try { const client = createR2Client(config); await deleteR2Objects({ client, bucket: config.r2_bucket, keys: r2KeysToDelete }); onProgress("cleanup", "已删除 R2 临时参考图"); } catch (e) { onProgress("cleanup_error", "删除 R2 参考图失败: " + e.message); } } } } async function pollResult({ apiKey, taskId, resultUrl, pollIntervalMs, maxWaitMs }) { const start = Date.now(); const headers = { Authorization: apiKey, "Content-Type": "application/x-www-form-urlencoded;charset=utf-8;", }; while (true) { const resp = await axios.get(resultUrl, { params: { key: apiKey, id: taskId }, headers, timeout: 30000, }); const payload = resp.data; if (!payload || payload.code !== 200) { throw new Error(`查询失败: ${payload?.msg || "未知错误"}`); } const data = payload.data || {}; const status = data.status; if (status === 2) return extractImageUrls(data); if (status === 3) throw new Error("任务生成失败: " + (data.message || payload.msg || "未知原因")); const urls = extractImageUrls(data); if (urls.length > 0) return urls; if (Date.now() - start > maxWaitMs) { throw new Error("等待结果超时"); } await new Promise((r) => setTimeout(r, pollIntervalMs)); } } module.exports = { runGeneration, isR2Configured, uploadRefFilesToR2, deleteR2Objects, createR2Client, };