225 lines
7.6 KiB
Python
225 lines
7.6 KiB
Python
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
|
from pydantic import BaseModel, HttpUrl
|
|
from typing import Dict, List, Optional, Any
|
|
import logging
|
|
|
|
from workflow_service.comfy.comfy_server import (
|
|
server_manager,
|
|
ComfyUIServerInfo,
|
|
ServerStatus,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/comfy", tags=["ComfyUI Server Management"])
|
|
|
|
|
|
class ServerRegistrationRequest(BaseModel):
|
|
"""服务器注册请求模型"""
|
|
|
|
name: str
|
|
http_url: str
|
|
ws_url: str
|
|
|
|
|
|
class ServerHeartbeatRequest(BaseModel):
|
|
"""服务器心跳请求模型"""
|
|
|
|
name: str
|
|
|
|
|
|
class ServerUnregisterRequest(BaseModel):
|
|
"""服务器注销请求模型"""
|
|
|
|
name: str
|
|
|
|
|
|
class ServerStatusResponse(BaseModel):
|
|
"""服务器状态响应模型"""
|
|
|
|
name: str
|
|
http_url: str
|
|
ws_url: str
|
|
status: str
|
|
last_heartbeat: Optional[str] = None
|
|
last_health_check: Optional[str] = None
|
|
current_tasks: int
|
|
max_concurrent_tasks: int
|
|
capabilities: Dict[str, Any]
|
|
metadata: Dict[str, Any]
|
|
|
|
|
|
@router.post("/register", response_model=Dict[str, str])
|
|
async def register_server(request: ServerRegistrationRequest):
|
|
"""注册ComfyUI服务器"""
|
|
try:
|
|
print(request)
|
|
success = await server_manager.register_server(
|
|
name=request.name,
|
|
http_url=request.http_url,
|
|
ws_url=request.ws_url,
|
|
max_concurrent_tasks=1,
|
|
capabilities={},
|
|
metadata={},
|
|
)
|
|
|
|
if success:
|
|
logger.info(f"服务器 {request.name} 注册成功")
|
|
return {"message": f"服务器 {request.name} 注册成功", "status": "success"}
|
|
else:
|
|
raise HTTPException(status_code=400, detail="服务器注册失败")
|
|
|
|
except Exception as e:
|
|
logger.error(f"注册服务器 {request.name} 时发生错误: {e}")
|
|
raise HTTPException(status_code=500, detail=f"注册失败: {str(e)}")
|
|
|
|
|
|
@router.post("/heartbeat", response_model=Dict[str, str])
|
|
async def update_heartbeat(request: ServerHeartbeatRequest):
|
|
"""更新服务器心跳"""
|
|
try:
|
|
success = await server_manager.update_server_heartbeat(request.name)
|
|
|
|
if success:
|
|
return {
|
|
"message": f"服务器 {request.name} 心跳更新成功",
|
|
"status": "success",
|
|
}
|
|
else:
|
|
raise HTTPException(status_code=404, detail=f"服务器 {request.name} 不存在")
|
|
|
|
except Exception as e:
|
|
logger.error(f"更新服务器 {request.name} 心跳时发生错误: {e}")
|
|
raise HTTPException(status_code=500, detail=f"心跳更新失败: {str(e)}")
|
|
|
|
|
|
@router.post("/unregister", response_model=Dict[str, str])
|
|
async def unregister_server(request: ServerUnregisterRequest):
|
|
"""注销ComfyUI服务器"""
|
|
try:
|
|
success = await server_manager.unregister_server(request.name)
|
|
|
|
if success:
|
|
logger.info(f"服务器 {request.name} 注销成功")
|
|
return {"message": f"服务器 {request.name} 注销成功", "status": "success"}
|
|
else:
|
|
raise HTTPException(status_code=404, detail=f"服务器 {request.name} 不存在")
|
|
|
|
except Exception as e:
|
|
logger.error(f"注销服务器 {request.name} 时发生错误: {e}")
|
|
raise HTTPException(status_code=500, detail=f"注销失败: {str(e)}")
|
|
|
|
|
|
@router.get("/status/{server_name}", response_model=ServerStatusResponse)
|
|
async def get_server_status(server_name: str):
|
|
"""获取指定服务器状态"""
|
|
try:
|
|
server_info = await server_manager.get_server_status(server_name)
|
|
|
|
if server_info is None:
|
|
raise HTTPException(status_code=404, detail=f"服务器 {server_name} 不存在")
|
|
|
|
return ServerStatusResponse(
|
|
name=server_info.name,
|
|
http_url=server_info.http_url,
|
|
ws_url=server_info.ws_url,
|
|
status=server_info.status.value,
|
|
last_heartbeat=(
|
|
server_info.last_heartbeat.isoformat()
|
|
if server_info.last_heartbeat
|
|
else None
|
|
),
|
|
last_health_check=(
|
|
server_info.last_health_check.isoformat()
|
|
if server_info.last_health_check
|
|
else None
|
|
),
|
|
current_tasks=server_info.current_tasks,
|
|
max_concurrent_tasks=server_info.max_concurrent_tasks,
|
|
capabilities=server_info.capabilities,
|
|
metadata=server_info.metadata,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取服务器 {server_name} 状态时发生错误: {e}")
|
|
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
|
|
|
|
|
|
@router.get("/list", response_model=List[ServerStatusResponse])
|
|
async def list_all_servers():
|
|
"""获取所有服务器列表"""
|
|
try:
|
|
servers = await server_manager.get_all_servers()
|
|
|
|
return [
|
|
ServerStatusResponse(
|
|
name=server.name,
|
|
http_url=server.http_url,
|
|
ws_url=server.ws_url,
|
|
status=server.status.value,
|
|
last_heartbeat=(
|
|
server.last_heartbeat.isoformat() if server.last_heartbeat else None
|
|
),
|
|
last_health_check=(
|
|
server.last_health_check.isoformat()
|
|
if server.last_health_check
|
|
else None
|
|
),
|
|
current_tasks=server.current_tasks,
|
|
max_concurrent_tasks=server.max_concurrent_tasks,
|
|
capabilities=server.capabilities,
|
|
metadata=server.metadata,
|
|
)
|
|
for server in servers
|
|
]
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取服务器列表时发生错误: {e}")
|
|
raise HTTPException(status_code=500, detail=f"获取列表失败: {str(e)}")
|
|
|
|
|
|
@router.get("/health", response_model=Dict[str, Any])
|
|
async def get_system_health():
|
|
"""获取系统整体健康状态"""
|
|
try:
|
|
servers = await server_manager.get_all_servers()
|
|
|
|
total_servers = len(servers)
|
|
online_servers = len([s for s in servers if s.status == ServerStatus.ONLINE])
|
|
busy_servers = len([s for s in servers if s.status == ServerStatus.BUSY])
|
|
offline_servers = len([s for s in servers if s.status == ServerStatus.OFFLINE])
|
|
error_servers = len([s for s in servers if s.status == ServerStatus.ERROR])
|
|
|
|
total_tasks = sum(s.current_tasks for s in servers)
|
|
total_capacity = sum(s.max_concurrent_tasks for s in servers)
|
|
|
|
return {
|
|
"total_servers": total_servers,
|
|
"online_servers": online_servers,
|
|
"busy_servers": busy_servers,
|
|
"offline_servers": offline_servers,
|
|
"error_servers": error_servers,
|
|
"current_tasks": total_tasks,
|
|
"total_capacity": total_capacity,
|
|
"utilization_rate": (
|
|
(total_tasks / total_capacity * 100) if total_capacity > 0 else 0
|
|
),
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取系统健康状态时发生错误: {e}")
|
|
raise HTTPException(status_code=500, detail=f"获取健康状态失败: {str(e)}")
|
|
|
|
|
|
@router.post("/force_health_check", response_model=Dict[str, str])
|
|
async def force_health_check(background_tasks: BackgroundTasks):
|
|
"""强制执行健康检查"""
|
|
try:
|
|
# 在后台执行健康检查
|
|
background_tasks.add_task(server_manager._perform_health_checks)
|
|
return {"message": "健康检查已启动", "status": "success"}
|
|
|
|
except Exception as e:
|
|
logger.error(f"启动健康检查时发生错误: {e}")
|
|
raise HTTPException(status_code=500, detail=f"启动健康检查失败: {str(e)}")
|