ComfyUI-CustomNode/nodes/image_gesture_nodes.py

318 lines
11 KiB
Python

import io
import json
import os
import subprocess
import tempfile
import time
import uuid
from time import sleep
import numpy as np
import requests
import torch
import yaml
from PIL import Image
from loguru import logger
from qcloud_cos import CosConfig, CosS3Client
from tqdm import tqdm
class JMUtils:
def __init__(self):
if "aws_key_id" in list(os.environ.keys()):
yaml_config = os.environ
else:
with open(
os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config.yaml"
),
encoding="utf-8",
mode="r+",
) as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
self.api_key = yaml_config["jm_api_key"]
self.cos_region = yaml_config["cos_region"]
self.cos_secret_id = yaml_config["cos_secret_id"]
self.cos_secret_key = yaml_config["cos_secret_key"]
self.cos_bucket_name = yaml_config["cos_sucai_bucket_name"]
def submit_task(self, prompt: str, img_url: str, duration: str = "5"):
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
json_data = {
"model": "doubao-seedance-1-0-lite-i2v-250428",
"content": [
{
"type": "text",
"text": f"{prompt} --resolution 720p --dur {duration} --camerafixed false",
},
{
"type": "image_url",
"image_url": {
"url": img_url,
},
},
],
}
response = requests.post("https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks", headers=headers, json=json_data)
logger.info(f"submit task: {json.dumps(response.json())}")
resp_json = response.json()
if "id" not in resp_json:
return {"status": False, "data": img_url, "msg": resp_json["error"]["message"]}
else:
job_id = resp_json["id"]
return {"data": job_id, "status": True, "msg": "任务提交成功"}
except Exception as e:
logger.error(e)
return {"data": None, "status": False, "msg": str(e)}
def query_status(self, job_id: str):
resp_dict = {"status": False, "data": None, "msg": ""}
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
response = requests.get(f"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks/{job_id}", headers=headers)
resp_json = response.json()
resp_dict["status"] = resp_json["status"] == "succeeded"
resp_dict["msg"] = resp_json["status"]
resp_dict["data"] = resp_json["content"]["video_url"] if "content" in resp_json else None
except Exception as e:
logger.error(f"error:{str(e)}")
resp_dict["msg"] = str(e)
finally:
return resp_dict
def upload_io_to_cos(self, file: io.IOBase, mime_type: str = "image/png"):
resp_data = {'status': True, 'data': '', 'msg': ''}
category = mime_type.split('/')[0]
suffix = mime_type.split('/')[1]
try:
object_key = f'tk/{category}/{uuid.uuid4()}.{suffix}'
config = CosConfig(Region=self.cos_region, SecretId=self.cos_secret_id, SecretKey=self.cos_secret_key)
client = CosS3Client(config)
_ = client.upload_file_from_buffer(
Bucket=self.cos_bucket_name,
Key=object_key,
Body=file
)
url = f'https://{self.cos_bucket_name}.cos.{self.cos_region}.myqcloud.com/{object_key}'
resp_data['data'] = url
resp_data['msg'] = '上传成功'
except Exception as e:
logger.error(e)
resp_data['status'] = False
resp_data['msg'] = str(e)
return resp_data
def tensor_to_io(srlf, tensor: torch.Tensor):
# 转换为PIL图像
img = Image.fromarray(np.clip(255. * tensor.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
image_data = io.BytesIO()
img.save(image_data, format='PNG')
image_data.seek(0)
return image_data
def download_video(self, url, timeout=30, retries=3):
"""下载视频到临时文件并返回文件路径"""
for attempt in range(retries):
try:
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
temp_path = temp_file.name
temp_file.close()
# 下载视频
print(f"开始下载视频 (尝试 {attempt + 1}/{retries})...")
response = requests.get(url, stream=True, timeout=timeout)
response.raise_for_status()
# 获取文件大小
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 KB
# 使用tqdm显示下载进度
with open(temp_path, 'wb') as f, tqdm(
desc=url.split('/')[-1],
total=total_size,
unit='B',
unit_scale=True,
unit_divisor=1024
) as bar:
for data in response.iter_content(block_size):
size = f.write(data)
bar.update(size)
print(f"视频下载完成: {temp_path}")
return temp_path
except Exception as e:
print(f"下载错误 (尝试 {attempt + 1}/{retries}): {str(e)}")
if attempt < retries - 1:
time.sleep((attempt + 1) * 2)
else:
raise
def jpg_to_tensor(self, image_path, channel_first=False):
"""
将JPG图像转换为PyTorch张量
参数:
- image_path: JPG图像文件路径
- normalize: 是否将像素值归一化到[0.0, 1.0]
- channel_first: 是否将通道维度放在前面 (C, H, W)
返回:
- tensor: PyTorch张量
"""
try:
# 打开图像文件
image = Image.open(image_path).convert('RGB')
# 转换为张量
tensor = torch.from_numpy(np.array(image).astype(np.float32)/255.0)[None,]
return tensor
except Exception as e:
print(f"转换失败: {str(e)}")
raise
def get_last_15th_frame_tensor(self, video_url, cleanup=True):
"""
从视频URL截取倒数第15帧并转换为Tensor
先下载视频到本地临时文件再处理
"""
try:
# 下载视频
video_path = self.download_video(video_url)
# 获取视频总帧数
cmd_frames = [
'ffprobe', '-v', 'error',
'-select_streams', 'v:0',
'-show_entries', 'stream=nb_frames',
'-of', 'default=nokey=1:noprint_wrappers=1',
video_path
]
result = subprocess.run(
cmd_frames,
capture_output=True,
text=True,
check=True
)
# 处理可能的非数字输出
frame_count = result.stdout.strip()
if not frame_count.isdigit():
# 备选方案:通过解码获取帧数
print("无法获取准确帧数,尝试直接解码...")
cmd_decode = [
'ffmpeg', '-i', video_path,
'-f', 'null', '-'
]
decode_result = subprocess.run(
cmd_decode,
capture_output=True,
text=True
)
for line in decode_result.stderr.split('\n'):
if 'frame=' in line:
parts = line.split('frame=')[-1].split()[0]
if parts.isdigit():
frame_count = int(parts)
break
else:
raise ValueError("无法确定视频帧数")
else:
frame_count = int(frame_count)
# 计算目标帧
target_frame = max(0, frame_count - 15)
print(f"视频总帧数: {frame_count}, 目标帧: {target_frame}")
# 截取指定帧
with tempfile.NamedTemporaryFile(suffix='%03d.jpg', delete=True) as frame_file:
frame_path = frame_file.name
cmd_extract = [
'ffmpeg',
'-ss', f'00:00:00',
'-i', video_path,
'-vframes', '1',
'-vf', f'select=eq(n\,{target_frame})',
'-vsync', '0',
'-an', '-y',
frame_path
]
subprocess.run(
cmd_extract,
capture_output=True,
check=True
)
# 转换为Tensor
tensor = self.jpg_to_tensor(frame_path.replace("%03d","001"))
except Exception as e:
raise e
return tensor
class JMGestureCorrect:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",)
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("正面图",)
FUNCTION = "gen"
CATEGORY = "不忘科技-自定义节点🚩/图片/姿态"
def gen(self, image:torch.Tensor):
wait_time = 120
interval = 2
client = JMUtils()
image_io = client.tensor_to_io(image)
upload_data = client.upload_io_to_cos(image_io)
if upload_data["status"]:
image_url = upload_data["data"]
else:
raise Exception("上传失败")
prompt = "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame"
submit_data = client.submit_task(prompt, image_url)
if submit_data["status"]:
job_id = submit_data["data"]
else:
raise Exception("即梦任务提交失败")
job_data = None
for idx, _ in enumerate(range(0, wait_time, interval)):
logger.info(f"查询即梦结果 {idx+1}")
query = client.query_status(job_id)
if query["status"]:
job_data = query["data"]
break
else:
if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]:
raise Exception("即梦任务失败 {}".format(query["msg"]))
sleep(interval)
if not job_data:
raise Exception("即梦任务等待超时")
return (client.get_last_15th_frame_tensor(job_data),)