from fastapi import APIRouter, HTTPException, BackgroundTasks from pydantic import BaseModel, HttpUrl from typing import Dict, List, Optional, Any import logging from workflow_service.comfy.comfy_server import ( server_manager, ComfyUIServerInfo, ServerStatus, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/comfy", tags=["ComfyUI Server Management"]) class ServerRegistrationRequest(BaseModel): """服务器注册请求模型""" name: str http_url: str ws_url: str class ServerHeartbeatRequest(BaseModel): """服务器心跳请求模型""" name: str class ServerUnregisterRequest(BaseModel): """服务器注销请求模型""" name: str class ServerStatusResponse(BaseModel): """服务器状态响应模型""" name: str http_url: str ws_url: str status: str last_heartbeat: Optional[str] = None last_health_check: Optional[str] = None current_tasks: int max_concurrent_tasks: int capabilities: Dict[str, Any] metadata: Dict[str, Any] @router.post("/register", response_model=Dict[str, str]) async def register_server(request: ServerRegistrationRequest): """注册ComfyUI服务器""" try: print(request) success = await server_manager.register_server( name=request.name, http_url=request.http_url, ws_url=request.ws_url, max_concurrent_tasks=1, capabilities={}, metadata={}, ) if success: logger.info(f"服务器 {request.name} 注册成功") return {"message": f"服务器 {request.name} 注册成功", "status": "success"} else: raise HTTPException(status_code=400, detail="服务器注册失败") except Exception as e: logger.error(f"注册服务器 {request.name} 时发生错误: {e}") raise HTTPException(status_code=500, detail=f"注册失败: {str(e)}") @router.post("/heartbeat", response_model=Dict[str, str]) async def update_heartbeat(request: ServerHeartbeatRequest): """更新服务器心跳""" try: success = await server_manager.update_server_heartbeat(request.name) if success: return { "message": f"服务器 {request.name} 心跳更新成功", "status": "success", } else: raise HTTPException(status_code=404, detail=f"服务器 {request.name} 不存在") except Exception as e: logger.error(f"更新服务器 {request.name} 心跳时发生错误: {e}") raise HTTPException(status_code=500, detail=f"心跳更新失败: {str(e)}") @router.post("/unregister", response_model=Dict[str, str]) async def unregister_server(request: ServerUnregisterRequest): """注销ComfyUI服务器""" try: success = await server_manager.unregister_server(request.name) if success: logger.info(f"服务器 {request.name} 注销成功") return {"message": f"服务器 {request.name} 注销成功", "status": "success"} else: raise HTTPException(status_code=404, detail=f"服务器 {request.name} 不存在") except Exception as e: logger.error(f"注销服务器 {request.name} 时发生错误: {e}") raise HTTPException(status_code=500, detail=f"注销失败: {str(e)}") @router.get("/status/{server_name}", response_model=ServerStatusResponse) async def get_server_status(server_name: str): """获取指定服务器状态""" try: server_info = await server_manager.get_server_status(server_name) if server_info is None: raise HTTPException(status_code=404, detail=f"服务器 {server_name} 不存在") return ServerStatusResponse( name=server_info.name, http_url=server_info.http_url, ws_url=server_info.ws_url, status=server_info.status.value, last_heartbeat=( server_info.last_heartbeat.isoformat() if server_info.last_heartbeat else None ), last_health_check=( server_info.last_health_check.isoformat() if server_info.last_health_check else None ), current_tasks=server_info.current_tasks, max_concurrent_tasks=server_info.max_concurrent_tasks, capabilities=server_info.capabilities, metadata=server_info.metadata, ) except Exception as e: logger.error(f"获取服务器 {server_name} 状态时发生错误: {e}") raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") @router.get("/list", response_model=List[ServerStatusResponse]) async def list_all_servers(): """获取所有服务器列表""" try: servers = await server_manager.get_all_servers() return [ ServerStatusResponse( name=server.name, http_url=server.http_url, ws_url=server.ws_url, status=server.status.value, last_heartbeat=( server.last_heartbeat.isoformat() if server.last_heartbeat else None ), last_health_check=( server.last_health_check.isoformat() if server.last_health_check else None ), current_tasks=server.current_tasks, max_concurrent_tasks=server.max_concurrent_tasks, capabilities=server.capabilities, metadata=server.metadata, ) for server in servers ] except Exception as e: logger.error(f"获取服务器列表时发生错误: {e}") raise HTTPException(status_code=500, detail=f"获取列表失败: {str(e)}") @router.get("/health", response_model=Dict[str, Any]) async def get_system_health(): """获取系统整体健康状态""" try: servers = await server_manager.get_all_servers() total_servers = len(servers) online_servers = len([s for s in servers if s.status == ServerStatus.ONLINE]) busy_servers = len([s for s in servers if s.status == ServerStatus.BUSY]) offline_servers = len([s for s in servers if s.status == ServerStatus.OFFLINE]) error_servers = len([s for s in servers if s.status == ServerStatus.ERROR]) total_tasks = sum(s.current_tasks for s in servers) total_capacity = sum(s.max_concurrent_tasks for s in servers) return { "total_servers": total_servers, "online_servers": online_servers, "busy_servers": busy_servers, "offline_servers": offline_servers, "error_servers": error_servers, "current_tasks": total_tasks, "total_capacity": total_capacity, "utilization_rate": ( (total_tasks / total_capacity * 100) if total_capacity > 0 else 0 ), } except Exception as e: logger.error(f"获取系统健康状态时发生错误: {e}") raise HTTPException(status_code=500, detail=f"获取健康状态失败: {str(e)}") @router.post("/force_health_check", response_model=Dict[str, str]) async def force_health_check(background_tasks: BackgroundTasks): """强制执行健康检查""" try: # 在后台执行健康检查 background_tasks.add_task(server_manager._perform_health_checks) return {"message": "健康检查已启动", "status": "success"} except Exception as e: logger.error(f"启动健康检查时发生错误: {e}") raise HTTPException(status_code=500, detail=f"启动健康检查失败: {str(e)}")