113 lines
3.4 KiB
Python
113 lines
3.4 KiB
Python
from fastapi import APIRouter, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import Dict, List, Optional, Any
|
|
import logging
|
|
import aiohttp
|
|
import asyncio
|
|
|
|
from app.comfy.comfy_server import server_manager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/comfy", tags=["ComfyUI Server Management"])
|
|
|
|
|
|
class ServerRegistrationRequest(BaseModel):
|
|
"""服务器注册请求模型"""
|
|
|
|
name: str
|
|
http_url: str
|
|
ws_url: str
|
|
|
|
|
|
class ServerUnregisterRequest(BaseModel):
|
|
"""服务器注销请求模型"""
|
|
|
|
name: str
|
|
|
|
|
|
class ServerStatusResponse(BaseModel):
|
|
"""服务器状态响应模型"""
|
|
|
|
name: str
|
|
http_url: str
|
|
ws_url: str
|
|
status: str
|
|
last_health_check: Optional[str] = None
|
|
current_tasks: int
|
|
max_concurrent_tasks: int
|
|
capabilities: Dict[str, Any]
|
|
metadata: Dict[str, Any]
|
|
|
|
|
|
@router.post("/register", response_model=Dict[str, str])
|
|
async def register_server(request: ServerRegistrationRequest):
|
|
"""注册ComfyUI服务器"""
|
|
try:
|
|
print(request)
|
|
success = await server_manager.register_server(
|
|
name=request.name,
|
|
http_url=request.http_url,
|
|
ws_url=request.ws_url,
|
|
max_concurrent_tasks=1,
|
|
capabilities={},
|
|
metadata={},
|
|
)
|
|
|
|
if success:
|
|
logger.info(f"服务器 {request.name} 注册成功")
|
|
return {"message": f"服务器 {request.name} 注册成功", "status": "success"}
|
|
else:
|
|
raise HTTPException(status_code=400, detail="服务器注册失败")
|
|
|
|
except Exception as e:
|
|
logger.error(f"注册服务器 {request.name} 时发生错误: {e}")
|
|
raise HTTPException(status_code=500, detail=f"注册失败: {str(e)}")
|
|
|
|
|
|
@router.post("/unregister", response_model=Dict[str, str])
|
|
async def unregister_server(request: ServerUnregisterRequest):
|
|
"""注销ComfyUI服务器"""
|
|
try:
|
|
success = await server_manager.unregister_server(request.name)
|
|
|
|
if success:
|
|
logger.info(f"服务器 {request.name} 注销成功")
|
|
return {"message": f"服务器 {request.name} 注销成功", "status": "success"}
|
|
else:
|
|
raise HTTPException(status_code=404, detail=f"服务器 {request.name} 不存在")
|
|
|
|
except Exception as e:
|
|
logger.error(f"注销服务器 {request.name} 时发生错误: {e}")
|
|
raise HTTPException(status_code=500, detail=f"注销失败: {str(e)}")
|
|
|
|
|
|
@router.get("/list", response_model=List[ServerStatusResponse])
|
|
async def list_all_servers():
|
|
"""获取所有服务器列表"""
|
|
try:
|
|
servers = await server_manager.get_all_servers()
|
|
|
|
return [
|
|
ServerStatusResponse(
|
|
name=server.name,
|
|
http_url=server.http_url,
|
|
ws_url=server.ws_url,
|
|
status=server.status.value,
|
|
last_health_check=(
|
|
server.last_health_check.isoformat()
|
|
if server.last_health_check
|
|
else None
|
|
),
|
|
current_tasks=server.current_tasks,
|
|
max_concurrent_tasks=server.max_concurrent_tasks,
|
|
capabilities=server.capabilities,
|
|
metadata=server.metadata,
|
|
)
|
|
for server in servers
|
|
]
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取服务器列表时发生错误: {e}")
|
|
raise HTTPException(status_code=500, detail=f"获取列表失败: {str(e)}")
|