ComfyUI-CustomNode/nodes/image_gesture_nodes.py

373 lines
14 KiB
Python

import io
import json
import os
import subprocess
import tempfile
import time
import uuid
from time import sleep
import folder_paths
import numpy as np
import requests
import torch
import yaml
from PIL import Image
from loguru import logger
from qcloud_cos import CosConfig, CosS3Client
from tqdm import tqdm
class JMUtils:
def __init__(self):
if "aws_key_id" in list(os.environ.keys()):
yaml_config = os.environ
else:
with open(
os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config.yaml"
),
encoding="utf-8",
mode="r+",
) as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
self.api_key = yaml_config["jm_api_key"]
self.cos_region = yaml_config["cos_region"]
self.cos_secret_id = yaml_config["cos_secret_id"]
self.cos_secret_key = yaml_config["cos_secret_key"]
self.cos_bucket_name = yaml_config["cos_sucai_bucket_name"]
def submit_task(self, prompt: str, img_url: str, duration: str = "10"):
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
json_data = {
"model": "doubao-seedance-1-0-pro-250528",
"content": [
{
"type": "text",
"text": f"{prompt} --resolution 1080p --dur {duration} --camerafixed false",
},
{
"type": "image_url",
"image_url": {
"url": img_url,
},
},
],
}
response = requests.post("https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks",
headers=headers, json=json_data)
logger.info(f"submit task: {json.dumps(response.json())}")
resp_json = response.json()
if "id" not in resp_json:
return {"status": False, "data": img_url, "msg": resp_json["error"]["message"]}
else:
job_id = resp_json["id"]
return {"data": job_id, "status": True, "msg": "任务提交成功"}
except Exception as e:
logger.error(e)
return {"data": None, "status": False, "msg": str(e)}
def query_status(self, job_id: str):
resp_dict = {"status": False, "data": None, "msg": ""}
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
response = requests.get(f"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks/{job_id}",
headers=headers)
resp_json = response.json()
resp_dict["status"] = resp_json["status"] == "succeeded"
resp_dict["msg"] = resp_json["status"]
resp_dict["data"] = resp_json["content"]["video_url"] if "content" in resp_json else None
except Exception as e:
logger.error(f"error:{str(e)}")
resp_dict["msg"] = str(e)
finally:
return resp_dict
def upload_io_to_cos(self, file: io.IOBase, mime_type: str = "image/png"):
resp_data = {'status': True, 'data': '', 'msg': ''}
category = mime_type.split('/')[0]
suffix = mime_type.split('/')[1]
try:
object_key = f'tk/{category}/{uuid.uuid4()}.{suffix}'
config = CosConfig(Region=self.cos_region, SecretId=self.cos_secret_id, SecretKey=self.cos_secret_key)
client = CosS3Client(config)
_ = client.upload_file_from_buffer(
Bucket=self.cos_bucket_name,
Key=object_key,
Body=file
)
url = f'https://{self.cos_bucket_name}.cos.{self.cos_region}.myqcloud.com/{object_key}'
resp_data['data'] = url
resp_data['msg'] = '上传成功'
except Exception as e:
logger.error(e)
resp_data['status'] = False
resp_data['msg'] = str(e)
return resp_data
def tensor_to_io(srlf, tensor: torch.Tensor):
# 转换为PIL图像
img = Image.fromarray(np.clip(255. * tensor.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
image_data = io.BytesIO()
img.save(image_data, format='PNG')
image_data.seek(0)
return image_data
def download_video(self, url, timeout=30, retries=3, path=None):
"""下载视频到临时文件并返回文件路径"""
for attempt in range(retries):
try:
# 创建临时文件
if path:
temp_path = path
else:
temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
temp_path = temp_file.name
temp_file.close()
# 下载视频
print(f"开始下载视频 (尝试 {attempt + 1}/{retries})...")
response = requests.get(url, stream=True, timeout=timeout)
response.raise_for_status()
# 获取文件大小
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 KB
# 使用tqdm显示下载进度
with open(temp_path, 'wb') as f, tqdm(
desc=url.split('/')[-1],
total=total_size,
unit='B',
unit_scale=True,
unit_divisor=1024
) as bar:
for data in response.iter_content(block_size):
size = f.write(data)
bar.update(size)
print(f"视频下载完成: {temp_path}")
return temp_path
except Exception as e:
print(f"下载错误 (尝试 {attempt + 1}/{retries}): {str(e)}")
if attempt < retries - 1:
time.sleep((attempt + 1) * 2)
else:
raise
def jpg_to_tensor(self, image_path, channel_first=False):
"""
将JPG图像转换为PyTorch张量
参数:
- image_path: JPG图像文件路径
- normalize: 是否将像素值归一化到[0.0, 1.0]
- channel_first: 是否将通道维度放在前面 (C, H, W)
返回:
- tensor: PyTorch张量
"""
try:
# 打开图像文件
image = Image.open(image_path).convert('RGB')
# 转换为张量
tensor = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,]
return tensor
except Exception as e:
print(f"转换失败: {str(e)}")
raise
def get_last_15th_frame_tensor(self, video_url, cleanup=True):
"""
从视频URL截取倒数第15帧并转换为Tensor
先下载视频到本地临时文件再处理
"""
try:
# 下载视频
video_path = self.download_video(video_url)
# 获取视频总帧数
cmd_frames = [
'ffprobe', '-v', 'error',
'-select_streams', 'v:0',
'-show_entries', 'stream=nb_frames',
'-of', 'default=nokey=1:noprint_wrappers=1',
video_path
]
result = subprocess.run(
cmd_frames,
capture_output=True,
text=True,
check=True
)
# 处理可能的非数字输出
frame_count = result.stdout.strip()
if not frame_count.isdigit():
# 备选方案:通过解码获取帧数
print("无法获取准确帧数,尝试直接解码...")
cmd_decode = [
'ffmpeg', '-i', video_path,
'-f', 'null', '-'
]
decode_result = subprocess.run(
cmd_decode,
capture_output=True,
text=True
)
for line in decode_result.stderr.split('\n'):
if 'frame=' in line:
parts = line.split('frame=')[-1].split()[0]
if parts.isdigit():
frame_count = int(parts)
break
else:
raise ValueError("无法确定视频帧数")
else:
frame_count = int(frame_count)
# 计算目标帧
target_frame = max(0, frame_count - 15)
print(f"视频总帧数: {frame_count}, 目标帧: {target_frame}")
# 截取指定帧
with tempfile.NamedTemporaryFile(suffix='%03d.jpg', delete=True) as frame_file:
frame_path = frame_file.name
cmd_extract = [
'ffmpeg',
'-ss', f'00:00:00',
'-i', video_path,
'-vframes', '1',
'-vf', f'select=eq(n\,{target_frame})',
'-vsync', '0',
'-an', '-y',
frame_path
]
subprocess.run(
cmd_extract,
capture_output=True,
check=True
)
# 转换为Tensor
tensor = self.jpg_to_tensor(frame_path.replace("%03d", "001"))
except Exception as e:
raise e
return tensor
class JMGestureCorrect:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",)
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("正面图",)
FUNCTION = "gen"
CATEGORY = "不忘科技-自定义节点🚩/图片/姿态"
def gen(self, image: torch.Tensor):
wait_time = 240
interval = 2
client = JMUtils()
image_io = client.tensor_to_io(image)
upload_data = client.upload_io_to_cos(image_io)
if upload_data["status"]:
image_url = upload_data["data"]
else:
raise Exception("上传失败")
prompt = "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame"
submit_data = client.submit_task(prompt, image_url)
if submit_data["status"]:
job_id = submit_data["data"]
else:
raise Exception("即梦任务提交失败")
job_data = None
for idx, _ in enumerate(range(0, wait_time, interval)):
logger.info(f"查询即梦结果 {idx + 1}")
query = client.query_status(job_id)
if query["status"]:
job_data = query["data"]
break
else:
if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]:
raise Exception("即梦任务失败 {}".format(query["msg"]))
sleep(interval)
if not job_data:
raise Exception("即梦任务等待超时")
return (client.get_last_15th_frame_tensor(job_data),)
class JMCustom:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"prompt": ("STRING", {
"default": "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame",
"multiline": True}),
"duration": ("INT", {"default": 5, "min": 2, "max": 10}),
"wait_time": ("INT", {"default": 180, "min": 60, "max": 600}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("视频存储路径",)
FUNCTION = "gen"
CATEGORY = "不忘科技-自定义节点🚩/视频/即梦"
def gen(self, image: torch.Tensor, prompt: str, duration: int, wait_time: int):
interval = 2
client = JMUtils()
image_io = client.tensor_to_io(image)
upload_data = client.upload_io_to_cos(image_io)
if upload_data["status"]:
image_url = upload_data["data"]
else:
raise Exception("上传失败")
submit_data = client.submit_task(prompt, image_url, str(duration))
if submit_data["status"]:
job_id = submit_data["data"]
else:
raise Exception("即梦任务提交失败")
job_data = None
for idx, _ in enumerate(range(0, wait_time, interval)):
logger.info(f"查询即梦结果 {idx + 1}")
query = client.query_status(job_id)
if query["status"]:
job_data = query["data"]
break
else:
if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]:
raise Exception("即梦任务失败 {}".format(query["msg"]))
sleep(interval)
if not job_data:
raise Exception("即梦任务等待超时")
return (
client.download_video(job_data, path=os.path.join(folder_paths.get_output_directory(), f"{uuid.uuid4()}.mp4")),)