heygem-gateway/api.py

535 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import os
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
import urllib.parse
import aiofiles
import aiohttp
import uvicorn
from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from loguru import logger
from models import Task, TaskStatus
from processor import ProcessManager
# 定义任务管理器
class TaskManager:
def __init__(self):
self.tasks: Dict[str, Task] = {}
self.process_manager = ProcessManager(
config_path="config/machines.json", task_manager=self
)
self.callbacks: Dict[str, List[Callable[[Task], Any]]] = {}
self._task_lock = asyncio.Lock()
self._callback_lock = asyncio.Lock()
self._started = False
async def start(self):
"""启动任务管理器"""
if not self._started:
await self.process_manager.start()
self._started = True
logger.info("任务管理器已启动")
async def stop(self):
"""停止任务管理器"""
if self._started:
await self.process_manager.stop()
self._started = False
logger.info("任务管理器已停止")
async def add_task(
self,
audio_file_path: str,
video_file_path: str,
output_dir: str,
description: str,
callback_url: Optional[str] = None,
) -> Task:
"""添加新任务"""
async with self._task_lock:
task_id = f"task_{len(self.tasks) + 1}"
output_file_path = os.path.join(
output_dir, f"{task_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4"
)
task = Task(
task_id=task_id,
audio_file_path=audio_file_path,
video_file_path=video_file_path,
output_dir=output_dir,
description=description,
status=TaskStatus.PENDING,
created_at=datetime.now(),
callback_url=callback_url,
output_file_path=output_file_path,
)
self.tasks[task_id] = task
# 如果有回调URL添加回调
if callback_url:
await self.add_callback(
task_id, self._create_callback_handler(callback_url)
)
return task
def _create_callback_handler(self, callback_url: str) -> Callable[[Task], Any]:
"""创建回调处理器"""
async def callback_handler(task: Task):
try:
if callback_url.startswith("http"):
async with aiohttp.ClientSession() as session:
await session.post(callback_url, json=task.dict())
else:
logger.info(f"回调URL: {callback_url} 不是HTTP URL")
except Exception as e:
logger.error(f"Error sending callback to {callback_url}: {str(e)}")
return callback_handler
async def process_task_sync(self, task_id: str) -> Task:
"""同步处理任务"""
task = await self.get_task(task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
try:
# 更新任务状态为处理中
await self.update_task_status(task_id, TaskStatus.PROCESSING)
# 生成唯一的文件ID
file_id = f"{task_id}_{int(datetime.now().timestamp())}"
# 处理视频
success = await self.process_manager.process_video(
task_id=task_id,
audio_file_path=task.audio_file_path,
video_file_path=task.video_file_path,
out_file_path=task.output_file_path,
)
if success:
# 更新任务状态为已入队
return await self.update_task_status(task_id, TaskStatus.QUEUED)
else:
return await self.update_task_status(
task_id,
TaskStatus.FAILED,
error_message="Failed to queue video processing task",
)
except Exception as e:
return await self.update_task_status(
task_id, TaskStatus.FAILED, error_message=str(e)
)
async def update_task_status(
self,
task_id: str,
status: TaskStatus,
machine_name: Optional[str] = None,
error_message: Optional[str] = None,
) -> Task:
"""更新任务状态"""
async with self._task_lock:
if task_id not in self.tasks:
raise ValueError(f"Task {task_id} not found")
task = self.tasks[task_id]
logger.info(f"更新任务状态: {task_id}, 从 {task.status}{status}")
task.status = status
if machine_name:
task.machine_name = machine_name
if error_message:
task.error_message = error_message
if status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
task.completed_at = datetime.now()
logger.info(f"任务完成时间已设置: {task_id}, 状态: {status}")
# 触发回调
if task_id in self.callbacks:
logger.info(
f"触发任务回调: {task_id}, 回调数量: {len(self.callbacks[task_id])}"
)
for callback in self.callbacks[task_id]:
try:
await callback(task)
logger.info(f"任务回调执行成功: {task_id}")
except Exception as e:
logger.error(f"任务回调执行失败: {task_id}, 错误: {str(e)}")
return task
async def get_task(self, task_id: str) -> Optional[Task]:
"""获取任务信息"""
return self.tasks.get(task_id)
async def get_all_tasks(self) -> List[Task]:
"""获取所有任务"""
return list(self.tasks.values())
async def add_callback(self, task_id: str, callback: Callable[[Task], Any]) -> None:
"""添加任务回调"""
async with self._callback_lock:
if task_id not in self.callbacks:
self.callbacks[task_id] = []
self.callbacks[task_id].append(callback)
async def remove_callback(
self, task_id: str, callback: Callable[[Task], Any]
) -> None:
"""移除任务回调"""
async with self._callback_lock:
if task_id in self.callbacks:
self.callbacks[task_id].remove(callback)
# 创建FastAPI应用
app = FastAPI(title="视频处理API")
task_manager = TaskManager()
@app.on_event("startup")
async def startup_event():
"""应用启动时初始化"""
await task_manager.start()
@app.on_event("shutdown")
async def shutdown_event():
"""应用关闭时清理"""
await task_manager.stop()
# API路由
@app.post("/tasks/", response_model=Task)
async def create_task(
audio: Optional[UploadFile] = File(None),
video: Optional[UploadFile] = File(None),
audio_url: Optional[str] = Form(None),
video_url: Optional[str] = Form(None),
description: str = Form(...),
callback_url: Optional[str] = Form(None),
background_tasks: BackgroundTasks = BackgroundTasks(),
):
"""创建新的视频处理任务支持文件上传或URL下载"""
try:
# 使用公共函数准备任务文件
task, temp_dir = await prepare_task_files(
audio, video, audio_url, video_url, description, callback_url
)
# 异步处理任务
background_tasks.add_task(
process_video_task,
task.task_id,
task.audio_file_path,
task.video_file_path,
task.output_file_path,
)
return task
except Exception as e:
logger.error(f"创建任务失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/tasks/sync")
async def create_task_sync(
audio: Optional[UploadFile] = File(None),
video: Optional[UploadFile] = File(None),
audio_url: Optional[str] = Form(None),
video_url: Optional[str] = Form(None),
description: str = Form(...),
callback_url: Optional[str] = Form(None),
):
"""创建新的视频处理任务并返回处理结果支持文件上传或URL下载"""
try:
logger.info("创建新的视频处理任务并返回处理结果")
# 使用公共函数准备任务文件
task, temp_dir = await prepare_task_files(
audio, video, audio_url, video_url, description, callback_url
)
logger.info(f"任务创建成功: {task.task_id}")
logger.info(f"音频文件路径: {task.audio_file_path}")
logger.info(f"视频文件路径: {task.video_file_path}")
# 创建一个事件来等待任务完成
task_completed = asyncio.Event()
logger.info(f"创建任务完成事件: {task.task_id}")
async def task_completion_callback(completed_task: Task):
logger.info(
f"任务完成回调被触发: {completed_task.task_id}, 状态: {completed_task.status}"
)
if completed_task.task_id == task.task_id and (
completed_task.status == TaskStatus.COMPLETED
or completed_task.status == TaskStatus.FAILED
):
logger.info(f"设置任务完成事件: {task.task_id}")
task_completed.set()
# 添加回调
await task_manager.add_callback(task.task_id, task_completion_callback)
logger.info(f"已添加任务完成回调: {task.task_id}")
# 开始处理任务
await task_manager.process_task_sync(task.task_id)
logger.info(f"开始处理任务: {task.task_id}")
# 等待任务完成
try:
logger.info(f"等待任务完成: {task.task_id}")
await asyncio.wait_for(task_completed.wait(), timeout=1800) # 30分钟超时
##################sdsadsadsads
logger.info(f"任务完成事件已触发: {task.task_id}")
except asyncio.TimeoutError:
logger.error(f"任务处理超时: {task.task_id}")
raise HTTPException(status_code=500, detail="任务处理超时")
# 获取最终任务状态
result = await task_manager.get_task(task.task_id)
if not result:
raise HTTPException(status_code=404, detail="任务不存在")
if result.status == TaskStatus.COMPLETED:
# 获取输出文件路径
output_file = result.output_file_path
if not os.path.exists(output_file):
logger.error(f"输出文件不存在: {output_file}")
raise HTTPException(status_code=404, detail="输出文件不存在")
# 获取文件名
filename = os.path.basename(output_file)
# 创建文件流
async def file_iterator():
async with aiofiles.open(output_file, "rb") as f:
while chunk := await f.read(8192): # 8KB chunks
yield chunk
return StreamingResponse(
file_iterator(),
media_type="video/mp4",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
else:
raise HTTPException(
status_code=500, detail=result.error_message or "处理失败"
)
except Exception as e:
logger.error(f"创建任务失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/tasks/{task_id}", response_model=Task)
async def get_task(task_id: str):
"""获取任务信息"""
task = await task_manager.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
@app.get("/tasks/", response_model=List[Task])
async def get_all_tasks():
"""获取所有任务"""
return await task_manager.get_all_tasks()
@app.post("/tasks/{task_id}/callback")
async def add_task_callback(
task_id: str, callback_url: str, background_tasks: BackgroundTasks
):
"""添加任务完成回调"""
if not await task_manager.get_task(task_id):
raise HTTPException(status_code=404, detail="Task not found")
async def callback_handler(task: Task):
try:
async with aiohttp.ClientSession() as session:
await session.post(callback_url, json=task.dict())
except Exception as e:
logger.error(f"Error sending callback to {callback_url}: {str(e)}")
await task_manager.add_callback(task_id, callback_handler)
return {"message": "Callback added successfully"}
# 后台任务处理函数
async def process_video_task(
task_id: str, audio_file_path: str, video_file_path: str, out_file_path: str
):
"""处理视频任务的后台函数"""
try:
# 更新任务状态为处理中
await task_manager.update_task_status(task_id, TaskStatus.PROCESSING)
# 生成唯一的文件ID
file_id = f"{task_id}_{int(datetime.now().timestamp())}"
# 处理视频
success = await task_manager.process_manager.process_video(
task_id, audio_file_path, video_file_path, out_file_path
)
if success:
await task_manager.update_task_status(task_id, TaskStatus.COMPLETED)
else:
await task_manager.update_task_status(
task_id, TaskStatus.FAILED, error_message="Video processing failed"
)
except Exception as e:
await task_manager.update_task_status(
task_id, TaskStatus.FAILED, error_message=str(e)
)
# 清理临时文件
# if os.path.exists(temp_dir):
# shutil.rmtree(temp_dir)
# logger.info(f"已清理临时目录: {temp_dir}")
# 添加文件下载辅助函数
async def download_file_from_url(url: str, destination_path: str) -> bool:
"""从URL下载文件到指定路径"""
try:
logger.info(f"开始下载文件: {url} -> {destination_path}")
async with aiohttp.ClientSession() as session:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=300)
) as response:
if response.status == 200:
# 获取Content-Length用于进度追踪
content_length = response.headers.get("Content-Length")
if content_length:
logger.info(f"文件大小: {int(content_length)} bytes")
async with aiofiles.open(destination_path, "wb") as f:
downloaded = 0
async for chunk in response.content.iter_chunked(8192):
await f.write(chunk)
downloaded += len(chunk)
logger.info(
f"文件下载完成: {destination_path} ({downloaded} bytes)"
)
return True
else:
logger.error(f"下载失败HTTP状态码: {response.status}")
return False
except Exception as e:
logger.error(f"下载文件时发生错误: {str(e)}")
return False
async def process_media_input(
media_file: Optional[UploadFile],
media_url: Optional[str],
media_type: str,
temp_dir: str,
) -> str:
"""处理媒体输入文件或URL返回本地文件路径"""
if media_file and media_url:
raise ValueError(f"{media_type}不能同时提供文件和URL")
if not media_file and not media_url:
raise ValueError(f"必须提供{media_type}文件或URL")
if media_file:
# 处理上传文件
if not media_file.filename:
raise ValueError(f"{media_type}文件名不能为空")
file_path = os.path.join(temp_dir, media_file.filename)
logger.info(f"开始保存{media_type}文件: {file_path}")
async with aiofiles.open(file_path, "wb") as f:
content = await media_file.read()
await f.write(content)
logger.info(f"{media_type}文件保存完成: {file_path}")
return file_path
else:
# 处理URL下载
parsed_url = urllib.parse.urlparse(media_url)
filename = os.path.basename(parsed_url.path)
if not filename or "." not in filename:
# 如果URL没有文件名或没有扩展名根据媒体类型生成
if media_type == "音频":
extension = ".mp3" # 默认音频扩展名
elif media_type == "视频":
extension = ".mp4" # 默认视频扩展名
else:
extension = ""
if not filename:
filename = f"{media_type}_{int(datetime.now().timestamp())}{extension}"
elif "." not in filename:
filename = f"{filename}{extension}"
file_path = os.path.join(temp_dir, filename)
success = await download_file_from_url(media_url, file_path)
if not success:
raise ValueError(f"下载{media_type}文件失败: {media_url}")
return file_path
async def prepare_task_files(
audio: Optional[UploadFile],
video: Optional[UploadFile],
audio_url: Optional[str],
video_url: Optional[str],
description: str,
callback_url: Optional[str],
) -> tuple[Task, str]:
"""
准备任务文件和创建任务的公共逻辑
返回创建的任务对象和临时目录路径
"""
# 创建临时目录
temp_dir = os.path.join(
os.getcwd(), "temp", datetime.now().strftime("%Y%m%d_%H%M%S")
)
os.makedirs(temp_dir, exist_ok=True)
try:
# 处理音频输入文件或URL
audio_path = await process_media_input(audio, audio_url, "音频", temp_dir)
# 处理视频输入文件或URL
video_path = await process_media_input(video, video_url, "视频", temp_dir)
# 验证文件是否存在
if not os.path.exists(audio_path) or not os.path.exists(video_path):
raise ValueError("文件处理失败")
# 创建输出目录
output_dir = os.path.join("output", datetime.now().strftime("%Y%m%d"))
os.makedirs(output_dir, exist_ok=True)
# 创建任务
task = await task_manager.add_task(
audio_path, video_path, output_dir, description, callback_url
)
return task, temp_dir
except Exception as e:
# 如果发生错误,立即清理临时文件
# if os.path.exists(temp_dir):
# shutil.rmtree(temp_dir)
logger.error(f"准备任务文件失败: {str(e)}")
raise
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"api:app", host="0.0.0.0", port=8101, reload=True, workers=1, log_level="info"
) # 开发模式下启用热重载 # 单进程模式