ComfyUI-CustomNode/__init__.py

651 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import glob
import json
import os
import shutil
import traceback
import urllib.request
import uuid
from datetime import datetime
import server
import cv2
import ffmpy
import numpy as np
import torch
import yaml
from comfy import model_management
from qcloud_cos import CosConfig, CosClientError, CosServiceError
from qcloud_cos import CosS3Client
from sqlalchemy import Column, Integer, func, DateTime, ForeignKey, String, create_engine
from sqlalchemy.orm import sessionmaker
from ultralytics import YOLO
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
from .test_single_image import test_node
video_extensions = ["webm", "mp4", "mkv", "gif", "mov"]
class FaceDetect:
"""
人脸遮挡检测
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"main_seed": (
"INT:seed",
{"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF},
),
"model": (["convnext_tiny", "convnext_base"],),
"length": ("INT", {"default": 10, "min": 3, "max": 60, "step": 1}),
"threshold": (
"FLOAT",
{"default": 94, "min": 55, "max": 99, "step": 0.1},
),
},
}
RETURN_TYPES = (
"IMAGE",
"IMAGE",
"STRING",
"STRING",
"STRING",
"STRING",
"STRING",
"INT",
"INT",
)
RETURN_NAMES = (
"图像",
"选中人脸",
"分类",
"概率",
"采用帧序号",
"全部帧序列",
"剪辑配置",
"起始帧序号",
"帧数量",
)
FUNCTION = "predict"
CATEGORY = "不忘科技-自定义节点🚩"
def predict(self, image, main_seed, model, length, threshold):
image, image_selected, cls, prob, nums, period = test_node(
image, length=length, thres=threshold, model_name=model
)
print("全部帧序列", period)
if len(period) > 0:
start, end = period[main_seed % len(period)]
config = {"start": start, "end": end}
else:
raise RuntimeError("未找到符合要求的视频片段")
return (
image,
image_selected,
cls,
prob,
nums,
str(period),
json.dumps(config),
start,
end - start + 1,
)
class FaceExtract:
"""人脸提取 By YOLO"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("图片",)
FUNCTION = "crop"
CATEGORY = "不忘科技-自定义节点🚩"
def crop(self, image):
device = model_management.get_torch_device()
image_np = 255.0 * image.cpu().numpy()
model = YOLO(
model=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"model",
"yolov8n-face-lindevs.pt",
)
)
total_images = image_np.shape[0]
out_images = np.ndarray(shape=(total_images, 512, 512, 3))
print("shape", image_np.shape)
print("aaaaa")
idx = 0
for image_item in image_np:
results = model.predict(
image_item, imgsz=640, conf=0.75, iou=0.7, device=device, verbose=False
)
n = 512
r = results[0]
if len(r.boxes.data.cpu().numpy()) == 1:
y1, x1, y2, x2, p, cls = r.boxes.data.cpu().numpy()[0]
face_size = int(max(y2 - y1, x2 - x1))
center = (x1 + x2) // 2, (y1 + y2) // 2
x1, x2, y1, y2 = (
center[0] - face_size // 2,
center[0] + face_size // 2,
center[1] - face_size // 2,
center[1] + face_size // 2,
)
template = np.ndarray(shape=(face_size, face_size, 3))
template.fill(20)
for a, a1 in zip(list(range(int(x1), int(x2))), list(range(face_size))):
for b, b1 in zip(
list(range(int(y1), int(y2))), list(range(face_size))
):
if (a >= 0 and a < r.orig_img.shape[0]) and (
b >= 0 and b < r.orig_img.shape[1]
):
template[a1][b1] = r.orig_img[a][b]
print(int(x1), int(x2), int(y1), int(y2))
img = cv2.resize(template, (n, n))
out_images[idx] = img
idx += 1
else:
idx += 1
cropped_face = np.array(out_images).astype(np.float32) / 255.0
cropped_face = torch.from_numpy(cropped_face)
return (cropped_face,)
class COSDownload:
# TODO 增加腾讯云VOD视频下载
"""腾讯云COS下载"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cos_key": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("视频存储路径",)
FUNCTION = "download"
CATEGORY = "不忘科技-自定义节点🚩"
def download(self, cos_key):
if os.sep in cos_key or "/" in cos_key or "\\" in cos_key:
os.makedirs(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"download",
os.path.dirname(cos_key),
),
exist_ok=True,
)
for i in range(0, 10):
try:
with open(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "config.yaml"
),
encoding="utf-8",
mode="r+",
) as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
config = CosConfig(
Region=yaml_config["region"],
SecretId=yaml_config["secret_id"],
SecretKey=yaml_config["secret_key"],
)
client = CosS3Client(config)
response = client.download_file(
Bucket=yaml_config["bucket"],
Key=cos_key,
DestFilePath=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"download",
os.path.dirname(cos_key),
os.path.basename(cos_key),
),
)
break
except CosClientError or CosServiceError as e:
print(f"下载失败 {e}")
return (
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"download",
os.path.dirname(cos_key),
os.path.basename(cos_key),
),
)
class COSUpload:
"""腾讯云COS上传"""
# TODO 增加腾讯云VOD视频上传
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"path": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("COS文件Key",)
FUNCTION = "upload"
CATEGORY = "不忘科技-自定义节点🚩"
def upload(self, path):
for i in range(0, 10):
try:
with open(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "config.yaml"
),
encoding="utf-8",
mode="r+",
) as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
config = CosConfig(
Region=yaml_config["region"],
SecretId=yaml_config["secret_id"],
SecretKey=yaml_config["secret_key"],
)
client = CosS3Client(config)
response = client.upload_file(
Bucket=yaml_config["bucket"],
Key="/".join(
[
yaml_config["subfolder"],
(
path.split("/")[-1]
if "/" in path
else path.split("\\")[-1]
),
]
),
LocalFilePath=path,
)
break
except CosClientError or CosServiceError as e:
raise RuntimeError("上传失败")
data = {"prompt_id": "",
"video_url": "https://{}.cos.{}.myqcloud.com/{}".format(yaml_config['bucket'], yaml_config['region'],
'/'.join([yaml_config['subfolder'],
path.split('/')[
-1] if '/' in path else
path.split('\\')[-1], ]))
}
headers = {'Content-Type': 'application/json'}
try:
req = urllib.request.Request("", data=json.dumps(data).encode("utf-8"), headers=headers)
response = urllib.request.urlopen(req)
except:
raise RuntimeError("上报MQ状态失败")
return (
"/".join(
[
yaml_config["subfolder"],
path.split("/")[-1] if "/" in path else path.split("\\")[-1],
]
),
)
class Task(Base):
__tablename__ = 'task'
id = Column(Integer, primary_key=True)
gmt_create = Column(DateTime(timezone=True), server_default=func.now())
gmt_modified = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
prompt_id = Column(String, index=True, nullable=False, unique=True)
result = Column(String, nullable=True)
job_id = Column(Integer, index=True, nullable=False, unique=True)
status = Column(Integer)
def __repr__(self):
return f"{self.id},{self.gmt_create},{self.gmt_modified},{self.prompt_id},{self.result}"
class LogToDB:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"job_id": ("STRING",{"forceInput": True}),
"log": ("STRING",{"forceInput": True}),
"status": ("INT",{"default": 1, "max": 1}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "log2db"
OUTPUT_NODE = True
OUTPUT_IS_LIST = (True,)
# OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩"
def log2db(self, log, status, unique_id):
# 获取comfy服务器队列信息
(_, prompt_id, prompt, extra_data, outputs_to_execute) = next(
iter(server.PromptServer.instance.prompt_queue.currently_running.values()))
job_id = extra_data["client_id"]
engine = create_engine(
"mysql+pymysql://root:*k3&5xxG6oqHJM@sh-cdb-1xspb808.sql.tencentcdb.com:28795/comfy",
echo=True
)
# Base.metadata.create_all(engine)
session = sessionmaker(bind=engine)()
# 查询
tasks = session.query(Task).filter(Task.prompt_id == prompt_id).all()
print(prompt)
result = {
"curr_node_id": str(unique_id),
"last_node_id": list(prompt.keys())[-1],
"node_output": str(log)
}
if len(tasks) == 0:
# 不存在插入
task = Task(prompt_id=prompt_id, job_id=job_id, result=json.dumps(result), status=status)
session.add(task)
elif len(tasks) == 1:
# 存在更新
session.query(Task).filter(Task.prompt_id == prompt_id).update({"result": json.dumps(result),
"status": status})
else:
# 异常报错
raise RuntimeError("状态数据库prompt_id不唯一, 无法记录状态!")
session.commit()
return {"ui": {"text": json.dumps(result)}, "result": (json.dumps(result),)}
class VideoCut:
"""FFMPEG视频剪辑 -- !有卡顿问题 暂废弃"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"config": ("STRING",),
"video_path": ("STRING",),
"mod": ("INT",),
"fps": ("FLOAT",),
"period_length": (
"INT",
{
"default": 10,
"min": 4,
"max": 100,
"step": 1,
"forceInput": True,
},
),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("视频路径",)
FUNCTION = "cut"
# OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩"
def cut(self, config, video_path, mod, fps, period_length):
# 原文件名
origin_fname = ".".join(video_path.split(os.sep)[-1].split(".")[:-1])
# 配置获取
mul = mod / fps
print("fps", fps)
config = json.loads(config)
if len(config.keys()) == 0:
return ("无法生成符合要求的片段",)
start, end = config["start"], config["end"]
# 新文件名 复制改名适配ffmpeg
uid = uuid.uuid1()
temp_fname = os.sep.join(
[
*video_path.split(os.sep)[:-1],
"%s.%s" % (str(uid), video_path.split(".")[-1]),
]
)
try:
shutil.copy(video_path, temp_fname)
except:
return ("请检查输入文件权限",)
video_path = temp_fname
# 组装输出文件名
output_name = ".".join(
[
*video_path.split(os.sep)[-1].split(".")[:-2],
video_path.split(os.sep)[-1].split(".")[-2]
+ "_output_%%03d_%s" % datetime.now().strftime("%Y%m%d_%H%M%S"),
video_path.split(os.sep)[-1].split(".")[-1],
]
)
output = (
os.sep.join([*video_path.split(os.sep)[:-1], output_name])
.replace(
os.sep.join(["ComfyUI", "input"]), os.sep.join(["ComfyUI", "output"])
)
.replace(" ", "")
)
# 调用ffmpeg
ff = ffmpy.FFmpeg(
inputs={video_path: ["-accurate_seek"]},
outputs={
output: [
"-f",
"segment",
"-ss",
str(round(start * mul, 3)),
"-to",
str(round(end * mul, 3)),
"-segment_times",
str(period_length),
"-c",
"copy",
"-map",
"0",
"-avoid_negative_ts",
"1",
]
},
)
print(ff.cmd)
ff.run()
# uuid填充改回原文件名
try:
os.remove(temp_fname)
except:
pass
try:
files = glob.glob(output.replace("%03d", "*"))
for file in files:
shutil.move(file, file.replace(str(uid), origin_fname))
files = glob.glob(
output.replace(str(uid), origin_fname).replace("%03d", "*")
)
return (str(files),)
except:
files = glob.glob(output.replace("%03d", "*"))
traceback.print_exc()
return (str(files),)
# Add custom API routes, using router
from aiohttp import web
from server import PromptServer
@PromptServer.instance.routes.get("/hello")
async def get_hello(request):
return web.json_response("hello")
# 腾讯云 VOD
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.common import credential
from tencentcloud.vod.v20180717 import vod_client, models
import requests
from pathlib import Path
class VodToLocalNode:
def __init__(self):
self.secret_id = "AKIDsrihIyjZOBsjimt8TsN8yvv1AMh5dB44"
self.secret_key = "CPZcxdk6W39Jd4cGY95wvupoyMd0YFqW"
self.vod_client = self.init_vod_client()
def init_vod_client(self):
"""初始化VOD客户端"""
try:
http_profile = HttpProfile(endpoint="vod.tencentcloudapi.com")
client_profile = ClientProfile(httpProfile=http_profile)
cred = credential.Credential(self.secret_id, self.secret_key)
return vod_client.VodClient(cred, "ap-shanghai", client_profile)
except Exception as e:
raise RuntimeError(f"VOD client initialization failed: {e}")
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"file_id": ("STRING", {"default": ""}),
"sub_app_id": ("STRING", {"default": ""}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("local_path",)
FUNCTION = "execute"
CATEGORY = "video"
def execute(self, file_id, sub_app_id):
# 调用下载逻辑
local_path = self.download_vod(file_id, sub_app_id)
print(f"下载成功: {local_path}")
return (local_path,)
def _get_download_url(self, file_id, sub_app_id):
"""获取媒体文件下载地址"""
try:
req = models.DescribeMediaInfosRequest()
req.FileIds = [file_id]
req.SubAppId = int(sub_app_id)
resp = self.vod_client.DescribeMediaInfos(req)
if not resp.MediaInfoSet:
raise ValueError("File not found")
media_info = resp.MediaInfoSet[0]
if not media_info.BasicInfo.MediaUrl:
raise ValueError("No download URL available")
return media_info.BasicInfo.MediaUrl
except Exception as e:
raise RuntimeError(f"Tencent API error: {e}")
def create_directory(self, path):
p = Path(path)
if not p.exists():
p.mkdir(
parents=True, exist_ok=True
) # parents=True会自动创建所有必需的父目录exist_ok=True表示如果目录已存在则不会引发异常
print(f"目录已创建: {path}")
else:
print(f"目录已存在: {path}")
def download_vod(self, file_id, sub_app_id):
"""
需要补充腾讯云VOD SDK调用逻辑
返回本地文件路径
"""
media_url = self._get_download_url(file_id=file_id, sub_app_id=sub_app_id)
print(f"download from url: {media_url}")
# 生成一个临时目录路径名并创建该目录
self.create_directory(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "download", f"{sub_app_id}"
)
)
output_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"download",
f"{sub_app_id}",
f"{file_id}.mp4",
)
# 判断文件是否存在
if os.path.exists(output_dir):
return output_dir
return self._download_file(url=media_url, save_path=output_dir, timeout=60 * 10)
def _download_file(self, url: str, save_path: str, timeout: int = 30):
"""下载文件到本地"""
try:
with requests.get(url, stream=True, timeout=timeout) as response:
response.raise_for_status()
with open(save_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
return save_path
except Exception as e:
raise RuntimeError(f"Download error: {e}")
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"FaceOccDetect": FaceDetect,
"FaceExtract": FaceExtract,
"COSUpload": COSUpload,
"COSDownload": COSDownload,
"VideoCutCustom": VideoCut,
"VodToLocal": VodToLocalNode,
"LogToDB": LogToDB
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"FaceOccDetect": "面部遮挡检测",
"FaceExtract": "面部提取",
"COSUpload": "COS上传",
"COSDownload": "COS下载",
"VideoCutCustom": "视频剪裁",
"VodToLocal": "腾讯云VOD下载",
"LogToDB": "状态持久化DB"
}