573 lines
16 KiB
Python
573 lines
16 KiB
Python
"""
|
||
工具节点模块 - 重构版本
|
||
|
||
提供各种实用功能的ComfyUI节点,包括:
|
||
- 数据库日志记录
|
||
- VOD文件下载(使用统一存储抽象层)
|
||
- 模型管理
|
||
- 文件遍历
|
||
- Webhook通知
|
||
- 任务ID生成
|
||
|
||
本模块已经过重构,使用统一的存储抽象层替代直接SDK调用。
|
||
"""
|
||
|
||
import glob
|
||
import json
|
||
import os
|
||
import time
|
||
import uuid
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Tuple
|
||
|
||
import comfy.model_management
|
||
import requests
|
||
import server
|
||
from loguru import logger
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import sessionmaker
|
||
|
||
from ..utils.object_storage import DownloadResult, get_provider
|
||
from ..utils.task_table import Task
|
||
|
||
|
||
class LogToDB:
|
||
"""
|
||
数据库日志记录节点
|
||
|
||
将ComfyUI工作流的执行结果记录到数据库中,支持任务状态跟踪。
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"job_id": ("STRING", {"forceInput": True}),
|
||
"log": ("STRING", {"forceInput": True}),
|
||
"status": ("INT", {"default": 1, "max": 1}),
|
||
"sql_url": (
|
||
"STRING",
|
||
{"default": "mysql+pymysql://root:root@example.com:3306/test"},
|
||
),
|
||
},
|
||
"hidden": {
|
||
"unique_id": "UNIQUE_ID",
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
FUNCTION = "log2db"
|
||
OUTPUT_NODE = True
|
||
OUTPUT_IS_LIST = (True,)
|
||
CATEGORY = "不忘科技-自定义节点🚩/工具"
|
||
|
||
def log2db(
|
||
self, log: str, status: int, sql_url: str, unique_id: str, **kwargs
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
记录日志到数据库
|
||
|
||
Args:
|
||
log: 日志内容
|
||
status: 状态代码
|
||
sql_url: 数据库连接URL
|
||
unique_id: 唯一ID
|
||
|
||
Returns:
|
||
Dict: 包含结果信息的字典
|
||
"""
|
||
try:
|
||
# 获取comfy服务器队列信息
|
||
(_, prompt_id, prompt, extra_data, outputs_to_execute) = next(
|
||
iter(
|
||
server.PromptServer.instance.prompt_queue.currently_running.values()
|
||
)
|
||
)
|
||
job_id = extra_data["client_id"]
|
||
|
||
engine = create_engine(sql_url, echo=True)
|
||
session = sessionmaker(bind=engine)()
|
||
|
||
# 查询现有任务
|
||
tasks = session.query(Task).filter(Task.prompt_id == prompt_id).all()
|
||
|
||
result = {
|
||
"curr_node_id": str(unique_id),
|
||
"last_node_id": list(prompt.keys())[-1],
|
||
"node_output": str(log),
|
||
}
|
||
|
||
if len(tasks) == 0:
|
||
# 不存在则插入
|
||
task = Task(
|
||
prompt_id=prompt_id,
|
||
job_id=job_id,
|
||
result=json.dumps(result),
|
||
status=status,
|
||
)
|
||
session.add(task)
|
||
logger.info(f"新增任务记录: {prompt_id}")
|
||
elif len(tasks) == 1:
|
||
# 存在则更新
|
||
session.query(Task).filter(Task.prompt_id == prompt_id).update(
|
||
{"result": json.dumps(result), "status": status}
|
||
)
|
||
logger.info(f"更新任务记录: {prompt_id}")
|
||
else:
|
||
# 异常情况
|
||
session.rollback()
|
||
raise RuntimeError("状态数据库prompt_id不唯一, 无法记录状态!")
|
||
|
||
session.commit()
|
||
session.close()
|
||
|
||
return {"ui": {"text": json.dumps(result)}, "result": (json.dumps(result),)}
|
||
|
||
except Exception as e:
|
||
logger.error(f"数据库日志记录失败: {e}")
|
||
return {"ui": {"text": str(e)}, "result": (str(e),)}
|
||
|
||
|
||
class VodToLocalNode:
|
||
"""
|
||
腾讯云VOD文件下载节点
|
||
|
||
使用统一的存储抽象层从腾讯云视频点播(VOD)服务下载媒体文件到本地。
|
||
支持通过文件ID和子应用ID定位和下载视频文件。
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""
|
||
初始化VOD下载节点
|
||
|
||
使用统一的存储管理器获取VOD存储提供者,
|
||
替代直接使用腾讯云VOD SDK的方式。
|
||
"""
|
||
try:
|
||
# 使用统一存储管理器获取VOD提供者
|
||
self.vod_provider = get_provider("vod")
|
||
logger.info("VOD下载节点初始化成功")
|
||
|
||
except Exception as e:
|
||
logger.error(f"VOD下载节点初始化失败: {e}")
|
||
raise RuntimeError(f"VOD节点初始化失败: {e}")
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"file_id": ("STRING", {"default": ""}),
|
||
"sub_app_id": ("STRING", {"default": ""}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("local_path",)
|
||
FUNCTION = "execute"
|
||
CATEGORY = "不忘科技-自定义节点🚩/工具"
|
||
|
||
def execute(self, file_id: str, sub_app_id: str) -> Tuple[str]:
|
||
"""
|
||
执行VOD文件下载
|
||
|
||
Args:
|
||
file_id: VOD文件ID
|
||
sub_app_id: 子应用ID
|
||
|
||
Returns:
|
||
tuple: 包含本地文件路径的元组
|
||
|
||
Raises:
|
||
Exception: 下载失败时抛出异常
|
||
"""
|
||
try:
|
||
# 参数验证
|
||
if not file_id or not file_id.strip():
|
||
raise ValueError("文件ID不能为空")
|
||
if not sub_app_id or not sub_app_id.strip():
|
||
raise ValueError("子应用ID不能为空")
|
||
|
||
# 调用下载逻辑
|
||
local_path = self.download_vod(file_id.strip(), sub_app_id.strip())
|
||
logger.info(f"VOD文件下载成功: {local_path}")
|
||
return (local_path,)
|
||
|
||
except Exception as e:
|
||
logger.error(f"VOD文件下载执行失败: {e}")
|
||
raise Exception(f"VOD文件下载失败: {str(e)}")
|
||
|
||
def create_directory(self, path: str) -> None:
|
||
"""
|
||
创建目录(如果不存在)
|
||
|
||
Args:
|
||
path: 目录路径
|
||
"""
|
||
p = Path(path)
|
||
if not p.exists():
|
||
p.mkdir(parents=True, exist_ok=True)
|
||
logger.info(f"目录已创建: {path}")
|
||
else:
|
||
logger.debug(f"目录已存在: {path}")
|
||
|
||
def download_vod(self, file_id: str, sub_app_id: str) -> str:
|
||
"""
|
||
下载腾讯云VOD文件到本地
|
||
|
||
Args:
|
||
file_id: VOD文件ID
|
||
sub_app_id: 子应用ID
|
||
|
||
Returns:
|
||
str: 本地文件路径
|
||
|
||
Raises:
|
||
Exception: 下载失败时抛出异常
|
||
"""
|
||
try:
|
||
# 生成本地存储路径
|
||
download_dir = os.path.join(
|
||
os.path.dirname(os.path.abspath(__file__)), "download", f"{sub_app_id}"
|
||
)
|
||
self.create_directory(download_dir)
|
||
|
||
output_path = os.path.join(download_dir, f"{file_id}.mp4")
|
||
|
||
# 如果文件已存在,直接返回
|
||
if os.path.exists(output_path):
|
||
logger.info(f"VOD文件已存在,直接返回: {output_path}")
|
||
return output_path
|
||
|
||
logger.info(f"开始下载VOD文件: {file_id} -> {output_path}")
|
||
|
||
# 使用VOD提供者下载文件
|
||
result: DownloadResult = self.vod_provider.download_file(
|
||
file_id, output_path, sub_app_id, timeout=600
|
||
)
|
||
|
||
if result.success:
|
||
logger.info(f"VOD文件下载成功: {result.local_path}")
|
||
return result.local_path
|
||
else:
|
||
raise Exception(result.message or "VOD下载失败")
|
||
|
||
except Exception as e:
|
||
logger.error(f"VOD文件下载异常: {e}")
|
||
raise Exception(f"VOD下载失败: {str(e)}")
|
||
|
||
|
||
class AnyType(str):
|
||
"""特殊类型,用于在不等比较中始终相等。Credit to pythongosssss"""
|
||
|
||
def __ne__(self, __value: object) -> bool:
|
||
return False
|
||
|
||
|
||
any = AnyType("*")
|
||
|
||
|
||
class UnloadAllModels:
|
||
"""
|
||
卸载所有模型节点
|
||
|
||
释放GPU内存,卸载所有已加载的模型。
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {"any": (any, {"forceInput": True})},
|
||
"optional": {},
|
||
}
|
||
|
||
RETURN_TYPES = ()
|
||
FUNCTION = "unload_models"
|
||
CATEGORY = "不忘科技-自定义节点🚩/工具"
|
||
OUTPUT_NODE = True
|
||
|
||
def unload_models(self, any=None) -> Tuple:
|
||
"""
|
||
卸载所有已加载的模型
|
||
|
||
Args:
|
||
any: 输入参数(任意类型)
|
||
|
||
Returns:
|
||
tuple: 空元组
|
||
"""
|
||
try:
|
||
logger.info("开始卸载所有模型...")
|
||
|
||
# 卸载所有已加载的模型
|
||
comfy.model_management.soft_empty_cache()
|
||
comfy.model_management.unload_all_models()
|
||
comfy.model_management.soft_empty_cache()
|
||
|
||
logger.info("所有模型已成功卸载")
|
||
return ()
|
||
|
||
except Exception as e:
|
||
logger.error(f"模型卸载失败: {e}")
|
||
# 即使失败也返回空元组,避免中断工作流
|
||
return ()
|
||
|
||
|
||
class TraverseFolder:
|
||
"""
|
||
文件夹遍历节点
|
||
|
||
遍历指定文件夹,查找符合条件的文件。
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"folder": (
|
||
"STRING",
|
||
{"default": r"E:\comfy\ComfyUI\input\s3", "required": True},
|
||
),
|
||
"subfix": ("STRING", {"default": ".mp4", "required": True}),
|
||
"recursive": ("BOOLEAN", {"default": True, "required": True}),
|
||
"idx": (
|
||
"INT",
|
||
{"default": 0, "min": 0, "max": 0xFFFFFF},
|
||
),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("文件路径",)
|
||
FUNCTION = "compute"
|
||
CATEGORY = "不忘科技-自定义节点🚩/工具"
|
||
|
||
def compute(
|
||
self, folder: str, subfix: str, recursive: bool, idx: int
|
||
) -> Tuple[str]:
|
||
"""
|
||
遍历文件夹查找文件
|
||
|
||
Args:
|
||
folder: 文件夹路径
|
||
subfix: 文件后缀
|
||
recursive: 是否递归查找
|
||
idx: 文件索引
|
||
|
||
Returns:
|
||
tuple: 包含文件路径的元组
|
||
|
||
Raises:
|
||
RuntimeError: 未找到文件时抛出异常
|
||
"""
|
||
try:
|
||
# 参数验证
|
||
if not folder or not folder.strip():
|
||
raise ValueError("文件夹路径不能为空")
|
||
if not subfix or not subfix.strip():
|
||
raise ValueError("文件后缀不能为空")
|
||
|
||
folder = folder.strip()
|
||
subfix = subfix.strip()
|
||
|
||
# 构建搜索模式
|
||
if recursive:
|
||
pattern = os.path.join(folder, f"**/*{subfix}")
|
||
else:
|
||
pattern = os.path.join(folder, f"*{subfix}")
|
||
|
||
# 查找文件
|
||
files = glob.glob(pattern, recursive=recursive)
|
||
|
||
if len(files) == 0:
|
||
logger.warning(f"在文件夹 {folder} 中未找到后缀为 {subfix} 的文件")
|
||
raise RuntimeError("No Files Found")
|
||
|
||
# 选择文件
|
||
selected_file = files[idx % len(files)]
|
||
logger.info(
|
||
f"找到 {len(files)} 个文件,选择第 {idx % len(files)} 个: {selected_file}"
|
||
)
|
||
|
||
return (str(selected_file),)
|
||
|
||
except Exception as e:
|
||
logger.error(f"文件夹遍历失败: {e}")
|
||
if isinstance(e, RuntimeError):
|
||
raise
|
||
else:
|
||
raise RuntimeError(f"文件夹遍历错误: {str(e)}")
|
||
|
||
|
||
class PlugAndPlayWebhook:
|
||
"""
|
||
即插即用Webhook节点
|
||
|
||
连上线就能转发数据到指定的Webhook地址。
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"webhook_url": (
|
||
"STRING",
|
||
{
|
||
"default": "http://127.0.0.1:8010/handler/webhook",
|
||
"placeholder": "https://your-api.com/webhook",
|
||
},
|
||
),
|
||
"image_url": ("STRING", {"default": "", "placeholder": "图片的url"}),
|
||
},
|
||
"optional": {
|
||
"prompt_id": ("STRING", {"default": ""}),
|
||
},
|
||
"hidden": {
|
||
"prompt": "PROMPT",
|
||
"extra_pnginfo": "EXTRA_PNGINFO",
|
||
"unique_id": "UNIQUE_ID",
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ()
|
||
FUNCTION = "send"
|
||
OUTPUT_NODE = True
|
||
CATEGORY = "不忘科技-自定义节点🚩/工具"
|
||
|
||
def send(
|
||
self,
|
||
webhook_url: str,
|
||
image_url: str,
|
||
prompt_id: str = "",
|
||
prompt=None,
|
||
extra_pnginfo=None,
|
||
unique_id=None,
|
||
) -> Tuple:
|
||
"""
|
||
发送数据到Webhook
|
||
|
||
Args:
|
||
webhook_url: Webhook URL
|
||
image_url: 图片URL
|
||
prompt_id: 提示ID
|
||
prompt: 提示信息
|
||
extra_pnginfo: 额外PNG信息
|
||
unique_id: 唯一ID
|
||
|
||
Returns:
|
||
tuple: 空元组
|
||
|
||
Raises:
|
||
ValueError: URL为空时抛出异常
|
||
"""
|
||
try:
|
||
# 参数验证
|
||
if not webhook_url or not webhook_url.strip():
|
||
raise ValueError("❌ 请填写Webhook URL!")
|
||
|
||
# 使用传入的prompt_id,如果没有则用unique_id
|
||
final_prompt_id = prompt_id or unique_id or "unknown"
|
||
|
||
# 准备发送的数据
|
||
data = {
|
||
"img_base64": image_url,
|
||
"format": "png",
|
||
"image_url": image_url,
|
||
"prompt_id": final_prompt_id,
|
||
"timestamp": time.time(),
|
||
}
|
||
|
||
logger.info(f"准备发送Webhook数据到: {webhook_url}")
|
||
|
||
# 发送Webhook
|
||
response = requests.post(webhook_url.strip(), json=data, timeout=30)
|
||
response.raise_for_status()
|
||
|
||
logger.info(f"Webhook发送成功,响应状态: {response.status_code}")
|
||
logger.debug(f"发送的数据: {data}")
|
||
|
||
except requests.RequestException as e:
|
||
logger.error(f"❌ Webhook发送失败: {str(e)}")
|
||
# 不抛出异常,避免中断工作流
|
||
except Exception as e:
|
||
logger.error(f"❌ Webhook发送异常: {str(e)}")
|
||
# 不抛出异常,避免中断工作流
|
||
|
||
# 终端节点,无需返回
|
||
return ()
|
||
|
||
|
||
class TaskIdGenerate:
|
||
"""
|
||
TaskID生成器
|
||
|
||
用户可传入自定义TaskID或自动生成UUID。
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {},
|
||
"optional": {
|
||
"custom_task_id": (
|
||
"STRING",
|
||
{"default": "", "placeholder": "留空则自动生成"},
|
||
),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("task_id",)
|
||
FUNCTION = "generate_task_id"
|
||
OUTPUT_NODE = False
|
||
CATEGORY = "不忘科技-自定义节点🚩/工具"
|
||
|
||
def generate_task_id(self, custom_task_id: str = "") -> Tuple[str]:
|
||
"""
|
||
生成或使用自定义TaskID
|
||
|
||
Args:
|
||
custom_task_id: 自定义TaskID
|
||
|
||
Returns:
|
||
tuple: 包含TaskID的元组
|
||
"""
|
||
try:
|
||
if custom_task_id and custom_task_id.strip():
|
||
# 用户输入了自定义ID
|
||
task_id = custom_task_id.strip()
|
||
logger.info(f"📝 使用自定义TaskID: {task_id}")
|
||
else:
|
||
# 自动生成UUID
|
||
task_id = str(uuid.uuid4())
|
||
logger.info(f"🎲 自动生成TaskID: {task_id}")
|
||
|
||
return (task_id,)
|
||
|
||
except Exception as e:
|
||
logger.error(f"TaskID生成失败: {e}")
|
||
# 发生异常时生成一个基础UUID
|
||
fallback_id = str(uuid.uuid4())
|
||
logger.warning(f"使用备用TaskID: {fallback_id}")
|
||
return (fallback_id,)
|
||
|
||
|
||
# 节点映射字典
|
||
NODE_CLASS_MAPPINGS = {
|
||
"LogToDB": LogToDB,
|
||
"VodToLocalNode": VodToLocalNode,
|
||
"UnloadAllModels": UnloadAllModels,
|
||
"TraverseFolder": TraverseFolder,
|
||
"PlugAndPlayWebhook": PlugAndPlayWebhook,
|
||
"TaskIdGenerate": TaskIdGenerate,
|
||
}
|
||
|
||
# 节点显示名称映射
|
||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||
"LogToDB": "数据库日志记录",
|
||
"VodToLocalNode": "VOD文件下载",
|
||
"UnloadAllModels": "卸载所有模型",
|
||
"TraverseFolder": "文件夹遍历",
|
||
"PlugAndPlayWebhook": "Webhook通知",
|
||
"TaskIdGenerate": "TaskID生成器",
|
||
}
|