Files
video-flow/api/tasks/audio_tasks.py
2026-01-09 14:09:16 +08:00

213 lines
5.6 KiB
Python

"""
音频处理 Celery 任务
TTS 生成、花字渲染等
"""
import os
import time
import logging
from pathlib import Path
from typing import Dict, Any, Optional
from celery import shared_task
import config
from modules import factory, ffmpeg_utils
from modules.text_renderer import renderer
logger = logging.getLogger(__name__)
@shared_task(bind=True, name="audio.generate_tts")
def generate_tts_task(
self,
text: str,
voice_type: str = "zh_female_santongyongns_saturn_bigtts",
target_duration: Optional[float] = None,
output_path: Optional[str] = None
) -> Dict[str, Any]:
"""
生成 TTS 音频(异步任务)
Args:
text: 要转换的文本
voice_type: TTS 音色
target_duration: 目标时长(秒),如果指定会调整音频速度
output_path: 输出路径
Returns:
{"status": "success", "path": "...", "url": "..."}
"""
task_id = self.request.id
logger.info(f"[Task {task_id}] 生成 TTS: {text[:30]}...")
if not output_path:
timestamp = int(time.time())
output_path = str(config.TEMP_DIR / f"tts_{timestamp}.mp3")
try:
# 生成 TTS
audio_path = factory.generate_voiceover_volcengine(
text=text,
voice_type=voice_type,
output_path=output_path
)
if not audio_path or not os.path.exists(audio_path):
raise RuntimeError("TTS 生成失败")
# 如果需要调整时长
if target_duration:
adjusted_path = str(config.TEMP_DIR / f"tts_adj_{int(time.time())}.mp3")
ffmpeg_utils.adjust_audio_duration(audio_path, target_duration, adjusted_path)
# 删除原始文件
if audio_path != output_path:
os.remove(audio_path)
audio_path = adjusted_path
output_url = f"/static/temp/{Path(audio_path).name}"
return {
"status": "success",
"path": audio_path,
"url": output_url,
"task_id": task_id
}
except Exception as e:
logger.error(f"[Task {task_id}] TTS 生成失败: {e}")
raise
@shared_task(bind=True, name="audio.generate_fancy_text")
def generate_fancy_text_task(
self,
text: str,
style: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
生成花字图片(异步任务)
Args:
text: 花字文本
style: 样式配置
Returns:
{"status": "success", "path": "...", "url": "..."}
"""
task_id = self.request.id
logger.info(f"[Task {task_id}] 生成花字: {text}")
if not style:
style = {
"font_size": 72,
"font_color": "#FFFFFF",
"stroke": {"color": "#000000", "width": 5}
}
try:
img_path = renderer.render(
text=text,
style=style,
cache=False
)
if not img_path or not os.path.exists(img_path):
raise RuntimeError("花字生成失败")
output_url = f"/static/temp/{Path(img_path).name}"
return {
"status": "success",
"path": img_path,
"url": output_url,
"task_id": task_id
}
except Exception as e:
logger.error(f"[Task {task_id}] 花字生成失败: {e}")
raise
@shared_task(bind=True, name="audio.batch_generate_tts")
def batch_generate_tts_task(
self,
items: list,
voice_type: str = "zh_female_santongyongns_saturn_bigtts"
) -> Dict[str, Any]:
"""
批量生成 TTS 音频(异步任务)
Args:
items: [{"text": "...", "target_duration": 3.0}, ...]
voice_type: TTS 音色
Returns:
{"status": "success", "results": [...]}
"""
task_id = self.request.id
logger.info(f"[Task {task_id}] 批量生成 TTS: {len(items)}")
results = []
timestamp = int(time.time())
for i, item in enumerate(items):
text = item.get("text", "")
target_duration = item.get("target_duration")
if not text:
results.append({"index": i, "status": "skipped", "reason": "空文本"})
continue
try:
output_path = str(config.TEMP_DIR / f"tts_batch_{timestamp}_{i}.mp3")
audio_path = factory.generate_voiceover_volcengine(
text=text,
voice_type=voice_type,
output_path=output_path
)
if target_duration and audio_path:
adjusted_path = str(config.TEMP_DIR / f"tts_batch_adj_{timestamp}_{i}.mp3")
ffmpeg_utils.adjust_audio_duration(audio_path, target_duration, adjusted_path)
audio_path = adjusted_path
if audio_path:
results.append({
"index": i,
"status": "success",
"path": audio_path,
"url": f"/static/temp/{Path(audio_path).name}"
})
else:
results.append({"index": i, "status": "failed", "reason": "生成失败"})
except Exception as e:
results.append({"index": i, "status": "failed", "reason": str(e)})
# 更新进度
progress = (i + 1) / len(items)
self.update_state(state="PROGRESS", meta={"progress": progress, "completed": i + 1, "total": len(items)})
return {
"status": "success",
"results": results,
"task_id": task_id
}