import asyncio import json import logging import time from datetime import datetime, timedelta from typing import Dict, List, Optional, Set from dataclasses import dataclass, asdict from enum import Enum import aiohttp from aiohttp import ClientTimeout from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update, delete from sqlalchemy.orm import selectinload from workflow_service.database.connection import AsyncSessionLocal from workflow_service.database.models import ComfyUIServer as ComfyUIServerModel logger = logging.getLogger(__name__) class ServerStatus(Enum): """服务器状态枚举""" ONLINE = "online" OFFLINE = "offline" BUSY = "busy" ERROR = "error" @dataclass class ComfyUIServerInfo: """ComfyUI服务器信息""" name: str # 服务器名称,由注册方自行拟定 http_url: str ws_url: str status: ServerStatus = ServerStatus.OFFLINE last_health_check: Optional[datetime] = None current_tasks: int = 0 max_concurrent_tasks: int = 1 capabilities: Dict[str, any] = None # 服务器能力信息,如支持的模型等 metadata: Dict[str, any] = None # 其他元数据 def __post_init__(self): if self.capabilities is None: self.capabilities = {} if self.metadata is None: self.metadata = {} @classmethod def from_model(cls, model: ComfyUIServerModel) -> "ComfyUIServerInfo": """从数据库模型创建实例""" capabilities = json.loads(model.capabilities) if model.capabilities else {} metadata = json.loads(model.server_metadata) if model.server_metadata else {} return cls( name=model.name, http_url=model.http_url, ws_url=model.ws_url, status=ServerStatus(model.status), last_health_check=model.last_health_check, current_tasks=model.current_tasks, max_concurrent_tasks=model.max_concurrent_tasks, capabilities=capabilities, metadata=metadata, ) def to_model(self) -> ComfyUIServerModel: """转换为数据库模型""" return ComfyUIServerModel( name=self.name, http_url=self.http_url, ws_url=self.ws_url, status=self.status.value, last_health_check=self.last_health_check, current_tasks=self.current_tasks, max_concurrent_tasks=self.max_concurrent_tasks, capabilities=json.dumps(self.capabilities), server_metadata=json.dumps(self.metadata), ) class ComfyUIServerManager: """ComfyUI服务器管理器""" def __init__(self): self.servers: Dict[str, ComfyUIServerInfo] = {} # name -> server_info self.lock = asyncio.Lock() self.health_check_interval = 30 # 健康检查间隔(秒) self._health_check_task: Optional[asyncio.Task] = None self._initialized = False async def initialize(self): """初始化管理器,从数据库加载服务器信息""" if self._initialized: return async with self.lock: try: await self._load_servers_from_db() self._initialized = True logger.info(f"从数据库加载了 {len(self.servers)} 个服务器") except Exception as e: logger.error(f"从数据库加载服务器失败: {e}") self._initialized = True # 即使失败也标记为已初始化,避免重复尝试 async def _load_servers_from_db(self): """从数据库加载服务器信息""" async with AsyncSessionLocal() as session: async with session.begin(): result = await session.execute(select(ComfyUIServerModel)) models = result.scalars().all() for model in models: try: server_info = ComfyUIServerInfo.from_model(model) self.servers[server_info.name] = server_info except Exception as e: logger.error(f"加载服务器 {model.name} 失败: {e}") async def _save_server_to_db(self, server: ComfyUIServerInfo): """保存服务器信息到数据库""" try: async with AsyncSessionLocal() as session: async with session.begin(): # 检查是否已存在 existing = await session.execute( select(ComfyUIServerModel).where( ComfyUIServerModel.name == server.name ) ) existing_model = existing.scalar_one_or_none() if existing_model: # 更新现有记录 await session.execute( update(ComfyUIServerModel) .where(ComfyUIServerModel.name == server.name) .values( http_url=server.http_url, ws_url=server.ws_url, status=server.status.value, last_health_check=server.last_health_check, current_tasks=server.current_tasks, max_concurrent_tasks=server.max_concurrent_tasks, capabilities=json.dumps(server.capabilities), server_metadata=json.dumps(server.metadata), updated_at=datetime.utcnow(), ) ) else: # 创建新记录 new_model = server.to_model() session.add(new_model) await session.commit() except Exception as e: logger.error(f"保存服务器 {server.name} 到数据库失败: {e}") async def _delete_server_from_db(self, name: str): """从数据库删除服务器""" try: async with AsyncSessionLocal() as session: async with session.begin(): await session.execute( delete(ComfyUIServerModel).where( ComfyUIServerModel.name == name ) ) await session.commit() except Exception as e: logger.error(f"从数据库删除服务器 {name} 失败: {e}") async def register_server( self, name: str, http_url: str, ws_url: str, max_concurrent_tasks: int = 1, capabilities: Optional[Dict[str, any]] = None, metadata: Optional[Dict[str, any]] = None, ) -> bool: """注册新的ComfyUI服务器""" # 确保已初始化 await self.initialize() # 确保健康检查任务已启动 await self._ensure_health_check_started() async with self.lock: # 解析URL获取IP和端口 from urllib.parse import urlparse parsed_url = urlparse(http_url) host = parsed_url.hostname port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443) # 检查是否已存在相同IP+端口的服务器 existing_server_name = None for existing_server in self.servers.values(): existing_parsed = urlparse(existing_server.http_url) existing_host = existing_parsed.hostname existing_port = existing_parsed.port or (80 if existing_parsed.scheme == 'http' else 443) if existing_host == host and existing_port == port: existing_server_name = existing_server.name logger.info(f"检测到IP {host}:{port} 已被服务器 {existing_server_name} 使用,将自动更新配置") break # 如果IP+端口已存在,更新现有服务器 if existing_server_name: # 删除旧的服务器记录 if existing_server_name in self.servers: del self.servers[existing_server_name] logger.info(f"已删除旧服务器 {existing_server_name} 的记录") # 检查名称是否已存在 if name in self.servers: logger.warning(f"服务器名称 {name} 已存在,将更新配置") server_info = ComfyUIServerInfo( name=name, http_url=http_url.rstrip("/"), ws_url=ws_url.rstrip("/"), status=ServerStatus.OFFLINE, # 注册时状态为离线,等待健康检查 max_concurrent_tasks=max_concurrent_tasks, capabilities=capabilities or {}, metadata=metadata or {}, ) # 立即进行健康检查 if await self._check_server_health(server_info): server_info.status = ServerStatus.ONLINE server_info.last_health_check = datetime.now() self.servers[name] = server_info # 保存到数据库 await self._save_server_to_db(server_info) # 如果更新了现有服务器,从数据库删除旧记录 if existing_server_name: await self._delete_server_from_db(existing_server_name) logger.info(f"服务器 {name} 注册成功: {http_url} (IP: {host}:{port})") return True async def unregister_server(self, name: str) -> bool: """注销ComfyUI服务器""" async with self.lock: if name in self.servers: del self.servers[name] # 从数据库删除 await self._delete_server_from_db(name) logger.info(f"服务器 {name} 已注销") return True return False async def get_available_server( self, required_capabilities: Optional[Dict[str, any]] = None ) -> Optional[ComfyUIServerInfo]: """获取可用的服务器(负载均衡)""" # 确保已初始化 await self.initialize() # 确保健康检查任务已启动 await self._ensure_health_check_started() async with self.lock: available_servers = [] for server in self.servers.values(): if ( server.status == ServerStatus.ONLINE and server.current_tasks < server.max_concurrent_tasks ): # 检查能力要求 if required_capabilities: if not self._check_capabilities( server.capabilities, required_capabilities ): continue available_servers.append(server) if not available_servers: return None # 简单的负载均衡:选择当前任务数最少的服务器 return min(available_servers, key=lambda s: s.current_tasks) async def allocate_server(self, name: str) -> bool: """分配服务器资源""" async with self.lock: if name in self.servers: server = self.servers[name] if ( server.status == ServerStatus.ONLINE and server.current_tasks < server.max_concurrent_tasks ): server.current_tasks += 1 if server.current_tasks >= server.max_concurrent_tasks: server.status = ServerStatus.BUSY # 保存到数据库 await self._save_server_to_db(server) return True return False async def release_server(self, name: str) -> bool: """释放服务器资源""" async with self.lock: if name in self.servers: server = self.servers[name] if server.current_tasks > 0: server.current_tasks -= 1 if ( server.status == ServerStatus.BUSY and server.current_tasks < server.max_concurrent_tasks ): server.status = ServerStatus.ONLINE # 保存到数据库 await self._save_server_to_db(server) return True return False async def get_server_status(self, name: str) -> Optional[ComfyUIServerInfo]: """获取服务器状态""" # 确保已初始化 await self.initialize() async with self.lock: return self.servers.get(name) async def get_all_servers(self) -> List[ComfyUIServerInfo]: """获取所有服务器信息""" # 确保已初始化 await self.initialize() async with self.lock: return list(self.servers.values()) def _check_capabilities( self, server_capabilities: Dict[str, any], required_capabilities: Dict[str, any] ) -> bool: """检查服务器是否满足能力要求""" for key, value in required_capabilities.items(): if key not in server_capabilities: return False if isinstance(value, dict) and isinstance(server_capabilities[key], dict): if not self._check_capabilities(server_capabilities[key], value): return False elif server_capabilities[key] != value: return False return True async def _check_server_health(self, server: ComfyUIServerInfo) -> bool: """检查服务器健康状态""" try: timeout = ClientTimeout(total=10) async with aiohttp.ClientSession(timeout=timeout) as session: async with session.get(f"{server.http_url}/system_stats") as response: if response.status == 200: return True except Exception as e: logger.debug(f"服务器 {server.name} 健康检查失败: {e}") return False async def _health_check_loop(self): """健康检查循环""" while True: try: await self._perform_health_checks() except Exception as e: logger.error(f"健康检查过程中发生错误: {e}") await asyncio.sleep(self.health_check_interval) async def _perform_health_checks(self): """执行健康检查""" current_time = datetime.now() async with self.lock: for server in self.servers.values(): # 定期健康检查 if ( not server.last_health_check or current_time - server.last_health_check > timedelta(seconds=self.health_check_interval) ): is_healthy = await self._check_server_health(server) server.last_health_check = current_time if is_healthy and server.status != ServerStatus.ONLINE: server.status = ServerStatus.ONLINE logger.info(f"服务器 {server.name} 恢复在线") elif not is_healthy and server.status == ServerStatus.ONLINE: server.status = ServerStatus.ERROR logger.warning(f"服务器 {server.name} 健康检查失败") # 保存到数据库 await self._save_server_to_db(server) def _start_health_check(self): """启动健康检查任务""" try: # 检查是否有运行中的事件循环 loop = asyncio.get_running_loop() if self._health_check_task is None or self._health_check_task.done(): self._health_check_task = asyncio.create_task(self._health_check_loop()) except RuntimeError: # 没有运行中的事件循环,延迟启动 logger.debug("没有运行中的事件循环,健康检查将在事件循环启动后自动开始") async def _ensure_health_check_started(self): """确保健康检查任务已启动""" if self._health_check_task is None or self._health_check_task.done(): self._start_health_check() async def shutdown(self): """关闭管理器""" if self._health_check_task and not self._health_check_task.done(): self._health_check_task.cancel() try: await self._health_check_task except asyncio.CancelledError: pass # 全局服务器管理器实例 server_manager = ComfyUIServerManager() # 兼容性函数,用于替代原有的静态配置 async def get_available_servers() -> List[ComfyUIServerInfo]: """获取可用的服务器列表(兼容性函数)""" return await server_manager.get_all_servers() async def get_server_for_task( required_capabilities: Optional[Dict[str, any]] = None ) -> Optional[ComfyUIServerInfo]: """为任务获取合适的服务器(兼容性函数)""" return await server_manager.get_available_server(required_capabilities)