ComfyUI-CustomNode/nodes/util_nodes.py

341 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import glob
import json
import os
import time
import uuid
from pathlib import Path
import comfy.model_management
import requests
import server
import yaml
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.vod.v20180717 import vod_client, models
from ..utils.task_table import Task
class LogToDB:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"job_id": ("STRING", {"forceInput": True}),
"log": ("STRING", {"forceInput": True}),
"status": ("INT", {"default": 1, "max": 1}),
"sql_url": ("STRING", {
"default": "mysql+pymysql://root:root@example.com:3306/test"}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "log2db"
OUTPUT_NODE = True
OUTPUT_IS_LIST = (True,)
# OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩/工具"
def log2db(self, log, status, sql_url, unique_id):
# 获取comfy服务器队列信息
(_, prompt_id, prompt, extra_data, outputs_to_execute) = next(
iter(server.PromptServer.instance.prompt_queue.currently_running.values()))
job_id = extra_data["client_id"]
engine = create_engine(
sql_url,
echo=True
)
# Base.metadata.create_all(engine)
session = sessionmaker(bind=engine)()
# 查询
tasks = session.query(Task).filter(Task.prompt_id == prompt_id).all()
print(prompt)
result = {
"curr_node_id": str(unique_id),
"last_node_id": list(prompt.keys())[-1],
"node_output": str(log)
}
if len(tasks) == 0:
# 不存在插入
task = Task(prompt_id=prompt_id, job_id=job_id, result=json.dumps(result), status=status)
session.add(task)
elif len(tasks) == 1:
# 存在更新
session.query(Task).filter(Task.prompt_id == prompt_id).update({"result": json.dumps(result),
"status": status})
else:
# 异常报错
raise RuntimeError("状态数据库prompt_id不唯一, 无法记录状态!")
session.commit()
return {"ui": {"text": json.dumps(result)}, "result": (json.dumps(result),)}
class VodToLocalNode:
def __init__(self):
if "aws_key_id" in list(os.environ.keys()):
yaml_config = os.environ
else:
with open(
os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config.yaml"
),
encoding="utf-8",
mode="r+",
) as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
self.secret_id = yaml_config["cos_secret_id"]
self.secret_key = yaml_config["cos_secret_key"]
self.region = yaml_config["cos_region"]
self.vod_client = self.init_vod_client()
def init_vod_client(self):
"""初始化VOD客户端"""
try:
http_profile = HttpProfile(endpoint="vod.tencentcloudapi.com")
client_profile = ClientProfile(httpProfile=http_profile)
cred = credential.Credential(self.secret_id, self.secret_key)
return vod_client.VodClient(cred, self.region, client_profile)
except Exception as e:
raise RuntimeError(f"VOD client initialization failed: {e}")
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"file_id": ("STRING", {"default": ""}),
"sub_app_id": ("STRING", {"default": ""}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("local_path",)
FUNCTION = "execute"
CATEGORY = "不忘科技-自定义节点🚩/工具"
def execute(self, file_id, sub_app_id):
# 调用下载逻辑
local_path = self.download_vod(file_id, sub_app_id)
print(f"下载成功: {local_path}")
return (local_path,)
def _get_download_url(self, file_id, sub_app_id):
"""获取媒体文件下载地址"""
try:
req = models.DescribeMediaInfosRequest()
req.FileIds = [file_id]
req.SubAppId = int(sub_app_id)
resp = self.vod_client.DescribeMediaInfos(req)
if not resp.MediaInfoSet:
raise ValueError("File not found")
media_info = resp.MediaInfoSet[0]
if not media_info.BasicInfo.MediaUrl:
raise ValueError("No download URL available")
return media_info.BasicInfo.MediaUrl
except Exception as e:
raise RuntimeError(f"Tencent API error: {e}")
def create_directory(self, path):
p = Path(path)
if not p.exists():
p.mkdir(
parents=True, exist_ok=True
) # parents=True会自动创建所有必需的父目录exist_ok=True表示如果目录已存在则不会引发异常
print(f"目录已创建: {path}")
else:
print(f"目录已存在: {path}")
def download_vod(self, file_id, sub_app_id):
"""
需要补充腾讯云VOD SDK调用逻辑
返回本地文件路径
"""
media_url = self._get_download_url(file_id=file_id, sub_app_id=sub_app_id)
print(f"download from url: {media_url}")
# 生成一个临时目录路径名并创建该目录
self.create_directory(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "download", f"{sub_app_id}"
)
)
output_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"download",
f"{sub_app_id}",
f"{file_id}.mp4",
)
# 判断文件是否存在
if os.path.exists(output_dir):
return output_dir
return self._download_file(url=media_url, save_path=output_dir, timeout=60 * 10)
def _download_file(self, url: str, save_path: str, timeout: int = 30):
"""下载文件到本地"""
try:
with requests.get(url, stream=True, timeout=timeout) as response:
response.raise_for_status()
with open(save_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
return save_path
except Exception as e:
raise RuntimeError(f"Download error: {e}")
class AnyType(str):
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
def __ne__(self, __value: object) -> bool:
return False
any = AnyType("*")
class UnloadAllModels:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"any": (any, {"forceInput": True})
},
"optional": {},
}
RETURN_TYPES = ()
FUNCTION = "unload_models"
CATEGORY = "不忘科技-自定义节点🚩/工具"
OUTPUT_NODE = True
def unload_models(self, any=None):
# 卸载所有已加载的模型
comfy.model_management.soft_empty_cache()
comfy.model_management.unload_all_models()
comfy.model_management.soft_empty_cache()
return ()
class TraverseFolder:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"folder": ("STRING", {"default": r"E:\comfy\ComfyUI\input\s3", "required": True}),
"subfix": ("STRING", {"default": ".mp4", "required": True}),
"recursive": ("BOOLEAN", {"default": True, "required": True}),
"idx": (
"INT",
{"default": 0, "min": 0, "max": 0xFFFFFF},
),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("文件路径",)
FUNCTION = "compute"
CATEGORY = "不忘科技-自定义节点🚩/工具"
def compute(self, folder, subfix, recursive, idx):
files = glob.glob(os.path.join(folder, r"**\*%s" % subfix), recursive=recursive)
if len(files) == 0:
raise RuntimeError("No Files Found")
return (str(files[idx % len(files)]),)
class PlugAndPlayWebhook:
"""即插即用Webhook节点连上线就能转发数据"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"webhook_url": ("STRING", {"default": "http://127.0.0.1:8010/handler/webhook",
"placeholder": "https://your-api.com/webhook"}),
"image_url": ("STRING", {"default": "",
"placeholder": "图片的url"}),
},
"optional": {
"prompt_id": ("STRING", {"default": ""}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "unique_id": "UNIQUE_ID"},
}
RETURN_TYPES = ()
FUNCTION = "send"
OUTPUT_NODE = True
CATEGORY = "不忘科技-自定义节点🚩/工具"
def send(self, webhook_url, image_url, prompt_id="", prompt=None, extra_pnginfo=None, unique_id=None):
if not webhook_url:
raise ValueError("❌ 请填写Webhook URL")
# 使用传入的prompt_id如果没有则用unique_id
final_prompt_id = prompt_id or unique_id or "unknown"
# 准备发送的数据
data = {
"img_base64": image_url,
"format": "png",
"image_url": image_url,
"prompt_id": final_prompt_id,
"timestamp": time.time()
}
# 发送Webhook
try:
response = requests.post(webhook_url, json=data)
response.raise_for_status()
print(f'发送的数据:{data}')
except Exception as e:
print(f"❌ 发送失败: {str(e)}")
# 终端节点,无需返回
return ()
class TaskIdGenerate:
"""TaskID生成器用户可传入或自动生成TaskID"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {},
"optional": {
"custom_task_id": ("STRING", {"default": "", "placeholder": "留空则自动生成"}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("task_id",)
FUNCTION = "generate_task_id"
OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩/工具"
def generate_task_id(self, custom_task_id=""):
if custom_task_id and custom_task_id.strip():
# 用户输入了自定义ID
task_id = custom_task_id.strip()
print(f"📝 使用自定义TaskID: {task_id}")
else:
# 自动生成UUID
task_id = str(uuid.uuid4())
print(f"🎲 自动生成TaskID: {task_id}")
return (task_id,)