from fastapi import APIRouter, HTTPException from pydantic import BaseModel from typing import Dict, List, Optional, Any import logging import aiohttp import asyncio from app.comfy.comfy_server import server_manager logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/comfy", tags=["ComfyUI Server Management"]) class ServerRegistrationRequest(BaseModel): """服务器注册请求模型""" name: str http_url: str ws_url: str class ServerUnregisterRequest(BaseModel): """服务器注销请求模型""" name: str class ServerStatusResponse(BaseModel): """服务器状态响应模型""" name: str http_url: str ws_url: str status: str last_health_check: Optional[str] = None current_tasks: int max_concurrent_tasks: int capabilities: Dict[str, Any] metadata: Dict[str, Any] @router.post("/register", response_model=Dict[str, str]) async def register_server(request: ServerRegistrationRequest): """注册ComfyUI服务器""" try: print(request) success = await server_manager.register_server( name=request.name, http_url=request.http_url, ws_url=request.ws_url, max_concurrent_tasks=1, capabilities={}, metadata={}, ) if success: logger.info(f"服务器 {request.name} 注册成功") return {"message": f"服务器 {request.name} 注册成功", "status": "success"} else: raise HTTPException(status_code=400, detail="服务器注册失败") except Exception as e: logger.error(f"注册服务器 {request.name} 时发生错误: {e}") raise HTTPException(status_code=500, detail=f"注册失败: {str(e)}") @router.post("/unregister", response_model=Dict[str, str]) async def unregister_server(request: ServerUnregisterRequest): """注销ComfyUI服务器""" try: success = await server_manager.unregister_server(request.name) if success: logger.info(f"服务器 {request.name} 注销成功") return {"message": f"服务器 {request.name} 注销成功", "status": "success"} else: raise HTTPException(status_code=404, detail=f"服务器 {request.name} 不存在") except Exception as e: logger.error(f"注销服务器 {request.name} 时发生错误: {e}") raise HTTPException(status_code=500, detail=f"注销失败: {str(e)}") @router.get("/list", response_model=List[ServerStatusResponse]) async def list_all_servers(): """获取所有服务器列表""" try: servers = await server_manager.get_all_servers() return [ ServerStatusResponse( name=server.name, http_url=server.http_url, ws_url=server.ws_url, status=server.status.value, last_health_check=( server.last_health_check.isoformat() if server.last_health_check else None ), current_tasks=server.current_tasks, max_concurrent_tasks=server.max_concurrent_tasks, capabilities=server.capabilities, metadata=server.metadata, ) for server in servers ] except Exception as e: logger.error(f"获取服务器列表时发生错误: {e}") raise HTTPException(status_code=500, detail=f"获取列表失败: {str(e)}")