ComfyUI-CustomNode/nodes/image_gesture_nodes.py

428 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import json
import os
import subprocess
import tempfile
import time
import uuid
from time import sleep
import cv2
import folder_paths
import numpy as np
import requests
import torch
import yaml
from PIL import Image
from loguru import logger
from qcloud_cos import CosConfig, CosS3Client
from torchvision.transforms import transforms
from tqdm import tqdm
class JMUtils:
def __init__(self):
if "aws_key_id" in list(os.environ.keys()):
yaml_config = os.environ
else:
with open(
os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config.yaml"
),
encoding="utf-8",
mode="r+",
) as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
self.api_key = yaml_config["jm_api_key"]
self.cos_region = yaml_config["cos_region"]
self.cos_secret_id = yaml_config["cos_secret_id"]
self.cos_secret_key = yaml_config["cos_secret_key"]
self.cos_bucket_name = yaml_config["cos_sucai_bucket_name"]
def submit_task(self, prompt: str, img_url: str, duration: str = "10", resolution:str="720p"):
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
json_data = {
"model": "doubao-seedance-1-0-pro-250528",
"content": [
{
"type": "text",
"text": f"{prompt} --resolution {resolution} --dur {duration} --camerafixed false",
},
{
"type": "image_url",
"image_url": {
"url": img_url,
},
},
],
}
response = requests.post("https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks",
headers=headers, json=json_data)
logger.info(f"submit task: {json.dumps(response.json())}")
resp_json = response.json()
if "id" not in resp_json:
return {"status": False, "data": img_url, "msg": resp_json["error"]["message"]}
else:
job_id = resp_json["id"]
return {"data": job_id, "status": True, "msg": "任务提交成功"}
except Exception as e:
logger.error(e)
return {"data": None, "status": False, "msg": str(e)}
def query_status(self, job_id: str):
resp_dict = {"status": False, "data": None, "msg": ""}
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
response = requests.get(f"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks/{job_id}",
headers=headers)
resp_json = response.json()
resp_dict["status"] = resp_json["status"] == "succeeded"
resp_dict["msg"] = resp_json["status"]
resp_dict["data"] = resp_json["content"]["video_url"] if "content" in resp_json else None
except Exception as e:
logger.error(f"error:{str(e)}")
resp_dict["msg"] = str(e)
finally:
return resp_dict
def upload_io_to_cos(self, file: io.IOBase, mime_type: str = "image/png"):
resp_data = {'status': True, 'data': '', 'msg': ''}
category = mime_type.split('/')[0]
suffix = mime_type.split('/')[1]
try:
object_key = f'tk/{category}/{uuid.uuid4()}.{suffix}'
config = CosConfig(Region=self.cos_region, SecretId=self.cos_secret_id, SecretKey=self.cos_secret_key)
client = CosS3Client(config)
_ = client.upload_file_from_buffer(
Bucket=self.cos_bucket_name,
Key=object_key,
Body=file
)
url = f'https://{self.cos_bucket_name}.cos.{self.cos_region}.myqcloud.com/{object_key}'
resp_data['data'] = url
resp_data['msg'] = '上传成功'
except Exception as e:
logger.error(e)
resp_data['status'] = False
resp_data['msg'] = str(e)
return resp_data
def tensor_to_io(srlf, tensor: torch.Tensor):
# 转换为PIL图像
img = Image.fromarray(np.clip(255. * tensor.cpu().squeeze().numpy(), 0, 255).astype(np.uint8))
image_data = io.BytesIO()
img.save(image_data, format='PNG')
image_data.seek(0)
return image_data
def read_video_last_frame_to_tensor(self, video_path: str) -> torch.Tensor:
"""
读取视频文件的最后一帧并将其转换为BCHW格式的PyTorch张量。
参数:
video_path (str): 视频文件的路径。
返回:
torch.Tensor: 形状为[1, H, W, C]的张量其中H和W分别是视频帧的高度和宽度通道顺序为RGB。
异常:
FileNotFoundError: 如果指定的视频文件不存在。
ValueError: 如果视频文件为空或无法读取帧。
"""
# 打开视频文件
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise FileNotFoundError(f"无法打开视频文件: {video_path}")
# 获取视频总帧数
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames == 0:
cap.release()
raise ValueError("视频文件为空或无法确定帧数")
# 设置读取位置到最后一帧
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1)
# 读取最后一帧
ret, frame = cap.read()
# 释放资源
cap.release()
if not ret or frame is None:
raise ValueError(f"无法读取视频的最后一帧,可能视频已损坏")
# 转换BGR到RGB (OpenCV默认读取为BGR)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转换为PyTorch张量并调整维度为BCHW
transform = transforms.Compose([
transforms.ToTensor() # 转换为[C, H, W]格式的张量值范围从0到1
])
tensor = transform(frame_rgb).unsqueeze(0).permute(0, 2, 3, 1) # 添加批次维度,变为[1, H, W, C]
return tensor
def download_video(self, url, timeout=30, retries=3, path=None):
"""下载视频到临时文件并返回文件路径"""
for attempt in range(retries):
try:
# 创建临时文件
if path:
temp_path = path
else:
temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
temp_path = temp_file.name
temp_file.close()
# 下载视频
print(f"开始下载视频 (尝试 {attempt + 1}/{retries})...")
response = requests.get(url, stream=True, timeout=timeout)
response.raise_for_status()
# 获取文件大小
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 KB
# 使用tqdm显示下载进度
with open(temp_path, 'wb') as f, tqdm(
desc=url.split('/')[-1],
total=total_size,
unit='B',
unit_scale=True,
unit_divisor=1024
) as bar:
for data in response.iter_content(block_size):
size = f.write(data)
bar.update(size)
print(f"视频下载完成: {temp_path}")
return temp_path, self.read_video_last_frame_to_tensor(temp_path)
except Exception as e:
print(f"下载错误 (尝试 {attempt + 1}/{retries}): {str(e)}")
if attempt < retries - 1:
time.sleep((attempt + 1) * 2)
else:
raise
def jpg_to_tensor(self, image_path, channel_first=False):
"""
将JPG图像转换为PyTorch张量
参数:
- image_path: JPG图像文件路径
- normalize: 是否将像素值归一化到[0.0, 1.0]
- channel_first: 是否将通道维度放在前面 (C, H, W)
返回:
- tensor: PyTorch张量
"""
try:
# 打开图像文件
image = Image.open(image_path).convert('RGB')
# 转换为张量
tensor = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,]
return tensor
except Exception as e:
print(f"转换失败: {str(e)}")
raise
def get_last_15th_frame_tensor(self, video_url, cleanup=True):
"""
从视频URL截取倒数第15帧并转换为Tensor
先下载视频到本地临时文件再处理
"""
try:
# 下载视频
video_path, _ = self.download_video(video_url)
# 获取视频总帧数
cmd_frames = [
'ffprobe', '-v', 'error',
'-select_streams', 'v:0',
'-show_entries', 'stream=nb_frames',
'-of', 'default=nokey=1:noprint_wrappers=1',
video_path
]
result = subprocess.run(
cmd_frames,
capture_output=True,
text=True,
check=True
)
# 处理可能的非数字输出
frame_count = result.stdout.strip()
if not frame_count.isdigit():
# 备选方案:通过解码获取帧数
print("无法获取准确帧数,尝试直接解码...")
cmd_decode = [
'ffmpeg', '-i', video_path,
'-f', 'null', '-'
]
decode_result = subprocess.run(
cmd_decode,
capture_output=True,
text=True
)
for line in decode_result.stderr.split('\n'):
if 'frame=' in line:
parts = line.split('frame=')[-1].split()[0]
if parts.isdigit():
frame_count = int(parts)
break
else:
raise ValueError("无法确定视频帧数")
else:
frame_count = int(frame_count)
# 计算目标帧
target_frame = max(0, frame_count - 15)
print(f"视频总帧数: {frame_count}, 目标帧: {target_frame}")
# 截取指定帧
with tempfile.NamedTemporaryFile(suffix='%03d.jpg', delete=True) as frame_file:
frame_path = frame_file.name
cmd_extract = [
'ffmpeg',
'-ss', f'00:00:00',
'-i', video_path,
'-vframes', '1',
'-vf', f'select=eq(n\,{target_frame})',
'-vsync', '0',
'-an', '-y',
frame_path
]
subprocess.run(
cmd_extract,
capture_output=True,
check=True
)
# 转换为Tensor
tensor = self.jpg_to_tensor(frame_path.replace("%03d", "001"))
except Exception as e:
raise e
return tensor
class JMGestureCorrect:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"resolution":(["720p","1080p"])
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("正面图",)
FUNCTION = "gen"
CATEGORY = "不忘科技-自定义节点🚩/图片/姿态"
def gen(self, image: torch.Tensor, resolution: str):
wait_time = 240
interval = 2
client = JMUtils()
image_io = client.tensor_to_io(image)
upload_data = client.upload_io_to_cos(image_io)
if upload_data["status"]:
image_url = upload_data["data"]
else:
raise Exception("上传失败")
prompt = "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame"
submit_data = client.submit_task(prompt, image_url, duration="5", resolution=resolution)
if submit_data["status"]:
job_id = submit_data["data"]
else:
raise Exception("即梦任务提交失败")
job_data = None
for idx, _ in enumerate(range(0, wait_time, interval)):
logger.info(f"查询即梦结果 {idx + 1}")
query = client.query_status(job_id)
if query["status"]:
job_data = query["data"]
break
else:
if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]:
raise Exception("即梦任务失败 {}".format(query["msg"]))
sleep(interval)
if not job_data:
raise Exception("即梦任务等待超时")
return (client.get_last_15th_frame_tensor(job_data),)
class JMCustom:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"prompt": ("STRING", {
"default": "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame",
"multiline": True}),
"duration": ("INT", {"default": 5, "min": 2, "max": 10}),
"resolution": (["720p", "1080p"]),
"wait_time": ("INT", {"default": 180, "min": 60, "max": 600}),
}
}
RETURN_TYPES = ("STRING", "IMAGE",)
RETURN_NAMES = ("视频存储路径", "视频最后一帧")
FUNCTION = "gen"
CATEGORY = "不忘科技-自定义节点🚩/视频/即梦"
def gen(self, image: torch.Tensor, prompt: str, duration: int, resolution: str, wait_time: int):
interval = 2
client = JMUtils()
image_io = client.tensor_to_io(image)
upload_data = client.upload_io_to_cos(image_io)
if upload_data["status"]:
image_url = upload_data["data"]
else:
raise Exception("上传失败")
submit_data = client.submit_task(prompt, image_url, str(duration), resolution=resolution)
if submit_data["status"]:
job_id = submit_data["data"]
else:
raise Exception("即梦任务提交失败")
job_data = None
for idx, _ in enumerate(range(0, wait_time, interval)):
logger.info(f"查询即梦结果 {idx + 1}")
query = client.query_status(job_id)
if query["status"]:
job_data = query["data"]
break
else:
if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]:
raise Exception("即梦任务失败 {}".format(query["msg"]))
sleep(interval)
if not job_data:
raise Exception("即梦任务等待超时")
video_path, last_scene = client.download_video(job_data, path=os.path.join(folder_paths.get_output_directory(),
f"{uuid.uuid4()}.mp4"))
return (video_path, last_scene,)