ComfyUI-CustomNode/nodes/util_nodes.py

573 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
工具节点模块 - 重构版本
提供各种实用功能的ComfyUI节点包括
- 数据库日志记录
- VOD文件下载使用统一存储抽象层
- 模型管理
- 文件遍历
- Webhook通知
- 任务ID生成
本模块已经过重构使用统一的存储抽象层替代直接SDK调用。
"""
import glob
import json
import os
import time
import uuid
from pathlib import Path
from typing import Any, Dict, Tuple
import comfy.model_management
import requests
import server
from loguru import logger
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from ..utils.object_storage import DownloadResult, get_provider
from ..utils.task_table import Task
class LogToDB:
"""
数据库日志记录节点
将ComfyUI工作流的执行结果记录到数据库中支持任务状态跟踪。
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"job_id": ("STRING", {"forceInput": True}),
"log": ("STRING", {"forceInput": True}),
"status": ("INT", {"default": 1, "max": 1}),
"sql_url": (
"STRING",
{"default": "mysql+pymysql://root:root@example.com:3306/test"},
),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING",)
FUNCTION = "log2db"
OUTPUT_NODE = True
OUTPUT_IS_LIST = (True,)
CATEGORY = "不忘科技-自定义节点🚩/工具"
def log2db(
self, log: str, status: int, sql_url: str, unique_id: str, **kwargs
) -> Dict[str, Any]:
"""
记录日志到数据库
Args:
log: 日志内容
status: 状态代码
sql_url: 数据库连接URL
unique_id: 唯一ID
Returns:
Dict: 包含结果信息的字典
"""
try:
# 获取comfy服务器队列信息
(_, prompt_id, prompt, extra_data, outputs_to_execute) = next(
iter(
server.PromptServer.instance.prompt_queue.currently_running.values()
)
)
job_id = extra_data["client_id"]
engine = create_engine(sql_url, echo=True)
session = sessionmaker(bind=engine)()
# 查询现有任务
tasks = session.query(Task).filter(Task.prompt_id == prompt_id).all()
result = {
"curr_node_id": str(unique_id),
"last_node_id": list(prompt.keys())[-1],
"node_output": str(log),
}
if len(tasks) == 0:
# 不存在则插入
task = Task(
prompt_id=prompt_id,
job_id=job_id,
result=json.dumps(result),
status=status,
)
session.add(task)
logger.info(f"新增任务记录: {prompt_id}")
elif len(tasks) == 1:
# 存在则更新
session.query(Task).filter(Task.prompt_id == prompt_id).update(
{"result": json.dumps(result), "status": status}
)
logger.info(f"更新任务记录: {prompt_id}")
else:
# 异常情况
session.rollback()
raise RuntimeError("状态数据库prompt_id不唯一, 无法记录状态!")
session.commit()
session.close()
return {"ui": {"text": json.dumps(result)}, "result": (json.dumps(result),)}
except Exception as e:
logger.error(f"数据库日志记录失败: {e}")
return {"ui": {"text": str(e)}, "result": (str(e),)}
class VodToLocalNode:
"""
腾讯云VOD文件下载节点
使用统一的存储抽象层从腾讯云视频点播(VOD)服务下载媒体文件到本地。
支持通过文件ID和子应用ID定位和下载视频文件。
"""
def __init__(self):
"""
初始化VOD下载节点
使用统一的存储管理器获取VOD存储提供者
替代直接使用腾讯云VOD SDK的方式。
"""
try:
# 使用统一存储管理器获取VOD提供者
self.vod_provider = get_provider("vod")
logger.info("VOD下载节点初始化成功")
except Exception as e:
logger.error(f"VOD下载节点初始化失败: {e}")
raise RuntimeError(f"VOD节点初始化失败: {e}")
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"file_id": ("STRING", {"default": ""}),
"sub_app_id": ("STRING", {"default": ""}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("local_path",)
FUNCTION = "execute"
CATEGORY = "不忘科技-自定义节点🚩/工具"
def execute(self, file_id: str, sub_app_id: str) -> Tuple[str]:
"""
执行VOD文件下载
Args:
file_id: VOD文件ID
sub_app_id: 子应用ID
Returns:
tuple: 包含本地文件路径的元组
Raises:
Exception: 下载失败时抛出异常
"""
try:
# 参数验证
if not file_id or not file_id.strip():
raise ValueError("文件ID不能为空")
if not sub_app_id or not sub_app_id.strip():
raise ValueError("子应用ID不能为空")
# 调用下载逻辑
local_path = self.download_vod(file_id.strip(), sub_app_id.strip())
logger.info(f"VOD文件下载成功: {local_path}")
return (local_path,)
except Exception as e:
logger.error(f"VOD文件下载执行失败: {e}")
raise Exception(f"VOD文件下载失败: {str(e)}")
def create_directory(self, path: str) -> None:
"""
创建目录(如果不存在)
Args:
path: 目录路径
"""
p = Path(path)
if not p.exists():
p.mkdir(parents=True, exist_ok=True)
logger.info(f"目录已创建: {path}")
else:
logger.debug(f"目录已存在: {path}")
def download_vod(self, file_id: str, sub_app_id: str) -> str:
"""
下载腾讯云VOD文件到本地
Args:
file_id: VOD文件ID
sub_app_id: 子应用ID
Returns:
str: 本地文件路径
Raises:
Exception: 下载失败时抛出异常
"""
try:
# 生成本地存储路径
download_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "download", f"{sub_app_id}"
)
self.create_directory(download_dir)
output_path = os.path.join(download_dir, f"{file_id}.mp4")
# 如果文件已存在,直接返回
if os.path.exists(output_path):
logger.info(f"VOD文件已存在直接返回: {output_path}")
return output_path
logger.info(f"开始下载VOD文件: {file_id} -> {output_path}")
# 使用VOD提供者下载文件
result: DownloadResult = self.vod_provider.download_file(
file_id, output_path, sub_app_id, timeout=600
)
if result.success:
logger.info(f"VOD文件下载成功: {result.local_path}")
return result.local_path
else:
raise Exception(result.message or "VOD下载失败")
except Exception as e:
logger.error(f"VOD文件下载异常: {e}")
raise Exception(f"VOD下载失败: {str(e)}")
class AnyType(str):
"""特殊类型用于在不等比较中始终相等。Credit to pythongosssss"""
def __ne__(self, __value: object) -> bool:
return False
any = AnyType("*")
class UnloadAllModels:
"""
卸载所有模型节点
释放GPU内存卸载所有已加载的模型。
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {"any": (any, {"forceInput": True})},
"optional": {},
}
RETURN_TYPES = ()
FUNCTION = "unload_models"
CATEGORY = "不忘科技-自定义节点🚩/工具"
OUTPUT_NODE = True
def unload_models(self, any=None) -> Tuple:
"""
卸载所有已加载的模型
Args:
any: 输入参数(任意类型)
Returns:
tuple: 空元组
"""
try:
logger.info("开始卸载所有模型...")
# 卸载所有已加载的模型
comfy.model_management.soft_empty_cache()
comfy.model_management.unload_all_models()
comfy.model_management.soft_empty_cache()
logger.info("所有模型已成功卸载")
return ()
except Exception as e:
logger.error(f"模型卸载失败: {e}")
# 即使失败也返回空元组,避免中断工作流
return ()
class TraverseFolder:
"""
文件夹遍历节点
遍历指定文件夹,查找符合条件的文件。
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"folder": (
"STRING",
{"default": r"E:\comfy\ComfyUI\input\s3", "required": True},
),
"subfix": ("STRING", {"default": ".mp4", "required": True}),
"recursive": ("BOOLEAN", {"default": True, "required": True}),
"idx": (
"INT",
{"default": 0, "min": 0, "max": 0xFFFFFF},
),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("文件路径",)
FUNCTION = "compute"
CATEGORY = "不忘科技-自定义节点🚩/工具"
def compute(
self, folder: str, subfix: str, recursive: bool, idx: int
) -> Tuple[str]:
"""
遍历文件夹查找文件
Args:
folder: 文件夹路径
subfix: 文件后缀
recursive: 是否递归查找
idx: 文件索引
Returns:
tuple: 包含文件路径的元组
Raises:
RuntimeError: 未找到文件时抛出异常
"""
try:
# 参数验证
if not folder or not folder.strip():
raise ValueError("文件夹路径不能为空")
if not subfix or not subfix.strip():
raise ValueError("文件后缀不能为空")
folder = folder.strip()
subfix = subfix.strip()
# 构建搜索模式
if recursive:
pattern = os.path.join(folder, f"**/*{subfix}")
else:
pattern = os.path.join(folder, f"*{subfix}")
# 查找文件
files = glob.glob(pattern, recursive=recursive)
if len(files) == 0:
logger.warning(f"在文件夹 {folder} 中未找到后缀为 {subfix} 的文件")
raise RuntimeError("No Files Found")
# 选择文件
selected_file = files[idx % len(files)]
logger.info(
f"找到 {len(files)} 个文件,选择第 {idx % len(files)} 个: {selected_file}"
)
return (str(selected_file),)
except Exception as e:
logger.error(f"文件夹遍历失败: {e}")
if isinstance(e, RuntimeError):
raise
else:
raise RuntimeError(f"文件夹遍历错误: {str(e)}")
class PlugAndPlayWebhook:
"""
即插即用Webhook节点
连上线就能转发数据到指定的Webhook地址。
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"webhook_url": (
"STRING",
{
"default": "http://127.0.0.1:8010/handler/webhook",
"placeholder": "https://your-api.com/webhook",
},
),
"image_url": ("STRING", {"default": "", "placeholder": "图片的url"}),
},
"optional": {
"prompt_id": ("STRING", {"default": ""}),
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ()
FUNCTION = "send"
OUTPUT_NODE = True
CATEGORY = "不忘科技-自定义节点🚩/工具"
def send(
self,
webhook_url: str,
image_url: str,
prompt_id: str = "",
prompt=None,
extra_pnginfo=None,
unique_id=None,
) -> Tuple:
"""
发送数据到Webhook
Args:
webhook_url: Webhook URL
image_url: 图片URL
prompt_id: 提示ID
prompt: 提示信息
extra_pnginfo: 额外PNG信息
unique_id: 唯一ID
Returns:
tuple: 空元组
Raises:
ValueError: URL为空时抛出异常
"""
try:
# 参数验证
if not webhook_url or not webhook_url.strip():
raise ValueError("❌ 请填写Webhook URL")
# 使用传入的prompt_id如果没有则用unique_id
final_prompt_id = prompt_id or unique_id or "unknown"
# 准备发送的数据
data = {
"img_base64": image_url,
"format": "png",
"image_url": image_url,
"prompt_id": final_prompt_id,
"timestamp": time.time(),
}
logger.info(f"准备发送Webhook数据到: {webhook_url}")
# 发送Webhook
response = requests.post(webhook_url.strip(), json=data, timeout=30)
response.raise_for_status()
logger.info(f"Webhook发送成功响应状态: {response.status_code}")
logger.debug(f"发送的数据: {data}")
except requests.RequestException as e:
logger.error(f"❌ Webhook发送失败: {str(e)}")
# 不抛出异常,避免中断工作流
except Exception as e:
logger.error(f"❌ Webhook发送异常: {str(e)}")
# 不抛出异常,避免中断工作流
# 终端节点,无需返回
return ()
class TaskIdGenerate:
"""
TaskID生成器
用户可传入自定义TaskID或自动生成UUID。
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {},
"optional": {
"custom_task_id": (
"STRING",
{"default": "", "placeholder": "留空则自动生成"},
),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("task_id",)
FUNCTION = "generate_task_id"
OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩/工具"
def generate_task_id(self, custom_task_id: str = "") -> Tuple[str]:
"""
生成或使用自定义TaskID
Args:
custom_task_id: 自定义TaskID
Returns:
tuple: 包含TaskID的元组
"""
try:
if custom_task_id and custom_task_id.strip():
# 用户输入了自定义ID
task_id = custom_task_id.strip()
logger.info(f"📝 使用自定义TaskID: {task_id}")
else:
# 自动生成UUID
task_id = str(uuid.uuid4())
logger.info(f"🎲 自动生成TaskID: {task_id}")
return (task_id,)
except Exception as e:
logger.error(f"TaskID生成失败: {e}")
# 发生异常时生成一个基础UUID
fallback_id = str(uuid.uuid4())
logger.warning(f"使用备用TaskID: {fallback_id}")
return (fallback_id,)
# 节点映射字典
NODE_CLASS_MAPPINGS = {
"LogToDB": LogToDB,
"VodToLocalNode": VodToLocalNode,
"UnloadAllModels": UnloadAllModels,
"TraverseFolder": TraverseFolder,
"PlugAndPlayWebhook": PlugAndPlayWebhook,
"TaskIdGenerate": TaskIdGenerate,
}
# 节点显示名称映射
NODE_DISPLAY_NAME_MAPPINGS = {
"LogToDB": "数据库日志记录",
"VodToLocalNode": "VOD文件下载",
"UnloadAllModels": "卸载所有模型",
"TraverseFolder": "文件夹遍历",
"PlugAndPlayWebhook": "Webhook通知",
"TaskIdGenerate": "TaskID生成器",
}