415 lines
15 KiB
Python
415 lines
15 KiB
Python
import glob
|
||
import json
|
||
import os
|
||
import shutil
|
||
import traceback
|
||
import uuid
|
||
from datetime import datetime
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import torch
|
||
import yaml
|
||
from ultralytics import YOLO
|
||
from comfy import model_management
|
||
from qcloud_cos import CosConfig, CosClientError, CosServiceError
|
||
from qcloud_cos import CosS3Client
|
||
|
||
from .test_single_image import test_node
|
||
import ffmpy
|
||
|
||
video_extensions = ['webm', 'mp4', 'mkv', 'gif', 'mov']
|
||
|
||
|
||
class FaceDetect:
|
||
"""
|
||
人脸遮挡检测
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"image": ("IMAGE",),
|
||
"main_seed": ("INT:seed", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||
"model": (["convnext_tiny", "convnext_base"],),
|
||
"length": ("INT", {"default": 10, "min": 3, "max": 60, "step": 1}),
|
||
"threshold": ("FLOAT", {"default": 94, "min": 55, "max": 99, "step": 0.1})
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE", "IMAGE", "STRING", "STRING", "STRING", "STRING", "STRING", "INT", "INT")
|
||
RETURN_NAMES = ("图像", "选中人脸", "分类", "概率", "采用帧序号", "全部帧序列", "剪辑配置", "起始帧序号", "帧数量")
|
||
|
||
FUNCTION = "predict"
|
||
|
||
CATEGORY = "不忘科技-自定义节点🚩"
|
||
|
||
def predict(self, image, main_seed, model, length, threshold):
|
||
image, image_selected, cls, prob, nums, period = test_node(image, length=length, thres=threshold,
|
||
model_name=model)
|
||
print("全部帧序列", period)
|
||
if len(period) > 0:
|
||
start, end = period[main_seed % len(period)]
|
||
config = {"start": start, "end": end}
|
||
else:
|
||
raise RuntimeError("未找到符合要求的视频片段")
|
||
return (image, image_selected, cls, prob, nums, str(period), json.dumps(config), start, end - start+1)
|
||
|
||
|
||
class FaceExtract:
|
||
"""人脸提取 By YOLO"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"image": ("IMAGE",),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
RETURN_NAMES = ("图片",)
|
||
|
||
FUNCTION = "crop"
|
||
|
||
CATEGORY = "不忘科技-自定义节点🚩"
|
||
|
||
def crop(self, image):
|
||
device = model_management.get_torch_device()
|
||
image_np = 255. * image.cpu().numpy()
|
||
model = YOLO(model=os.path.join(os.path.dirname(os.path.abspath(__file__)), "model", "yolov8n-face-lindevs.pt"))
|
||
total_images = image_np.shape[0]
|
||
out_images = np.ndarray(shape=(total_images, 512, 512, 3))
|
||
print("shape", image_np.shape)
|
||
print("aaaaa")
|
||
idx = 0
|
||
for image_item in image_np:
|
||
results = model.predict(
|
||
image_item,
|
||
imgsz=640,
|
||
conf=0.75,
|
||
iou=0.7,
|
||
device=device,
|
||
verbose=False
|
||
)
|
||
n = 512
|
||
r = results[0]
|
||
if len(r.boxes.data.cpu().numpy()) == 1:
|
||
y1, x1, y2, x2, p, cls = r.boxes.data.cpu().numpy()[0]
|
||
face_size = int(max(y2 - y1, x2 - x1))
|
||
center = (x1 + x2) // 2, (y1 + y2) // 2
|
||
x1, x2, y1, y2 = center[0] - face_size // 2, center[0] + face_size // 2, center[1] - face_size // 2, \
|
||
center[1] + face_size // 2
|
||
template = np.ndarray(shape=(face_size, face_size, 3))
|
||
template.fill(20)
|
||
for a, a1 in zip(list(range(int(x1), int(x2))), list(range(face_size))):
|
||
for b, b1 in zip(list(range(int(y1), int(y2))), list(range(face_size))):
|
||
if (a >= 0 and a <= r.orig_img.shape[1]) and (b >= 0 and b <= r.orig_img.shape[0]):
|
||
template[a1][b1] = r.orig_img[a][b]
|
||
print(int(x1), int(x2), int(y1), int(y2))
|
||
img = cv2.resize(template, (n, n))
|
||
out_images[idx] = img
|
||
idx += 1
|
||
else:
|
||
idx += 1
|
||
cropped_face = np.array(out_images).astype(np.float32) / 255.0
|
||
cropped_face = torch.from_numpy(cropped_face)
|
||
return (cropped_face,)
|
||
|
||
|
||
class COSDownload:
|
||
# TODO 增加腾讯云VOD视频下载
|
||
"""腾讯云COS下载"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"cos_key": ("STRING", {"multiline": True}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("视频存储路径",)
|
||
FUNCTION = "download"
|
||
CATEGORY = "不忘科技-自定义节点🚩"
|
||
|
||
def download(self, cos_key):
|
||
if os.sep in cos_key or "/" in cos_key or "\\" in cos_key:
|
||
os.makedirs(os.path.join(os.path.dirname(os.path.abspath(__file__)), "download", os.path.dirname(cos_key)),
|
||
exist_ok=True)
|
||
for i in range(0, 10):
|
||
try:
|
||
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml"), encoding="utf-8",
|
||
mode="r+") as f:
|
||
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
|
||
config = CosConfig(Region=yaml_config["region"], SecretId=yaml_config["secret_id"],
|
||
SecretKey=yaml_config["secret_key"])
|
||
client = CosS3Client(config)
|
||
response = client.download_file(
|
||
Bucket=yaml_config["bucket"],
|
||
Key=cos_key,
|
||
DestFilePath=os.path.join(os.path.dirname(os.path.abspath(__file__)), "download",
|
||
os.path.dirname(cos_key), os.path.basename(cos_key)))
|
||
break
|
||
except CosClientError or CosServiceError as e:
|
||
print(f"下载失败 {e}")
|
||
return (os.path.join(os.path.dirname(os.path.abspath(__file__)), "download", os.path.dirname(cos_key),
|
||
os.path.basename(cos_key)),)
|
||
|
||
|
||
class COSUpload:
|
||
"""腾讯云COS上传"""
|
||
|
||
# TODO 增加腾讯云VOD视频上传
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"path": ("STRING", {"multiline": True}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("COS文件Key",)
|
||
|
||
FUNCTION = "upload"
|
||
CATEGORY = "不忘科技-自定义节点🚩"
|
||
|
||
def upload(self, path):
|
||
for i in range(0, 10):
|
||
try:
|
||
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml"), encoding="utf-8",
|
||
mode="r+") as f:
|
||
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
|
||
config = CosConfig(Region=yaml_config["region"], SecretId=yaml_config["secret_id"],
|
||
SecretKey=yaml_config["secret_key"])
|
||
client = CosS3Client(config)
|
||
response = client.upload_file(
|
||
Bucket=yaml_config["bucket"],
|
||
Key="/".join(
|
||
[yaml_config["subfolder"], path.split("/")[-1] if "/" in path else path.split("\\")[-1]]),
|
||
LocalFilePath=path)
|
||
break
|
||
except CosClientError or CosServiceError as e:
|
||
print(e)
|
||
return ("/".join([yaml_config["subfolder"], path.split("/")[-1] if "/" in path else path.split("\\")[-1]]),)
|
||
|
||
|
||
class VideoCut:
|
||
"""FFMPEG视频剪辑 -- !有卡顿问题 暂废弃"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"config": ("STRING",),
|
||
"video_path": ("STRING",),
|
||
"mod": ("INT",),
|
||
"fps": ("FLOAT",),
|
||
"period_length": ("INT", {"default": 10, "min": 4, "max": 100, "step": 1, "forceInput": True})
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("视频路径",)
|
||
|
||
FUNCTION = "cut"
|
||
|
||
# OUTPUT_NODE = False
|
||
|
||
CATEGORY = "不忘科技-自定义节点🚩"
|
||
|
||
def cut(self, config, video_path, mod, fps, period_length):
|
||
# 原文件名
|
||
origin_fname = ".".join(video_path.split(os.sep)[-1].split(".")[:-1])
|
||
# 配置获取
|
||
mul = mod / fps
|
||
print("fps", fps)
|
||
config = json.loads(config)
|
||
if len(config.keys()) == 0:
|
||
return ("无法生成符合要求的片段",)
|
||
start, end = config["start"], config["end"]
|
||
# 新文件名 复制改名适配ffmpeg
|
||
uid = uuid.uuid1()
|
||
temp_fname = os.sep.join([*video_path.split(os.sep)[:-1], "%s.%s" % (str(uid), video_path.split(".")[-1])])
|
||
try:
|
||
shutil.copy(video_path, temp_fname)
|
||
except:
|
||
return ("请检查输入文件权限",)
|
||
video_path = temp_fname
|
||
# 组装输出文件名
|
||
output_name = (".".join([*video_path.split(os.sep)[-1].split(".")[:-2],
|
||
video_path.split(os.sep)[-1].split(".")[-2] +
|
||
"_output_%%03d_%s" % datetime.now().strftime('%Y%m%d_%H%M%S'),
|
||
video_path.split(os.sep)[-1].split(".")[-1]]))
|
||
output = (os.sep.join([*video_path.split(os.sep)[:-1], output_name])
|
||
.replace(os.sep.join(["ComfyUI", "input"]), os.sep.join(["ComfyUI", "output"])).replace(" ", ""))
|
||
#调用ffmpeg
|
||
ff = ffmpy.FFmpeg(
|
||
inputs={video_path: ['-accurate_seek']},
|
||
outputs={output: [
|
||
'-f', 'segment',
|
||
'-ss', str(round(start * mul, 3)),
|
||
'-to', str(round(end * mul, 3)),
|
||
'-segment_times', str(period_length),
|
||
'-c', 'copy',
|
||
'-map', '0',
|
||
'-avoid_negative_ts', '1'
|
||
]}
|
||
)
|
||
print(ff.cmd)
|
||
ff.run()
|
||
# uuid填充改回原文件名
|
||
try:
|
||
os.remove(temp_fname)
|
||
except:
|
||
pass
|
||
try:
|
||
files = glob.glob(output.replace("%03d", "*"))
|
||
for file in files:
|
||
shutil.move(file, file.replace(str(uid), origin_fname))
|
||
files = glob.glob(output.replace(str(uid), origin_fname).replace("%03d", "*"))
|
||
return (str(files),)
|
||
except:
|
||
files = glob.glob(output.replace("%03d", "*"))
|
||
traceback.print_exc()
|
||
return (str(files),)
|
||
|
||
|
||
# Add custom API routes, using router
|
||
from aiohttp import web
|
||
from server import PromptServer
|
||
|
||
|
||
@PromptServer.instance.routes.get("/hello")
|
||
async def get_hello(request):
|
||
return web.json_response("hello")
|
||
|
||
|
||
|
||
|
||
# 腾讯云 VOD
|
||
|
||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||
from tencentcloud.common import credential
|
||
from tencentcloud.vod.v20180717 import vod_client, models
|
||
import requests
|
||
from pathlib import Path
|
||
import tempfile
|
||
|
||
class VodToLocalNode:
|
||
|
||
def __init__(self):
|
||
self.secret_id = "AKIDsrihIyjZOBsjimt8TsN8yvv1AMh5dB44"
|
||
self.secret_key = "CPZcxdk6W39Jd4cGY95wvupoyMd0YFqW"
|
||
self.vod_client = self.init_vod_client()
|
||
|
||
def init_vod_client(self):
|
||
"""初始化VOD客户端"""
|
||
try:
|
||
http_profile = HttpProfile(endpoint="vod.tencentcloudapi.com")
|
||
client_profile = ClientProfile(httpProfile=http_profile)
|
||
cred = credential.Credential(self.secret_id, self.secret_key)
|
||
return vod_client.VodClient(cred, "ap-shanghai", client_profile)
|
||
except Exception as e:
|
||
raise RuntimeError(f"VOD client initialization failed: {e}")
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"file_id": ("STRING", {"default": ""}),
|
||
"sub_app_id": ("STRING", {"default": ""}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("local_path",)
|
||
FUNCTION = "execute"
|
||
CATEGORY = "video"
|
||
|
||
def execute(self, file_id, sub_app_id):
|
||
# 调用下载逻辑
|
||
local_path = self.download_vod(file_id, sub_app_id)
|
||
return (local_path,)
|
||
|
||
def _get_download_url(self, file_id, sub_app_id):
|
||
"""获取媒体文件下载地址"""
|
||
try:
|
||
req = models.DescribeMediaInfosRequest()
|
||
req.FileIds = [file_id]
|
||
req.SubAppId = int(sub_app_id)
|
||
|
||
resp = self.vod_client.DescribeMediaInfos(req)
|
||
if not resp.MediaInfoSet:
|
||
raise ValueError("File not found")
|
||
|
||
media_info = resp.MediaInfoSet[0]
|
||
if not media_info.BasicInfo.MediaUrl:
|
||
raise ValueError("No download URL available")
|
||
|
||
return media_info.BasicInfo.MediaUrl
|
||
except Exception as e:
|
||
raise RuntimeError(f"Tencent API error: {e}")
|
||
|
||
def create_directory(self, path):
|
||
p = Path(path)
|
||
if not p.exists():
|
||
p.mkdir(parents=True, exist_ok=True) # parents=True会自动创建所有必需的父目录,exist_ok=True表示如果目录已存在则不会引发异常
|
||
print(f"目录已创建: {path}")
|
||
else:
|
||
print(f"目录已存在: {path}")
|
||
|
||
def download_vod(self, file_id, sub_app_id):
|
||
"""
|
||
需要补充腾讯云VOD SDK调用逻辑
|
||
返回本地文件路径
|
||
"""
|
||
media_url = self._get_download_url(file_id=file_id, sub_app_id=sub_app_id)
|
||
# 生成一个临时目录路径名并创建该目录
|
||
self.create_directory(os.path.join(os.path.dirname(os.path.abspath(__file__)), "download", f"{sub_app_id}"))
|
||
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "download", f"{sub_app_id}", f"{file_id}.mp4" )
|
||
return self._download_file(url=media_url, save_path=output_dir, timeout=60 * 10)
|
||
|
||
def _download_file(self, url, save_path, timeout=30):
|
||
"""下载文件到本地"""
|
||
try:
|
||
|
||
with requests.get(url, stream=True, timeout=timeout) as response:
|
||
response.raise_for_status()
|
||
|
||
with open(save_path, "wb") as f:
|
||
for chunk in response.iter_content(chunk_size=8192):
|
||
if chunk:
|
||
f.write(chunk)
|
||
|
||
return str(save_path.resolve())
|
||
except Exception as e:
|
||
raise RuntimeError(f"Download error: {e}")
|
||
|
||
# A dictionary that contains all nodes you want to export with their names
|
||
# NOTE: names should be globally unique
|
||
NODE_CLASS_MAPPINGS = {
|
||
"FaceOccDetect": FaceDetect,
|
||
"FaceExtract": FaceExtract,
|
||
"COSUpload": COSUpload,
|
||
"COSDownload": COSDownload,
|
||
"VideoCutCustom": VideoCut,
|
||
"VodToLocal": VodToLocalNode
|
||
}
|
||
|
||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||
"FaceOccDetect": "面部遮挡检测",
|
||
"FaceExtract": "面部提取",
|
||
"COSUpload": "COS上传",
|
||
"COSDownload": "COS下载",
|
||
"VideoCutCustom": "视频剪裁",
|
||
"VodToLocal": "腾讯云VOD下载"
|
||
}
|
||
|