ComfyUI-WorkflowPublisher/workflow_service/routes/comfy_server.py

225 lines
7.6 KiB
Python

from fastapi import APIRouter, HTTPException, BackgroundTasks
from pydantic import BaseModel, HttpUrl
from typing import Dict, List, Optional, Any
import logging
from workflow_service.comfy.comfy_server import (
server_manager,
ComfyUIServerInfo,
ServerStatus,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/comfy", tags=["ComfyUI Server Management"])
class ServerRegistrationRequest(BaseModel):
"""服务器注册请求模型"""
name: str
http_url: str
ws_url: str
class ServerHeartbeatRequest(BaseModel):
"""服务器心跳请求模型"""
name: str
class ServerUnregisterRequest(BaseModel):
"""服务器注销请求模型"""
name: str
class ServerStatusResponse(BaseModel):
"""服务器状态响应模型"""
name: str
http_url: str
ws_url: str
status: str
last_heartbeat: Optional[str] = None
last_health_check: Optional[str] = None
current_tasks: int
max_concurrent_tasks: int
capabilities: Dict[str, Any]
metadata: Dict[str, Any]
@router.post("/register", response_model=Dict[str, str])
async def register_server(request: ServerRegistrationRequest):
"""注册ComfyUI服务器"""
try:
print(request)
success = await server_manager.register_server(
name=request.name,
http_url=request.http_url,
ws_url=request.ws_url,
max_concurrent_tasks=1,
capabilities={},
metadata={},
)
if success:
logger.info(f"服务器 {request.name} 注册成功")
return {"message": f"服务器 {request.name} 注册成功", "status": "success"}
else:
raise HTTPException(status_code=400, detail="服务器注册失败")
except Exception as e:
logger.error(f"注册服务器 {request.name} 时发生错误: {e}")
raise HTTPException(status_code=500, detail=f"注册失败: {str(e)}")
@router.post("/heartbeat", response_model=Dict[str, str])
async def update_heartbeat(request: ServerHeartbeatRequest):
"""更新服务器心跳"""
try:
success = await server_manager.update_server_heartbeat(request.name)
if success:
return {
"message": f"服务器 {request.name} 心跳更新成功",
"status": "success",
}
else:
raise HTTPException(status_code=404, detail=f"服务器 {request.name} 不存在")
except Exception as e:
logger.error(f"更新服务器 {request.name} 心跳时发生错误: {e}")
raise HTTPException(status_code=500, detail=f"心跳更新失败: {str(e)}")
@router.post("/unregister", response_model=Dict[str, str])
async def unregister_server(request: ServerUnregisterRequest):
"""注销ComfyUI服务器"""
try:
success = await server_manager.unregister_server(request.name)
if success:
logger.info(f"服务器 {request.name} 注销成功")
return {"message": f"服务器 {request.name} 注销成功", "status": "success"}
else:
raise HTTPException(status_code=404, detail=f"服务器 {request.name} 不存在")
except Exception as e:
logger.error(f"注销服务器 {request.name} 时发生错误: {e}")
raise HTTPException(status_code=500, detail=f"注销失败: {str(e)}")
@router.get("/status/{server_name}", response_model=ServerStatusResponse)
async def get_server_status(server_name: str):
"""获取指定服务器状态"""
try:
server_info = await server_manager.get_server_status(server_name)
if server_info is None:
raise HTTPException(status_code=404, detail=f"服务器 {server_name} 不存在")
return ServerStatusResponse(
name=server_info.name,
http_url=server_info.http_url,
ws_url=server_info.ws_url,
status=server_info.status.value,
last_heartbeat=(
server_info.last_heartbeat.isoformat()
if server_info.last_heartbeat
else None
),
last_health_check=(
server_info.last_health_check.isoformat()
if server_info.last_health_check
else None
),
current_tasks=server_info.current_tasks,
max_concurrent_tasks=server_info.max_concurrent_tasks,
capabilities=server_info.capabilities,
metadata=server_info.metadata,
)
except Exception as e:
logger.error(f"获取服务器 {server_name} 状态时发生错误: {e}")
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
@router.get("/list", response_model=List[ServerStatusResponse])
async def list_all_servers():
"""获取所有服务器列表"""
try:
servers = await server_manager.get_all_servers()
return [
ServerStatusResponse(
name=server.name,
http_url=server.http_url,
ws_url=server.ws_url,
status=server.status.value,
last_heartbeat=(
server.last_heartbeat.isoformat() if server.last_heartbeat else None
),
last_health_check=(
server.last_health_check.isoformat()
if server.last_health_check
else None
),
current_tasks=server.current_tasks,
max_concurrent_tasks=server.max_concurrent_tasks,
capabilities=server.capabilities,
metadata=server.metadata,
)
for server in servers
]
except Exception as e:
logger.error(f"获取服务器列表时发生错误: {e}")
raise HTTPException(status_code=500, detail=f"获取列表失败: {str(e)}")
@router.get("/health", response_model=Dict[str, Any])
async def get_system_health():
"""获取系统整体健康状态"""
try:
servers = await server_manager.get_all_servers()
total_servers = len(servers)
online_servers = len([s for s in servers if s.status == ServerStatus.ONLINE])
busy_servers = len([s for s in servers if s.status == ServerStatus.BUSY])
offline_servers = len([s for s in servers if s.status == ServerStatus.OFFLINE])
error_servers = len([s for s in servers if s.status == ServerStatus.ERROR])
total_tasks = sum(s.current_tasks for s in servers)
total_capacity = sum(s.max_concurrent_tasks for s in servers)
return {
"total_servers": total_servers,
"online_servers": online_servers,
"busy_servers": busy_servers,
"offline_servers": offline_servers,
"error_servers": error_servers,
"current_tasks": total_tasks,
"total_capacity": total_capacity,
"utilization_rate": (
(total_tasks / total_capacity * 100) if total_capacity > 0 else 0
),
}
except Exception as e:
logger.error(f"获取系统健康状态时发生错误: {e}")
raise HTTPException(status_code=500, detail=f"获取健康状态失败: {str(e)}")
@router.post("/force_health_check", response_model=Dict[str, str])
async def force_health_check(background_tasks: BackgroundTasks):
"""强制执行健康检查"""
try:
# 在后台执行健康检查
background_tasks.add_task(server_manager._perform_health_checks)
return {"message": "健康检查已启动", "status": "success"}
except Exception as e:
logger.error(f"启动健康检查时发生错误: {e}")
raise HTTPException(status_code=500, detail=f"启动健康检查失败: {str(e)}")