ComfyUI-CustomNode/__init__.py

407 lines
14 KiB
Python

import glob
import json
import os
import shutil
import traceback
import uuid
from datetime import datetime
import cv2
import numpy as np
import torch
import yaml
from ultralytics import YOLO
from comfy import model_management
from qcloud_cos import CosConfig, CosClientError, CosServiceError
from qcloud_cos import CosS3Client
from .test_single_image import test_node
import ffmpy
video_extensions = ['webm', 'mp4', 'mkv', 'gif', 'mov']
class FaceDetect:
"""
人脸遮挡检测
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"main_seed": ("INT:seed", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"model": (["convnext_tiny", "convnext_base"],),
"length": ("INT", {"default": 10, "min": 3, "max": 60, "step": 1}),
"threshold": ("FLOAT", {"default": 94, "min": 55, "max": 99, "step": 0.1})
},
}
RETURN_TYPES = ("IMAGE", "IMAGE", "STRING", "STRING", "STRING", "STRING", "STRING", "INT", "INT")
RETURN_NAMES = ("图像", "选中人脸", "分类", "概率", "采用帧序号", "全部帧序列", "剪辑配置", "起始帧序号", "帧数量")
FUNCTION = "predict"
CATEGORY = "不忘科技-自定义节点🚩"
def predict(self, image, main_seed, model, length, threshold):
image, image_selected, cls, prob, nums, period = test_node(image, length=length, thres=threshold,
model_name=model)
print("全部帧序列", period)
if len(period) > 0:
start, end = period[main_seed % len(period)]
config = {"start": start, "end": end}
else:
raise RuntimeError("未找到符合要求的视频片段")
return (image, image_selected, cls, prob, nums, str(period), json.dumps(config), start, end - start+1)
class FaceExtract:
"""人脸提取 By YOLO"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("图片",)
FUNCTION = "crop"
CATEGORY = "不忘科技-自定义节点🚩"
def crop(self, image):
device = model_management.get_torch_device()
image_np = 255. * image.cpu().numpy()
model = YOLO(model=os.path.join(os.path.dirname(os.path.abspath(__file__)), "model", "yolov8n-face-lindevs.pt"))
total_images = image_np.shape[0]
out_images = np.ndarray(shape=(total_images, 512, 512, 3))
print("shape", image_np.shape)
print("aaaaa")
idx = 0
for image_item in image_np:
results = model.predict(
image_item,
imgsz=640,
conf=0.75,
iou=0.7,
device=device,
verbose=False
)
n = 512
r = results[0]
if len(r.boxes.data.cpu().numpy()) == 1:
y1, x1, y2, x2, p, cls = r.boxes.data.cpu().numpy()[0]
face_size = int(max(y2 - y1, x2 - x1))
center = (x1 + x2) // 2, (y1 + y2) // 2
x1, x2, y1, y2 = center[0] - face_size // 2, center[0] + face_size // 2, center[1] - face_size // 2, \
center[1] + face_size // 2
template = np.ndarray(shape=(face_size, face_size, 3))
template.fill(20)
for a, a1 in zip(list(range(int(x1), int(x2))), list(range(face_size))):
for b, b1 in zip(list(range(int(y1), int(y2))), list(range(face_size))):
if (a >= 0 and a <= r.orig_img.shape[1]) and (b >= 0 and b <= r.orig_img.shape[0]):
template[a1][b1] = r.orig_img[a][b]
print(int(x1), int(x2), int(y1), int(y2))
img = cv2.resize(template, (n, n))
out_images[idx] = img
idx += 1
else:
idx += 1
cropped_face = np.array(out_images).astype(np.float32) / 255.0
cropped_face = torch.from_numpy(cropped_face)
return (cropped_face,)
class COSDownload:
# TODO 增加腾讯云VOD视频下载
"""腾讯云COS下载"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"cos_key": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("视频存储路径",)
FUNCTION = "download"
CATEGORY = "不忘科技-自定义节点🚩"
def download(self, cos_key):
if os.sep in cos_key or "/" in cos_key or "\\" in cos_key:
os.makedirs(os.path.join(os.path.dirname(os.path.abspath(__file__)), "download", os.path.dirname(cos_key)),
exist_ok=True)
for i in range(0, 10):
try:
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml"), encoding="utf-8",
mode="r+") as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
config = CosConfig(Region=yaml_config["region"], SecretId=yaml_config["secret_id"],
SecretKey=yaml_config["secret_key"])
client = CosS3Client(config)
response = client.download_file(
Bucket=yaml_config["bucket"],
Key=cos_key,
DestFilePath=os.path.join(os.path.dirname(os.path.abspath(__file__)), "download",
os.path.dirname(cos_key), os.path.basename(cos_key)))
break
except CosClientError or CosServiceError as e:
print(f"下载失败 {e}")
return (os.path.join(os.path.dirname(os.path.abspath(__file__)), "download", os.path.dirname(cos_key),
os.path.basename(cos_key)),)
class COSUpload:
"""腾讯云COS上传"""
# TODO 增加腾讯云VOD视频上传
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"path": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("COS文件Key",)
FUNCTION = "upload"
CATEGORY = "不忘科技-自定义节点🚩"
def upload(self, path):
for i in range(0, 10):
try:
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml"), encoding="utf-8",
mode="r+") as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
config = CosConfig(Region=yaml_config["region"], SecretId=yaml_config["secret_id"],
SecretKey=yaml_config["secret_key"])
client = CosS3Client(config)
response = client.upload_file(
Bucket=yaml_config["bucket"],
Key="/".join(
[yaml_config["subfolder"], path.split("/")[-1] if "/" in path else path.split("\\")[-1]]),
LocalFilePath=path)
break
except CosClientError or CosServiceError as e:
print(e)
return ("/".join([yaml_config["subfolder"], path.split("/")[-1] if "/" in path else path.split("\\")[-1]]),)
class VideoCut:
"""FFMPEG视频剪辑 -- !有卡顿问题 暂废弃"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"config": ("STRING",),
"video_path": ("STRING",),
"mod": ("INT",),
"fps": ("FLOAT",),
"period_length": ("INT", {"default": 10, "min": 4, "max": 100, "step": 1, "forceInput": True})
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("视频路径",)
FUNCTION = "cut"
# OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩"
def cut(self, config, video_path, mod, fps, period_length):
# 原文件名
origin_fname = ".".join(video_path.split(os.sep)[-1].split(".")[:-1])
# 配置获取
mul = mod / fps
print("fps", fps)
config = json.loads(config)
if len(config.keys()) == 0:
return ("无法生成符合要求的片段",)
start, end = config["start"], config["end"]
# 新文件名 复制改名适配ffmpeg
uid = uuid.uuid1()
temp_fname = os.sep.join([*video_path.split(os.sep)[:-1], "%s.%s" % (str(uid), video_path.split(".")[-1])])
try:
shutil.copy(video_path, temp_fname)
except:
return ("请检查输入文件权限",)
video_path = temp_fname
# 组装输出文件名
output_name = (".".join([*video_path.split(os.sep)[-1].split(".")[:-2],
video_path.split(os.sep)[-1].split(".")[-2] +
"_output_%%03d_%s" % datetime.now().strftime('%Y%m%d_%H%M%S'),
video_path.split(os.sep)[-1].split(".")[-1]]))
output = (os.sep.join([*video_path.split(os.sep)[:-1], output_name])
.replace(os.sep.join(["ComfyUI", "input"]), os.sep.join(["ComfyUI", "output"])).replace(" ", ""))
#调用ffmpeg
ff = ffmpy.FFmpeg(
inputs={video_path: ['-accurate_seek']},
outputs={output: [
'-f', 'segment',
'-ss', str(round(start * mul, 3)),
'-to', str(round(end * mul, 3)),
'-segment_times', str(period_length),
'-c', 'copy',
'-map', '0',
'-avoid_negative_ts', '1'
]}
)
print(ff.cmd)
ff.run()
# uuid填充改回原文件名
try:
os.remove(temp_fname)
except:
pass
try:
files = glob.glob(output.replace("%03d", "*"))
for file in files:
shutil.move(file, file.replace(str(uid), origin_fname))
files = glob.glob(output.replace(str(uid), origin_fname).replace("%03d", "*"))
return (str(files),)
except:
files = glob.glob(output.replace("%03d", "*"))
traceback.print_exc()
return (str(files),)
# Add custom API routes, using router
from aiohttp import web
from server import PromptServer
@PromptServer.instance.routes.get("/hello")
async def get_hello(request):
return web.json_response("hello")
# 腾讯云 VOD
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.common import credential
from tencentcloud.vod.v20180717 import vod_client, models
import requests
from pathlib import Path
import tempfile
class VodToLocalNode:
def __init__(self):
self.secret_id = "AKIDsrihIyjZOBsjimt8TsN8yvv1AMh5dB44"
self.secret_key = "CPZcxdk6W39Jd4cGY95wvupoyMd0YFqW"
self.vod_client = self.init_vod_client()
def init_vod_client(self):
"""初始化VOD客户端"""
try:
http_profile = HttpProfile(endpoint="vod.tencentcloudapi.com")
client_profile = ClientProfile(httpProfile=http_profile)
cred = credential.Credential(self.secret_id, self.secret_key)
return vod_client.VodClient(cred, "ap-shanghai", client_profile)
except Exception as e:
raise RuntimeError(f"VOD client initialization failed: {e}")
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"file_id": ("STRING", {"default": ""}),
"sub_app_id": ("INT", {"default": 0}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("local_path",)
FUNCTION = "execute"
CATEGORY = "video"
def execute(self, file_id, sub_app_id):
# 调用下载逻辑
local_path = self.download_vod(file_id, sub_app_id)
return (local_path,)
def _get_download_url(self, file_id, sub_app_id):
"""获取媒体文件下载地址"""
try:
req = models.DescribeMediaInfosRequest()
req.FileIds = [file_id]
req.SubAppId = sub_app_id
resp = self.vod_client.DescribeMediaInfos(req)
if not resp.MediaInfoSet:
raise ValueError("File not found")
media_info = resp.MediaInfoSet[0]
if not media_info.BasicInfo.MediaUrl:
raise ValueError("No download URL available")
return media_info.BasicInfo.MediaUrl
except Exception as e:
raise RuntimeError(f"Tencent API error: {e}")
def download_vod(self, file_id, sub_app_id):
"""
需要补充腾讯云VOD SDK调用逻辑
返回本地文件路径
"""
media_url = self._get_download_url(file_id=file_id, sub_app_id=sub_app_id)
# 生成一个临时目录路径名并创建该目录
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "download", f"{file_id}.mp4" )
return self._download_file(url=media_url, output_dir=output_dir, file_id=file_id, timeout=60 * 10)
def _download_file(self, url, output_dir, file_id, timeout=30):
"""下载文件到本地"""
try:
# 生成安全文件名
file_extension = os.path.splitext(url)[1] or ".mp4"
filename = f"{file_id}{file_extension}"
save_path = Path(output_dir) / filename
with requests.get(url, stream=True, timeout=timeout) as response:
response.raise_for_status()
with open(save_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
return str(save_path.resolve())
except Exception as e:
raise RuntimeError(f"Download error: {e}")
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"FaceOccDetect": FaceDetect,
"FaceExtract": FaceExtract,
"COSUpload": COSUpload,
"COSDownload": COSDownload,
"VideoCutCustom": VideoCut
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"FaceOccDetect": "面部遮挡检测",
"FaceExtract": "面部提取",
"COSUpload": "COS上传",
"COSDownload": "COS下载",
"VideoCutCustom": "视频剪裁"
}