140 lines
4.7 KiB
Python
140 lines
4.7 KiB
Python
import json
|
||
from typing import Dict, Optional, List, Any
|
||
from datetime import datetime, timedelta
|
||
|
||
from fastapi import APIRouter, Body, HTTPException
|
||
from fastapi.responses import JSONResponse
|
||
|
||
from workflow_service.comfy.comfy_queue import queue_manager
|
||
from workflow_service.comfy.comfy_workflow import ComfyWorkflow
|
||
from workflow_service.database.api import get_workflow, get_workflow_runs_recent
|
||
|
||
run_router = APIRouter(
|
||
prefix="/api/run",
|
||
tags=["Run"],
|
||
)
|
||
|
||
|
||
@run_router.post("")
|
||
async def run_workflow(
|
||
workflow_name: str,
|
||
workflow_version: Optional[str] = None,
|
||
data: Dict = Body(...),
|
||
):
|
||
"""
|
||
异步执行工作流。
|
||
立即返回任务ID,调用者可以通过任务ID查询执行状态。
|
||
"""
|
||
try:
|
||
if not workflow_name:
|
||
raise HTTPException(status_code=400, detail="`workflow_name` 字段是必需的")
|
||
|
||
# 获取工作流定义
|
||
workflow_data = await get_workflow(workflow_name, workflow_version)
|
||
if not workflow_data:
|
||
detail = (
|
||
f"工作流 '{workflow_name}'"
|
||
+ (f" 带版本 '{workflow_version}'" if workflow_version else " (最新版)")
|
||
+ " 未找到。"
|
||
)
|
||
raise HTTPException(status_code=404, detail=detail)
|
||
|
||
workflow = json.loads(workflow_data["workflow_json"])
|
||
flow = ComfyWorkflow(workflow_name, workflow)
|
||
|
||
# 提交到队列
|
||
workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=data)
|
||
|
||
return JSONResponse(
|
||
content={
|
||
"workflow_run_id": workflow_run_id,
|
||
"status": "queued",
|
||
"message": "工作流已提交到队列,正在等待执行",
|
||
},
|
||
status_code=202,
|
||
)
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"提交工作流失败: {str(e)}")
|
||
|
||
|
||
@run_router.get("")
|
||
async def get_runs(
|
||
limit: int = 10, status: Optional[str] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""获取运行列表,支持分页和状态过滤"""
|
||
try:
|
||
end_time = datetime.now()
|
||
start_time = end_time - timedelta(hours=24)
|
||
|
||
recent_runs = await get_workflow_runs_recent(start_time, end_time)
|
||
|
||
# 如果指定了状态过滤
|
||
if status:
|
||
recent_runs = [run for run in recent_runs if run.get("status") == status]
|
||
|
||
# 限制返回数量
|
||
limited_runs = recent_runs[:limit]
|
||
|
||
# 格式化返回数据
|
||
formatted_runs = []
|
||
for run in limited_runs:
|
||
formatted_runs.append(
|
||
{
|
||
"id": run.get("id"), # 使用数据库中的id字段
|
||
"workflow_name": run.get("workflow_name"),
|
||
"status": run.get("status", "unknown"),
|
||
"created_at": run.get("created_at"),
|
||
"updated_at": run.get("updated_at"),
|
||
"api_spec": run.get("api_spec"),
|
||
"result": run.get("result"), # 添加任务结果
|
||
"error_message": run.get("error_message"), # 添加错误信息
|
||
"completed_at": run.get("completed_at"), # 添加完成时间
|
||
}
|
||
)
|
||
|
||
return formatted_runs
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取运行列表失败: {str(e)}")
|
||
|
||
|
||
@run_router.get("/metrics")
|
||
async def get_run_metrics() -> Dict[str, Any]:
|
||
"""获取运行概览统计信息"""
|
||
try:
|
||
# 获取最近24小时的任务统计
|
||
end_time = datetime.now()
|
||
start_time = end_time - timedelta(hours=24)
|
||
|
||
recent_runs = await get_workflow_runs_recent(start_time, end_time)
|
||
|
||
# 统计各种状态的任务数量
|
||
status_counts = {}
|
||
for run in recent_runs:
|
||
status = run.get("status", "unknown")
|
||
status_counts[status] = status_counts.get(status, 0) + 1
|
||
|
||
return {
|
||
"running_tasks": status_counts.get("running", 0),
|
||
"pending_tasks": status_counts.get("pending", 0),
|
||
"completed_tasks": status_counts.get("completed", 0),
|
||
"failed_tasks": status_counts.get("failed", 0),
|
||
"total_tasks_24h": len(recent_runs),
|
||
"timestamp": datetime.now().isoformat(),
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取运行概览失败: {str(e)}")
|
||
|
||
|
||
@run_router.get("/{workflow_run_id}")
|
||
async def get_run_status(workflow_run_id: str):
|
||
"""
|
||
获取工作流执行状态。
|
||
"""
|
||
try:
|
||
status = await queue_manager.get_task_status(workflow_run_id)
|
||
return status
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
|