ComfyUI-WorkflowPublisher/app/routes/comfy_server.py

113 lines
3.4 KiB
Python

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import Dict, List, Optional, Any
import logging
import aiohttp
import asyncio
from app.comfy.comfy_server import server_manager
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/comfy", tags=["ComfyUI Server Management"])
class ServerRegistrationRequest(BaseModel):
"""服务器注册请求模型"""
name: str
http_url: str
ws_url: str
class ServerUnregisterRequest(BaseModel):
"""服务器注销请求模型"""
name: str
class ServerStatusResponse(BaseModel):
"""服务器状态响应模型"""
name: str
http_url: str
ws_url: str
status: str
last_health_check: Optional[str] = None
current_tasks: int
max_concurrent_tasks: int
capabilities: Dict[str, Any]
metadata: Dict[str, Any]
@router.post("/register", response_model=Dict[str, str])
async def register_server(request: ServerRegistrationRequest):
"""注册ComfyUI服务器"""
try:
print(request)
success = await server_manager.register_server(
name=request.name,
http_url=request.http_url,
ws_url=request.ws_url,
max_concurrent_tasks=1,
capabilities={},
metadata={},
)
if success:
logger.info(f"服务器 {request.name} 注册成功")
return {"message": f"服务器 {request.name} 注册成功", "status": "success"}
else:
raise HTTPException(status_code=400, detail="服务器注册失败")
except Exception as e:
logger.error(f"注册服务器 {request.name} 时发生错误: {e}")
raise HTTPException(status_code=500, detail=f"注册失败: {str(e)}")
@router.post("/unregister", response_model=Dict[str, str])
async def unregister_server(request: ServerUnregisterRequest):
"""注销ComfyUI服务器"""
try:
success = await server_manager.unregister_server(request.name)
if success:
logger.info(f"服务器 {request.name} 注销成功")
return {"message": f"服务器 {request.name} 注销成功", "status": "success"}
else:
raise HTTPException(status_code=404, detail=f"服务器 {request.name} 不存在")
except Exception as e:
logger.error(f"注销服务器 {request.name} 时发生错误: {e}")
raise HTTPException(status_code=500, detail=f"注销失败: {str(e)}")
@router.get("/list", response_model=List[ServerStatusResponse])
async def list_all_servers():
"""获取所有服务器列表"""
try:
servers = await server_manager.get_all_servers()
return [
ServerStatusResponse(
name=server.name,
http_url=server.http_url,
ws_url=server.ws_url,
status=server.status.value,
last_health_check=(
server.last_health_check.isoformat()
if server.last_health_check
else None
),
current_tasks=server.current_tasks,
max_concurrent_tasks=server.max_concurrent_tasks,
capabilities=server.capabilities,
metadata=server.metadata,
)
for server in servers
]
except Exception as e:
logger.error(f"获取服务器列表时发生错误: {e}")
raise HTTPException(status_code=500, detail=f"获取列表失败: {str(e)}")