341 lines
11 KiB
Python
341 lines
11 KiB
Python
import glob
|
||
import json
|
||
import os
|
||
import time
|
||
import uuid
|
||
from pathlib import Path
|
||
|
||
import comfy.model_management
|
||
import requests
|
||
import server
|
||
import yaml
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import sessionmaker
|
||
from tencentcloud.common import credential
|
||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||
from tencentcloud.vod.v20180717 import vod_client, models
|
||
|
||
from ..utils.task_table import Task
|
||
|
||
|
||
class LogToDB:
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"job_id": ("STRING", {"forceInput": True}),
|
||
"log": ("STRING", {"forceInput": True}),
|
||
"status": ("INT", {"default": 1, "max": 1}),
|
||
"sql_url": ("STRING", {
|
||
"default": "mysql+pymysql://root:root@example.com:3306/test"}),
|
||
},
|
||
"hidden": {
|
||
"unique_id": "UNIQUE_ID",
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
|
||
FUNCTION = "log2db"
|
||
|
||
OUTPUT_NODE = True
|
||
OUTPUT_IS_LIST = (True,)
|
||
|
||
# OUTPUT_NODE = False
|
||
|
||
CATEGORY = "不忘科技-自定义节点🚩/工具"
|
||
|
||
def log2db(self, log, status, sql_url, unique_id):
|
||
# 获取comfy服务器队列信息
|
||
(_, prompt_id, prompt, extra_data, outputs_to_execute) = next(
|
||
iter(server.PromptServer.instance.prompt_queue.currently_running.values()))
|
||
job_id = extra_data["client_id"]
|
||
engine = create_engine(
|
||
sql_url,
|
||
echo=True
|
||
)
|
||
# Base.metadata.create_all(engine)
|
||
session = sessionmaker(bind=engine)()
|
||
# 查询
|
||
tasks = session.query(Task).filter(Task.prompt_id == prompt_id).all()
|
||
print(prompt)
|
||
result = {
|
||
"curr_node_id": str(unique_id),
|
||
"last_node_id": list(prompt.keys())[-1],
|
||
"node_output": str(log)
|
||
}
|
||
if len(tasks) == 0:
|
||
# 不存在插入
|
||
task = Task(prompt_id=prompt_id, job_id=job_id, result=json.dumps(result), status=status)
|
||
session.add(task)
|
||
elif len(tasks) == 1:
|
||
# 存在更新
|
||
session.query(Task).filter(Task.prompt_id == prompt_id).update({"result": json.dumps(result),
|
||
"status": status})
|
||
else:
|
||
# 异常报错
|
||
raise RuntimeError("状态数据库prompt_id不唯一, 无法记录状态!")
|
||
session.commit()
|
||
return {"ui": {"text": json.dumps(result)}, "result": (json.dumps(result),)}
|
||
|
||
|
||
class VodToLocalNode:
|
||
|
||
def __init__(self):
|
||
if "aws_key_id" in list(os.environ.keys()):
|
||
yaml_config = os.environ
|
||
else:
|
||
with open(
|
||
os.path.join(
|
||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config.yaml"
|
||
),
|
||
encoding="utf-8",
|
||
mode="r+",
|
||
) as f:
|
||
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
|
||
self.secret_id = yaml_config["cos_secret_id"]
|
||
self.secret_key = yaml_config["cos_secret_key"]
|
||
self.region = yaml_config["cos_region"]
|
||
self.vod_client = self.init_vod_client()
|
||
|
||
def init_vod_client(self):
|
||
"""初始化VOD客户端"""
|
||
try:
|
||
http_profile = HttpProfile(endpoint="vod.tencentcloudapi.com")
|
||
client_profile = ClientProfile(httpProfile=http_profile)
|
||
cred = credential.Credential(self.secret_id, self.secret_key)
|
||
return vod_client.VodClient(cred, self.region, client_profile)
|
||
except Exception as e:
|
||
raise RuntimeError(f"VOD client initialization failed: {e}")
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"file_id": ("STRING", {"default": ""}),
|
||
"sub_app_id": ("STRING", {"default": ""}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("local_path",)
|
||
FUNCTION = "execute"
|
||
CATEGORY = "不忘科技-自定义节点🚩/工具"
|
||
|
||
def execute(self, file_id, sub_app_id):
|
||
# 调用下载逻辑
|
||
local_path = self.download_vod(file_id, sub_app_id)
|
||
print(f"下载成功: {local_path}")
|
||
return (local_path,)
|
||
|
||
def _get_download_url(self, file_id, sub_app_id):
|
||
"""获取媒体文件下载地址"""
|
||
try:
|
||
req = models.DescribeMediaInfosRequest()
|
||
req.FileIds = [file_id]
|
||
req.SubAppId = int(sub_app_id)
|
||
|
||
resp = self.vod_client.DescribeMediaInfos(req)
|
||
if not resp.MediaInfoSet:
|
||
raise ValueError("File not found")
|
||
|
||
media_info = resp.MediaInfoSet[0]
|
||
if not media_info.BasicInfo.MediaUrl:
|
||
raise ValueError("No download URL available")
|
||
|
||
return media_info.BasicInfo.MediaUrl
|
||
except Exception as e:
|
||
raise RuntimeError(f"Tencent API error: {e}")
|
||
|
||
def create_directory(self, path):
|
||
p = Path(path)
|
||
if not p.exists():
|
||
p.mkdir(
|
||
parents=True, exist_ok=True
|
||
) # parents=True会自动创建所有必需的父目录,exist_ok=True表示如果目录已存在则不会引发异常
|
||
print(f"目录已创建: {path}")
|
||
else:
|
||
print(f"目录已存在: {path}")
|
||
|
||
def download_vod(self, file_id, sub_app_id):
|
||
"""
|
||
需要补充腾讯云VOD SDK调用逻辑
|
||
返回本地文件路径
|
||
"""
|
||
media_url = self._get_download_url(file_id=file_id, sub_app_id=sub_app_id)
|
||
print(f"download from url: {media_url}")
|
||
# 生成一个临时目录路径名并创建该目录
|
||
self.create_directory(
|
||
os.path.join(
|
||
os.path.dirname(os.path.abspath(__file__)), "download", f"{sub_app_id}"
|
||
)
|
||
)
|
||
output_dir = os.path.join(
|
||
os.path.dirname(os.path.abspath(__file__)),
|
||
"download",
|
||
f"{sub_app_id}",
|
||
f"{file_id}.mp4",
|
||
)
|
||
# 判断文件是否存在
|
||
if os.path.exists(output_dir):
|
||
return output_dir
|
||
return self._download_file(url=media_url, save_path=output_dir, timeout=60 * 10)
|
||
|
||
def _download_file(self, url: str, save_path: str, timeout: int = 30):
|
||
"""下载文件到本地"""
|
||
try:
|
||
with requests.get(url, stream=True, timeout=timeout) as response:
|
||
response.raise_for_status()
|
||
with open(save_path, "wb") as f:
|
||
for chunk in response.iter_content(chunk_size=8192):
|
||
if chunk:
|
||
f.write(chunk)
|
||
return save_path
|
||
except Exception as e:
|
||
raise RuntimeError(f"Download error: {e}")
|
||
|
||
|
||
class AnyType(str):
|
||
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
|
||
|
||
def __ne__(self, __value: object) -> bool:
|
||
return False
|
||
|
||
|
||
any = AnyType("*")
|
||
|
||
|
||
class UnloadAllModels:
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"any": (any, {"forceInput": True})
|
||
},
|
||
"optional": {},
|
||
}
|
||
|
||
RETURN_TYPES = ()
|
||
FUNCTION = "unload_models"
|
||
CATEGORY = "不忘科技-自定义节点🚩/工具"
|
||
OUTPUT_NODE = True
|
||
|
||
def unload_models(self, any=None):
|
||
# 卸载所有已加载的模型
|
||
comfy.model_management.soft_empty_cache()
|
||
comfy.model_management.unload_all_models()
|
||
comfy.model_management.soft_empty_cache()
|
||
return ()
|
||
|
||
|
||
class TraverseFolder:
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"folder": ("STRING", {"default": r"E:\comfy\ComfyUI\input\s3", "required": True}),
|
||
"subfix": ("STRING", {"default": ".mp4", "required": True}),
|
||
"recursive": ("BOOLEAN", {"default": True, "required": True}),
|
||
"idx": (
|
||
"INT",
|
||
{"default": 0, "min": 0, "max": 0xFFFFFF},
|
||
),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("文件路径",)
|
||
|
||
FUNCTION = "compute"
|
||
|
||
CATEGORY = "不忘科技-自定义节点🚩/工具"
|
||
|
||
def compute(self, folder, subfix, recursive, idx):
|
||
files = glob.glob(os.path.join(folder, r"**\*%s" % subfix), recursive=recursive)
|
||
if len(files) == 0:
|
||
raise RuntimeError("No Files Found")
|
||
return (str(files[idx % len(files)]),)
|
||
|
||
|
||
class PlugAndPlayWebhook:
|
||
"""即插即用Webhook节点:连上线就能转发数据"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"webhook_url": ("STRING", {"default": "http://127.0.0.1:8010/handler/webhook",
|
||
"placeholder": "https://your-api.com/webhook"}),
|
||
"image_url": ("STRING", {"default": "",
|
||
"placeholder": "图片的url"}),
|
||
},
|
||
"optional": {
|
||
"prompt_id": ("STRING", {"default": ""}),
|
||
},
|
||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "unique_id": "UNIQUE_ID"},
|
||
}
|
||
|
||
RETURN_TYPES = ()
|
||
FUNCTION = "send"
|
||
OUTPUT_NODE = True
|
||
CATEGORY = "不忘科技-自定义节点🚩/工具"
|
||
|
||
def send(self, webhook_url, image_url, prompt_id="", prompt=None, extra_pnginfo=None, unique_id=None):
|
||
if not webhook_url:
|
||
raise ValueError("❌ 请填写Webhook URL!")
|
||
|
||
# 使用传入的prompt_id,如果没有则用unique_id
|
||
final_prompt_id = prompt_id or unique_id or "unknown"
|
||
|
||
# 准备发送的数据
|
||
data = {
|
||
"img_base64": image_url,
|
||
"format": "png",
|
||
"image_url": image_url,
|
||
"prompt_id": final_prompt_id,
|
||
"timestamp": time.time()
|
||
}
|
||
|
||
# 发送Webhook
|
||
try:
|
||
response = requests.post(webhook_url, json=data)
|
||
response.raise_for_status()
|
||
print(f'发送的数据:{data}')
|
||
except Exception as e:
|
||
print(f"❌ 发送失败: {str(e)}")
|
||
|
||
# 终端节点,无需返回
|
||
return ()
|
||
|
||
|
||
class TaskIdGenerate:
|
||
"""TaskID生成器:用户可传入或自动生成TaskID"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {},
|
||
"optional": {
|
||
"custom_task_id": ("STRING", {"default": "", "placeholder": "留空则自动生成"}),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("task_id",)
|
||
FUNCTION = "generate_task_id"
|
||
OUTPUT_NODE = False
|
||
CATEGORY = "不忘科技-自定义节点🚩/工具"
|
||
|
||
def generate_task_id(self, custom_task_id=""):
|
||
if custom_task_id and custom_task_id.strip():
|
||
# 用户输入了自定义ID
|
||
task_id = custom_task_id.strip()
|
||
print(f"📝 使用自定义TaskID: {task_id}")
|
||
else:
|
||
# 自动生成UUID
|
||
task_id = str(uuid.uuid4())
|
||
print(f"🎲 自动生成TaskID: {task_id}")
|
||
|
||
return (task_id,)
|