ComfyUI-WorkflowPublisher/app/comfy/comfy_server.py

448 lines
17 KiB
Python

import asyncio
import json
import logging
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Set
from dataclasses import dataclass, asdict
from enum import Enum
import aiohttp
from aiohttp import ClientTimeout
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.orm import selectinload
from app.database.connection import AsyncSessionLocal
from app.database.models import ComfyUIServer as ComfyUIServerModel
logger = logging.getLogger(__name__)
class ServerStatus(Enum):
"""服务器状态枚举"""
ONLINE = "online"
OFFLINE = "offline"
BUSY = "busy"
ERROR = "error"
@dataclass
class ComfyUIServerInfo:
"""ComfyUI服务器信息"""
name: str # 服务器名称,由注册方自行拟定
http_url: str
ws_url: str
status: ServerStatus = ServerStatus.OFFLINE
last_health_check: Optional[datetime] = None
current_tasks: int = 0
max_concurrent_tasks: int = 1
capabilities: Dict[str, any] = None # 服务器能力信息,如支持的模型等
metadata: Dict[str, any] = None # 其他元数据
def __post_init__(self):
if self.capabilities is None:
self.capabilities = {}
if self.metadata is None:
self.metadata = {}
@classmethod
def from_model(cls, model: ComfyUIServerModel) -> "ComfyUIServerInfo":
"""从数据库模型创建实例"""
capabilities = json.loads(model.capabilities) if model.capabilities else {}
metadata = json.loads(model.server_metadata) if model.server_metadata else {}
return cls(
name=model.name,
http_url=model.http_url,
ws_url=model.ws_url,
status=ServerStatus(model.status),
last_health_check=model.last_health_check,
current_tasks=model.current_tasks,
max_concurrent_tasks=model.max_concurrent_tasks,
capabilities=capabilities,
metadata=metadata,
)
def to_model(self) -> ComfyUIServerModel:
"""转换为数据库模型"""
return ComfyUIServerModel(
name=self.name,
http_url=self.http_url,
ws_url=self.ws_url,
status=self.status.value,
last_health_check=self.last_health_check,
current_tasks=self.current_tasks,
max_concurrent_tasks=self.max_concurrent_tasks,
capabilities=json.dumps(self.capabilities),
server_metadata=json.dumps(self.metadata),
)
class ComfyUIServerManager:
"""ComfyUI服务器管理器"""
def __init__(self):
self.servers: Dict[str, ComfyUIServerInfo] = {} # name -> server_info
self.lock = asyncio.Lock()
self.health_check_interval = 30 # 健康检查间隔(秒)
self._health_check_task: Optional[asyncio.Task] = None
self._initialized = False
async def initialize(self):
"""初始化管理器,从数据库加载服务器信息"""
if self._initialized:
return
async with self.lock:
try:
await self._load_servers_from_db()
self._initialized = True
logger.info(f"从数据库加载了 {len(self.servers)} 个服务器")
except Exception as e:
logger.error(f"从数据库加载服务器失败: {e}")
self._initialized = True # 即使失败也标记为已初始化,避免重复尝试
async def _load_servers_from_db(self):
"""从数据库加载服务器信息"""
async with AsyncSessionLocal() as session:
async with session.begin():
result = await session.execute(select(ComfyUIServerModel))
models = result.scalars().all()
for model in models:
try:
server_info = ComfyUIServerInfo.from_model(model)
self.servers[server_info.name] = server_info
except Exception as e:
logger.error(f"加载服务器 {model.name} 失败: {e}")
async def _save_server_to_db(self, server: ComfyUIServerInfo):
"""保存服务器信息到数据库"""
try:
async with AsyncSessionLocal() as session:
async with session.begin():
# 检查是否已存在
existing = await session.execute(
select(ComfyUIServerModel).where(
ComfyUIServerModel.name == server.name
)
)
existing_model = existing.scalar_one_or_none()
if existing_model:
# 更新现有记录
await session.execute(
update(ComfyUIServerModel)
.where(ComfyUIServerModel.name == server.name)
.values(
http_url=server.http_url,
ws_url=server.ws_url,
status=server.status.value,
last_health_check=server.last_health_check,
current_tasks=server.current_tasks,
max_concurrent_tasks=server.max_concurrent_tasks,
capabilities=json.dumps(server.capabilities),
server_metadata=json.dumps(server.metadata),
updated_at=datetime.utcnow(),
)
)
else:
# 创建新记录
new_model = server.to_model()
session.add(new_model)
await session.commit()
except Exception as e:
logger.error(f"保存服务器 {server.name} 到数据库失败: {e}")
async def _delete_server_from_db(self, name: str):
"""从数据库删除服务器"""
try:
async with AsyncSessionLocal() as session:
async with session.begin():
await session.execute(
delete(ComfyUIServerModel).where(
ComfyUIServerModel.name == name
)
)
await session.commit()
except Exception as e:
logger.error(f"从数据库删除服务器 {name} 失败: {e}")
async def register_server(
self,
name: str,
http_url: str,
ws_url: str,
max_concurrent_tasks: int = 1,
capabilities: Optional[Dict[str, any]] = None,
metadata: Optional[Dict[str, any]] = None,
) -> bool:
"""注册新的ComfyUI服务器"""
# 确保已初始化
await self.initialize()
# 确保健康检查任务已启动
await self._ensure_health_check_started()
async with self.lock:
# 解析URL获取IP和端口
from urllib.parse import urlparse
parsed_url = urlparse(http_url)
host = parsed_url.hostname
port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443)
# 检查是否已存在相同IP+端口的服务器
existing_server_name = None
for existing_server in self.servers.values():
existing_parsed = urlparse(existing_server.http_url)
existing_host = existing_parsed.hostname
existing_port = existing_parsed.port or (80 if existing_parsed.scheme == 'http' else 443)
if existing_host == host and existing_port == port:
existing_server_name = existing_server.name
logger.info(f"检测到IP {host}:{port} 已被服务器 {existing_server_name} 使用,将自动更新配置")
break
# 如果IP+端口已存在,更新现有服务器
if existing_server_name:
# 删除旧的服务器记录
if existing_server_name in self.servers:
del self.servers[existing_server_name]
logger.info(f"已删除旧服务器 {existing_server_name} 的记录")
# 检查名称是否已存在
if name in self.servers:
logger.warning(f"服务器名称 {name} 已存在,将更新配置")
server_info = ComfyUIServerInfo(
name=name,
http_url=http_url.rstrip("/"),
ws_url=ws_url.rstrip("/"),
status=ServerStatus.OFFLINE, # 注册时状态为离线,等待健康检查
max_concurrent_tasks=max_concurrent_tasks,
capabilities=capabilities or {},
metadata=metadata or {},
)
# 立即进行健康检查
if await self._check_server_health(server_info):
server_info.status = ServerStatus.ONLINE
server_info.last_health_check = datetime.now()
self.servers[name] = server_info
# 保存到数据库
await self._save_server_to_db(server_info)
# 如果更新了现有服务器,从数据库删除旧记录
if existing_server_name:
await self._delete_server_from_db(existing_server_name)
logger.info(f"服务器 {name} 注册成功: {http_url} (IP: {host}:{port})")
return True
async def unregister_server(self, name: str) -> bool:
"""注销ComfyUI服务器"""
async with self.lock:
if name in self.servers:
del self.servers[name]
# 从数据库删除
await self._delete_server_from_db(name)
logger.info(f"服务器 {name} 已注销")
return True
return False
async def get_available_server(
self, required_capabilities: Optional[Dict[str, any]] = None
) -> Optional[ComfyUIServerInfo]:
"""获取可用的服务器(负载均衡)"""
# 确保已初始化
await self.initialize()
# 确保健康检查任务已启动
await self._ensure_health_check_started()
async with self.lock:
available_servers = []
for server in self.servers.values():
if (
server.status == ServerStatus.ONLINE
and server.current_tasks < server.max_concurrent_tasks
):
# 检查能力要求
if required_capabilities:
if not self._check_capabilities(
server.capabilities, required_capabilities
):
continue
available_servers.append(server)
if not available_servers:
return None
# 简单的负载均衡:选择当前任务数最少的服务器
return min(available_servers, key=lambda s: s.current_tasks)
async def allocate_server(self, name: str) -> bool:
"""分配服务器资源"""
async with self.lock:
if name in self.servers:
server = self.servers[name]
if (
server.status == ServerStatus.ONLINE
and server.current_tasks < server.max_concurrent_tasks
):
server.current_tasks += 1
if server.current_tasks >= server.max_concurrent_tasks:
server.status = ServerStatus.BUSY
# 保存到数据库
await self._save_server_to_db(server)
return True
return False
async def release_server(self, name: str) -> bool:
"""释放服务器资源"""
async with self.lock:
if name in self.servers:
server = self.servers[name]
if server.current_tasks > 0:
server.current_tasks -= 1
if (
server.status == ServerStatus.BUSY
and server.current_tasks < server.max_concurrent_tasks
):
server.status = ServerStatus.ONLINE
# 保存到数据库
await self._save_server_to_db(server)
return True
return False
async def get_server_status(self, name: str) -> Optional[ComfyUIServerInfo]:
"""获取服务器状态"""
# 确保已初始化
await self.initialize()
async with self.lock:
return self.servers.get(name)
async def get_all_servers(self) -> List[ComfyUIServerInfo]:
"""获取所有服务器信息"""
# 确保已初始化
await self.initialize()
async with self.lock:
return list(self.servers.values())
def _check_capabilities(
self, server_capabilities: Dict[str, any], required_capabilities: Dict[str, any]
) -> bool:
"""检查服务器是否满足能力要求"""
for key, value in required_capabilities.items():
if key not in server_capabilities:
return False
if isinstance(value, dict) and isinstance(server_capabilities[key], dict):
if not self._check_capabilities(server_capabilities[key], value):
return False
elif server_capabilities[key] != value:
return False
return True
async def _check_server_health(self, server: ComfyUIServerInfo) -> bool:
"""检查服务器健康状态"""
try:
timeout = ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(f"{server.http_url}/system_stats") as response:
if response.status == 200:
return True
except Exception as e:
logger.debug(f"服务器 {server.name} 健康检查失败: {e}")
return False
async def _health_check_loop(self):
"""健康检查循环"""
while True:
try:
await self._perform_health_checks()
except Exception as e:
logger.error(f"健康检查过程中发生错误: {e}")
await asyncio.sleep(self.health_check_interval)
async def _perform_health_checks(self):
"""执行健康检查"""
current_time = datetime.now()
async with self.lock:
for server in self.servers.values():
# 定期健康检查
if (
not server.last_health_check
or current_time - server.last_health_check
> timedelta(seconds=self.health_check_interval)
):
is_healthy = await self._check_server_health(server)
server.last_health_check = current_time
if is_healthy and server.status != ServerStatus.ONLINE:
server.status = ServerStatus.ONLINE
logger.info(f"服务器 {server.name} 恢复在线")
elif not is_healthy and server.status == ServerStatus.ONLINE:
server.status = ServerStatus.ERROR
logger.warning(f"服务器 {server.name} 健康检查失败")
# 保存到数据库
await self._save_server_to_db(server)
def _start_health_check(self):
"""启动健康检查任务"""
try:
# 检查是否有运行中的事件循环
loop = asyncio.get_running_loop()
if self._health_check_task is None or self._health_check_task.done():
self._health_check_task = asyncio.create_task(self._health_check_loop())
except RuntimeError:
# 没有运行中的事件循环,延迟启动
logger.debug("没有运行中的事件循环,健康检查将在事件循环启动后自动开始")
async def _ensure_health_check_started(self):
"""确保健康检查任务已启动"""
if self._health_check_task is None or self._health_check_task.done():
self._start_health_check()
async def shutdown(self):
"""关闭管理器"""
if self._health_check_task and not self._health_check_task.done():
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
# 全局服务器管理器实例
server_manager = ComfyUIServerManager()
# 兼容性函数,用于替代原有的静态配置
async def get_available_servers() -> List[ComfyUIServerInfo]:
"""获取可用的服务器列表(兼容性函数)"""
return await server_manager.get_all_servers()
async def get_server_for_task(
required_capabilities: Optional[Dict[str, any]] = None
) -> Optional[ComfyUIServerInfo]:
"""为任务获取合适的服务器(兼容性函数)"""
return await server_manager.get_available_server(required_capabilities)