477 lines
18 KiB
Python
477 lines
18 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Set
|
|
from dataclasses import dataclass, asdict
|
|
from enum import Enum
|
|
|
|
import aiohttp
|
|
from aiohttp import ClientTimeout
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, update, delete
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from workflow_service.database.connection import AsyncSessionLocal
|
|
from workflow_service.database.models import ComfyUIServer as ComfyUIServerModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ServerStatus(Enum):
|
|
"""服务器状态枚举"""
|
|
|
|
ONLINE = "online"
|
|
OFFLINE = "offline"
|
|
BUSY = "busy"
|
|
ERROR = "error"
|
|
|
|
|
|
@dataclass
|
|
class ComfyUIServerInfo:
|
|
"""ComfyUI服务器信息"""
|
|
|
|
name: str # 服务器名称,由注册方自行拟定
|
|
http_url: str
|
|
ws_url: str
|
|
status: ServerStatus = ServerStatus.OFFLINE
|
|
last_heartbeat: Optional[datetime] = None
|
|
last_health_check: Optional[datetime] = None
|
|
current_tasks: int = 0
|
|
max_concurrent_tasks: int = 1
|
|
capabilities: Dict[str, any] = None # 服务器能力信息,如支持的模型等
|
|
metadata: Dict[str, any] = None # 其他元数据
|
|
|
|
def __post_init__(self):
|
|
if self.capabilities is None:
|
|
self.capabilities = {}
|
|
if self.metadata is None:
|
|
self.metadata = {}
|
|
|
|
@classmethod
|
|
def from_model(cls, model: ComfyUIServerModel) -> "ComfyUIServerInfo":
|
|
"""从数据库模型创建实例"""
|
|
capabilities = json.loads(model.capabilities) if model.capabilities else {}
|
|
metadata = json.loads(model.server_metadata) if model.server_metadata else {}
|
|
|
|
return cls(
|
|
name=model.name,
|
|
http_url=model.http_url,
|
|
ws_url=model.ws_url,
|
|
status=ServerStatus(model.status),
|
|
last_heartbeat=model.last_heartbeat,
|
|
last_health_check=model.last_health_check,
|
|
current_tasks=model.current_tasks,
|
|
max_concurrent_tasks=model.max_concurrent_tasks,
|
|
capabilities=capabilities,
|
|
metadata=metadata,
|
|
)
|
|
|
|
def to_model(self) -> ComfyUIServerModel:
|
|
"""转换为数据库模型"""
|
|
return ComfyUIServerModel(
|
|
name=self.name,
|
|
http_url=self.http_url,
|
|
ws_url=self.ws_url,
|
|
status=self.status.value,
|
|
last_heartbeat=self.last_heartbeat,
|
|
last_health_check=self.last_health_check,
|
|
current_tasks=self.current_tasks,
|
|
max_concurrent_tasks=self.max_concurrent_tasks,
|
|
capabilities=json.dumps(self.capabilities),
|
|
server_metadata=json.dumps(self.metadata),
|
|
)
|
|
|
|
|
|
class ComfyUIServerManager:
|
|
"""ComfyUI服务器管理器"""
|
|
|
|
def __init__(self):
|
|
self.servers: Dict[str, ComfyUIServerInfo] = {} # name -> server_info
|
|
self.lock = asyncio.Lock()
|
|
self.health_check_interval = 30 # 健康检查间隔(秒)
|
|
self.heartbeat_timeout = 60 # 心跳超时时间(秒)
|
|
self._health_check_task: Optional[asyncio.Task] = None
|
|
self._initialized = False
|
|
|
|
async def initialize(self):
|
|
"""初始化管理器,从数据库加载服务器信息"""
|
|
if self._initialized:
|
|
return
|
|
|
|
async with self.lock:
|
|
try:
|
|
await self._load_servers_from_db()
|
|
self._initialized = True
|
|
logger.info(f"从数据库加载了 {len(self.servers)} 个服务器")
|
|
except Exception as e:
|
|
logger.error(f"从数据库加载服务器失败: {e}")
|
|
self._initialized = True # 即使失败也标记为已初始化,避免重复尝试
|
|
|
|
async def _load_servers_from_db(self):
|
|
"""从数据库加载服务器信息"""
|
|
async with AsyncSessionLocal() as session:
|
|
async with session.begin():
|
|
result = await session.execute(select(ComfyUIServerModel))
|
|
models = result.scalars().all()
|
|
|
|
for model in models:
|
|
try:
|
|
server_info = ComfyUIServerInfo.from_model(model)
|
|
self.servers[server_info.name] = server_info
|
|
except Exception as e:
|
|
logger.error(f"加载服务器 {model.name} 失败: {e}")
|
|
|
|
async def _save_server_to_db(self, server: ComfyUIServerInfo):
|
|
"""保存服务器信息到数据库"""
|
|
try:
|
|
async with AsyncSessionLocal() as session:
|
|
async with session.begin():
|
|
# 检查是否已存在
|
|
existing = await session.execute(
|
|
select(ComfyUIServerModel).where(
|
|
ComfyUIServerModel.name == server.name
|
|
)
|
|
)
|
|
existing_model = existing.scalar_one_or_none()
|
|
|
|
if existing_model:
|
|
# 更新现有记录
|
|
await session.execute(
|
|
update(ComfyUIServerModel)
|
|
.where(ComfyUIServerModel.name == server.name)
|
|
.values(
|
|
http_url=server.http_url,
|
|
ws_url=server.ws_url,
|
|
status=server.status.value,
|
|
last_heartbeat=server.last_heartbeat,
|
|
last_health_check=server.last_health_check,
|
|
current_tasks=server.current_tasks,
|
|
max_concurrent_tasks=server.max_concurrent_tasks,
|
|
capabilities=json.dumps(server.capabilities),
|
|
server_metadata=json.dumps(server.metadata),
|
|
updated_at=datetime.utcnow(),
|
|
)
|
|
)
|
|
else:
|
|
# 创建新记录
|
|
new_model = server.to_model()
|
|
session.add(new_model)
|
|
|
|
await session.commit()
|
|
except Exception as e:
|
|
logger.error(f"保存服务器 {server.name} 到数据库失败: {e}")
|
|
|
|
async def _delete_server_from_db(self, name: str):
|
|
"""从数据库删除服务器"""
|
|
try:
|
|
async with AsyncSessionLocal() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(ComfyUIServerModel).where(
|
|
ComfyUIServerModel.name == name
|
|
)
|
|
)
|
|
await session.commit()
|
|
except Exception as e:
|
|
logger.error(f"从数据库删除服务器 {name} 失败: {e}")
|
|
|
|
async def register_server(
|
|
self,
|
|
name: str,
|
|
http_url: str,
|
|
ws_url: str,
|
|
max_concurrent_tasks: int = 1,
|
|
capabilities: Optional[Dict[str, any]] = None,
|
|
metadata: Optional[Dict[str, any]] = None,
|
|
) -> bool:
|
|
"""注册新的ComfyUI服务器"""
|
|
# 确保已初始化
|
|
await self.initialize()
|
|
|
|
# 确保健康检查任务已启动
|
|
await self._ensure_health_check_started()
|
|
|
|
async with self.lock:
|
|
# 解析URL获取IP和端口
|
|
from urllib.parse import urlparse
|
|
|
|
parsed_url = urlparse(http_url)
|
|
host = parsed_url.hostname
|
|
port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443)
|
|
|
|
# 检查是否已存在相同IP+端口的服务器
|
|
existing_server_name = None
|
|
for existing_server in self.servers.values():
|
|
existing_parsed = urlparse(existing_server.http_url)
|
|
existing_host = existing_parsed.hostname
|
|
existing_port = existing_parsed.port or (80 if existing_parsed.scheme == 'http' else 443)
|
|
|
|
if existing_host == host and existing_port == port:
|
|
existing_server_name = existing_server.name
|
|
logger.info(f"检测到IP {host}:{port} 已被服务器 {existing_server_name} 使用,将自动更新配置")
|
|
break
|
|
|
|
# 如果IP+端口已存在,更新现有服务器
|
|
if existing_server_name:
|
|
# 删除旧的服务器记录
|
|
if existing_server_name in self.servers:
|
|
del self.servers[existing_server_name]
|
|
logger.info(f"已删除旧服务器 {existing_server_name} 的记录")
|
|
|
|
# 检查名称是否已存在
|
|
if name in self.servers:
|
|
logger.warning(f"服务器名称 {name} 已存在,将更新配置")
|
|
|
|
server_info = ComfyUIServerInfo(
|
|
name=name,
|
|
http_url=http_url.rstrip("/"),
|
|
ws_url=ws_url.rstrip("/"),
|
|
status=ServerStatus.OFFLINE, # 注册时状态为离线,等待健康检查
|
|
max_concurrent_tasks=max_concurrent_tasks,
|
|
capabilities=capabilities or {},
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
# 立即进行健康检查
|
|
if await self._check_server_health(server_info):
|
|
server_info.status = ServerStatus.ONLINE
|
|
server_info.last_health_check = datetime.now()
|
|
server_info.last_heartbeat = datetime.now()
|
|
|
|
self.servers[name] = server_info
|
|
|
|
# 保存到数据库
|
|
await self._save_server_to_db(server_info)
|
|
|
|
# 如果更新了现有服务器,从数据库删除旧记录
|
|
if existing_server_name:
|
|
await self._delete_server_from_db(existing_server_name)
|
|
|
|
logger.info(f"服务器 {name} 注册成功: {http_url} (IP: {host}:{port})")
|
|
return True
|
|
|
|
async def unregister_server(self, name: str) -> bool:
|
|
"""注销ComfyUI服务器"""
|
|
async with self.lock:
|
|
if name in self.servers:
|
|
del self.servers[name]
|
|
# 从数据库删除
|
|
await self._delete_server_from_db(name)
|
|
logger.info(f"服务器 {name} 已注销")
|
|
return True
|
|
return False
|
|
|
|
async def update_server_heartbeat(self, name: str) -> bool:
|
|
"""更新服务器心跳"""
|
|
async with self.lock:
|
|
if name in self.servers:
|
|
self.servers[name].last_heartbeat = datetime.now()
|
|
self.servers[name].status = ServerStatus.ONLINE
|
|
# 保存到数据库
|
|
await self._save_server_to_db(self.servers[name])
|
|
return True
|
|
return False
|
|
|
|
async def get_available_server(
|
|
self, required_capabilities: Optional[Dict[str, any]] = None
|
|
) -> Optional[ComfyUIServerInfo]:
|
|
"""获取可用的服务器(负载均衡)"""
|
|
# 确保已初始化
|
|
await self.initialize()
|
|
|
|
# 确保健康检查任务已启动
|
|
await self._ensure_health_check_started()
|
|
|
|
async with self.lock:
|
|
available_servers = []
|
|
|
|
for server in self.servers.values():
|
|
if (
|
|
server.status == ServerStatus.ONLINE
|
|
and server.current_tasks < server.max_concurrent_tasks
|
|
):
|
|
|
|
# 检查能力要求
|
|
if required_capabilities:
|
|
if not self._check_capabilities(
|
|
server.capabilities, required_capabilities
|
|
):
|
|
continue
|
|
|
|
available_servers.append(server)
|
|
|
|
if not available_servers:
|
|
return None
|
|
|
|
# 简单的负载均衡:选择当前任务数最少的服务器
|
|
return min(available_servers, key=lambda s: s.current_tasks)
|
|
|
|
async def allocate_server(self, name: str) -> bool:
|
|
"""分配服务器资源"""
|
|
async with self.lock:
|
|
if name in self.servers:
|
|
server = self.servers[name]
|
|
if (
|
|
server.status == ServerStatus.ONLINE
|
|
and server.current_tasks < server.max_concurrent_tasks
|
|
):
|
|
server.current_tasks += 1
|
|
if server.current_tasks >= server.max_concurrent_tasks:
|
|
server.status = ServerStatus.BUSY
|
|
# 保存到数据库
|
|
await self._save_server_to_db(server)
|
|
return True
|
|
return False
|
|
|
|
async def release_server(self, name: str) -> bool:
|
|
"""释放服务器资源"""
|
|
async with self.lock:
|
|
if name in self.servers:
|
|
server = self.servers[name]
|
|
if server.current_tasks > 0:
|
|
server.current_tasks -= 1
|
|
if (
|
|
server.status == ServerStatus.BUSY
|
|
and server.current_tasks < server.max_concurrent_tasks
|
|
):
|
|
server.status = ServerStatus.ONLINE
|
|
# 保存到数据库
|
|
await self._save_server_to_db(server)
|
|
return True
|
|
return False
|
|
|
|
async def get_server_status(self, name: str) -> Optional[ComfyUIServerInfo]:
|
|
"""获取服务器状态"""
|
|
# 确保已初始化
|
|
await self.initialize()
|
|
|
|
async with self.lock:
|
|
return self.servers.get(name)
|
|
|
|
async def get_all_servers(self) -> List[ComfyUIServerInfo]:
|
|
"""获取所有服务器信息"""
|
|
# 确保已初始化
|
|
await self.initialize()
|
|
|
|
async with self.lock:
|
|
return list(self.servers.values())
|
|
|
|
def _check_capabilities(
|
|
self, server_capabilities: Dict[str, any], required_capabilities: Dict[str, any]
|
|
) -> bool:
|
|
"""检查服务器是否满足能力要求"""
|
|
for key, value in required_capabilities.items():
|
|
if key not in server_capabilities:
|
|
return False
|
|
if isinstance(value, dict) and isinstance(server_capabilities[key], dict):
|
|
if not self._check_capabilities(server_capabilities[key], value):
|
|
return False
|
|
elif server_capabilities[key] != value:
|
|
return False
|
|
return True
|
|
|
|
async def _check_server_health(self, server: ComfyUIServerInfo) -> bool:
|
|
"""检查服务器健康状态"""
|
|
try:
|
|
timeout = ClientTimeout(total=10)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.get(f"{server.http_url}/system_stats") as response:
|
|
if response.status == 200:
|
|
return True
|
|
except Exception as e:
|
|
logger.debug(f"服务器 {server.name} 健康检查失败: {e}")
|
|
|
|
return False
|
|
|
|
async def _health_check_loop(self):
|
|
"""健康检查循环"""
|
|
while True:
|
|
try:
|
|
await self._perform_health_checks()
|
|
except Exception as e:
|
|
logger.error(f"健康检查过程中发生错误: {e}")
|
|
|
|
await asyncio.sleep(self.health_check_interval)
|
|
|
|
async def _perform_health_checks(self):
|
|
"""执行健康检查"""
|
|
current_time = datetime.now()
|
|
|
|
async with self.lock:
|
|
for server in self.servers.values():
|
|
# 检查心跳超时
|
|
if (
|
|
server.last_heartbeat
|
|
and current_time - server.last_heartbeat
|
|
> timedelta(seconds=self.heartbeat_timeout)
|
|
):
|
|
server.status = ServerStatus.OFFLINE
|
|
logger.warning(f"服务器 {server.name} 心跳超时,标记为离线")
|
|
# 保存到数据库
|
|
await self._save_server_to_db(server)
|
|
|
|
# 定期健康检查
|
|
if (
|
|
not server.last_health_check
|
|
or current_time - server.last_health_check
|
|
> timedelta(seconds=self.health_check_interval)
|
|
):
|
|
|
|
is_healthy = await self._check_server_health(server)
|
|
server.last_health_check = current_time
|
|
|
|
if is_healthy and server.status != ServerStatus.ONLINE:
|
|
server.status = ServerStatus.ONLINE
|
|
server.last_heartbeat = current_time # 同步更新心跳时间
|
|
logger.info(f"服务器 {server.name} 恢复在线,心跳时间已同步")
|
|
elif not is_healthy and server.status == ServerStatus.ONLINE:
|
|
server.status = ServerStatus.ERROR
|
|
logger.warning(f"服务器 {server.name} 健康检查失败")
|
|
|
|
# 保存到数据库
|
|
await self._save_server_to_db(server)
|
|
|
|
def _start_health_check(self):
|
|
"""启动健康检查任务"""
|
|
try:
|
|
# 检查是否有运行中的事件循环
|
|
loop = asyncio.get_running_loop()
|
|
if self._health_check_task is None or self._health_check_task.done():
|
|
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
|
except RuntimeError:
|
|
# 没有运行中的事件循环,延迟启动
|
|
logger.debug("没有运行中的事件循环,健康检查将在事件循环启动后自动开始")
|
|
|
|
async def _ensure_health_check_started(self):
|
|
"""确保健康检查任务已启动"""
|
|
if self._health_check_task is None or self._health_check_task.done():
|
|
self._start_health_check()
|
|
|
|
async def shutdown(self):
|
|
"""关闭管理器"""
|
|
if self._health_check_task and not self._health_check_task.done():
|
|
self._health_check_task.cancel()
|
|
try:
|
|
await self._health_check_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
|
|
# 全局服务器管理器实例
|
|
server_manager = ComfyUIServerManager()
|
|
|
|
|
|
# 兼容性函数,用于替代原有的静态配置
|
|
async def get_available_servers() -> List[ComfyUIServerInfo]:
|
|
"""获取可用的服务器列表(兼容性函数)"""
|
|
return await server_manager.get_all_servers()
|
|
|
|
|
|
async def get_server_for_task(
|
|
required_capabilities: Optional[Dict[str, any]] = None
|
|
) -> Optional[ComfyUIServerInfo]:
|
|
"""为任务获取合适的服务器(兼容性函数)"""
|
|
return await server_manager.get_available_server(required_capabilities)
|