ComfyUI-CustomNode/nodes/image_gesture_nodes.py

669 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import json
import os
import subprocess
import tempfile
import time
import uuid
from time import sleep
from typing import Any, Dict
import cv2
import folder_paths
import numpy as np
import requests
import torch
from loguru import logger
from PIL import Image
from torchvision.transforms import transforms
from tqdm import tqdm
from ..utils.config_utils import config
from ..utils.object_storage import UploadResult, get_provider
class JMUtils:
"""
即梦AI工具类
提供即梦AI视频生成服务的完整功能包括
- 图像上传到云存储
- 任务提交和状态查询
- 视频下载和处理
- 张量和图像格式转换
使用统一的存储抽象层,支持多种云存储服务。
"""
def __init__(self):
"""
初始化即梦工具实例
从配置中读取API密钥和存储配置信息
"""
try:
# 获取即梦API配置
self.api_key = config.get_config("jm_api_key")
if not self.api_key:
raise ValueError("即梦API密钥未配置")
# 获取COS存储配置用于素材上传
cos_config = config.get_cos_config()
self.cos_bucket_name = cos_config.get("bucket_name") or config.get_config(
"cos_sucai_bucket_name"
)
if not self.cos_bucket_name:
raise ValueError("COS素材存储桶未配置")
# 获取存储提供者
self.storage_provider = get_provider("cos")
logger.info(f"即梦工具初始化成功,使用存储桶: {self.cos_bucket_name}")
except Exception as e:
logger.error(f"即梦工具初始化失败: {e}")
raise
def submit_task(
self, prompt: str, img_url: str, duration: str = "10", resolution: str = "720p"
) -> Dict[str, Any]:
"""
提交即梦AI视频生成任务
Args:
prompt: 生成提示词
img_url: 输入图像URL
duration: 视频时长(秒)
resolution: 视频分辨率
Returns:
Dict: 任务提交结果
- status: 是否成功
- data: 任务ID或原图URL
- msg: 消息
"""
try:
# 验证输入参数
if not prompt or not prompt.strip():
return {"status": False, "data": None, "msg": "提示词不能为空"}
if not img_url or not img_url.strip():
return {"status": False, "data": None, "msg": "图像URL不能为空"}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
json_data = {
"model": "doubao-seedance-1-0-pro-250528",
"content": [
{
"type": "text",
"text": f"{prompt.strip()} --resolution {resolution} --dur {duration} --camerafixed false",
},
{
"type": "image_url",
"image_url": {
"url": img_url.strip(),
},
},
],
}
logger.info(
f"即梦任务提交中: prompt='{prompt[:50]}...', resolution={resolution}, duration={duration}"
)
response = requests.post(
"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks",
headers=headers,
json=json_data,
timeout=30,
)
response.raise_for_status()
resp_json = response.json()
logger.info(
f"即梦任务提交响应: {json.dumps(resp_json, ensure_ascii=False)}"
)
if "id" not in resp_json:
error_msg = "未知错误"
if "error" in resp_json and "message" in resp_json["error"]:
error_msg = resp_json["error"]["message"]
return {
"status": False,
"data": img_url,
"msg": f"任务提交失败: {error_msg}",
}
else:
job_id = resp_json["id"]
logger.info(f"即梦任务提交成功任务ID: {job_id}")
return {"data": job_id, "status": True, "msg": "任务提交成功"}
except requests.RequestException as e:
logger.error(f"即梦API请求失败: {e}")
return {"data": None, "status": False, "msg": f"网络请求失败: {str(e)}"}
except Exception as e:
logger.error(f"即梦任务提交异常: {e}")
return {"data": None, "status": False, "msg": str(e)}
def query_status(self, job_id: str) -> Dict[str, Any]:
"""
查询即梦AI任务状态
Args:
job_id: 任务ID
Returns:
Dict: 任务状态查询结果
- status: 任务是否完成成功
- data: 视频URL如果完成
- msg: 状态消息
"""
resp_dict = {"status": False, "data": None, "msg": ""}
try:
if not job_id or not job_id.strip():
resp_dict["msg"] = "任务ID不能为空"
return resp_dict
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
response = requests.get(
f"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks/{job_id.strip()}",
headers=headers,
timeout=15,
)
response.raise_for_status()
resp_json = response.json()
task_status = resp_json.get("status", "unknown")
# 任务完成成功
if task_status == "succeeded":
resp_dict["status"] = True
resp_dict["msg"] = "任务完成"
if "content" in resp_json and "video_url" in resp_json["content"]:
resp_dict["data"] = resp_json["content"]["video_url"]
else:
resp_dict["status"] = False
resp_dict["msg"] = "任务完成但未找到视频URL"
# 任务失败
elif task_status in ["failed", "error"]:
resp_dict["status"] = False
error_msg = "任务失败"
if "error" in resp_json:
error_msg += f": {resp_json['error'].get('message', '未知错误')}"
resp_dict["msg"] = error_msg
# 任务进行中
elif task_status in ["pending", "running", "processing"]:
resp_dict["status"] = False
resp_dict["msg"] = f"任务进行中: {task_status}"
# 其他状态
else:
resp_dict["status"] = False
resp_dict["msg"] = f"未知任务状态: {task_status}"
except requests.RequestException as e:
logger.error(f"即梦状态查询网络错误: {e}")
resp_dict["msg"] = f"网络请求失败: {str(e)}"
except Exception as e:
logger.error(f"即梦状态查询异常: {e}")
resp_dict["msg"] = str(e)
return resp_dict
def upload_io_to_cos(
self, file: io.IOBase, mime_type: str = "image/png"
) -> Dict[str, Any]:
"""
上传IO对象到COS存储
Args:
file: 文件IO对象
mime_type: MIME类型
Returns:
Dict: 包含上传结果的字典
- status: 是否成功
- data: 上传后的URL
- msg: 消息
"""
resp_data = {"status": True, "data": "", "msg": ""}
try:
# 解析MIME类型
parts = mime_type.split("/")
category = parts[0] if len(parts) > 0 else "file"
suffix = parts[1] if len(parts) > 1 else "bin"
# 生成存储键名
object_key = f"tk/{category}/{uuid.uuid4()}.{suffix}"
logger.info(f"开始上传文件到COS: {object_key}")
# 读取文件内容
file_content = file.read()
file.seek(0) # 重置文件指针
# 使用统一存储接口上传
result: UploadResult = self.storage_provider.upload_bytes(
file_content, object_key, bucket_name=self.cos_bucket_name
)
if result.success:
# 构造COS URL如果result中没有提供
if result.url:
resp_data["data"] = result.url
else:
# 构造默认的COS URL
cos_config = config.get_cos_config()
region = cos_config.get("region", "ap-beijing")
resp_data["data"] = (
f"https://{self.cos_bucket_name}.cos.{region}.myqcloud.com/{object_key}"
)
resp_data["msg"] = "上传成功"
logger.info(f"文件上传成功: {resp_data['data']}")
else:
resp_data["status"] = False
resp_data["msg"] = result.message or "上传失败"
logger.error(f"文件上传失败: {resp_data['msg']}")
except Exception as e:
logger.error(f"上传文件时发生异常: {e}")
resp_data["status"] = False
resp_data["msg"] = str(e)
return resp_data
def tensor_to_io(self, tensor: torch.Tensor) -> io.BytesIO:
"""
将PyTorch张量转换为PNG格式的IO对象
Args:
tensor: PyTorch图像张量支持多种格式
- (H, W) 灰度图
- (H, W, C) RGB图像
- (1, H, W, C) 批次图像
Returns:
io.BytesIO: PNG格式的字节流对象
Raises:
ValueError: 当张量格式不支持时
"""
try:
# 处理张量维度
if tensor.dim() == 4: # (1, H, W, C)
tensor = tensor.squeeze(0)
elif tensor.dim() == 2: # (H, W) 灰度图
pass # 保持原样
elif tensor.dim() == 3: # (H, W, C)
pass # 保持原样
else:
raise ValueError(f"不支持的张量维度: {tensor.dim()}D")
# 转换为numpy数组
numpy_array = tensor.cpu().numpy()
# 确保数值在有效范围内
numpy_array = np.clip(numpy_array, 0.0, 1.0)
# 转换为0-255范围的uint8
image_array = (numpy_array * 255).astype(np.uint8)
# 处理灰度图
if len(image_array.shape) == 2:
img = Image.fromarray(image_array, mode="L")
else:
img = Image.fromarray(image_array, mode="RGB")
# 保存为PNG格式的BytesIO
image_data = io.BytesIO()
img.save(image_data, format="PNG", optimize=True)
image_data.seek(0)
logger.debug(f"张量转换为PNG成功大小: {len(image_data.getvalue())} bytes")
return image_data
except Exception as e:
logger.error(f"张量转换失败: {e}")
raise ValueError(f"张量转换为图像失败: {str(e)}")
def read_video_last_frame_to_tensor(self, video_path: str) -> torch.Tensor:
"""
读取视频文件的最后一帧并将其转换为BCHW格式的PyTorch张量。
参数:
video_path (str): 视频文件的路径。
返回:
torch.Tensor: 形状为[1, H, W, C]的张量其中H和W分别是视频帧的高度和宽度通道顺序为RGB。
异常:
FileNotFoundError: 如果指定的视频文件不存在。
ValueError: 如果视频文件为空或无法读取帧。
"""
# 打开视频文件
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise FileNotFoundError(f"无法打开视频文件: {video_path}")
# 获取视频总帧数
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames == 0:
cap.release()
raise ValueError("视频文件为空或无法确定帧数")
# 设置读取位置到最后一帧
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1)
# 读取最后一帧
ret, frame = cap.read()
# 释放资源
cap.release()
if not ret or frame is None:
raise ValueError(f"无法读取视频的最后一帧,可能视频已损坏")
# 转换BGR到RGB (OpenCV默认读取为BGR)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转换为PyTorch张量并调整维度为BCHW
transform = transforms.Compose(
[transforms.ToTensor()] # 转换为[C, H, W]格式的张量值范围从0到1
)
tensor = (
transform(frame_rgb).unsqueeze(0).permute(0, 2, 3, 1)
) # 添加批次维度,变为[1, H, W, C]
return tensor
def download_video(self, url, timeout=30, retries=3, path=None):
"""下载视频到临时文件并返回文件路径"""
for attempt in range(retries):
try:
# 创建临时文件
if path:
temp_path = path
else:
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
temp_path = temp_file.name
temp_file.close()
# 下载视频
print(f"开始下载视频 (尝试 {attempt + 1}/{retries})...")
response = requests.get(url, stream=True, timeout=timeout)
response.raise_for_status()
# 获取文件大小
total_size = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 KB
# 使用tqdm显示下载进度
with open(temp_path, "wb") as f, tqdm(
desc=url.split("/")[-1],
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(block_size):
size = f.write(data)
bar.update(size)
print(f"视频下载完成: {temp_path}")
return temp_path, self.read_video_last_frame_to_tensor(temp_path)
except Exception as e:
print(f"下载错误 (尝试 {attempt + 1}/{retries}): {str(e)}")
if attempt < retries - 1:
time.sleep((attempt + 1) * 2)
else:
raise
def jpg_to_tensor(self, image_path):
"""
将JPG图像转换为PyTorch张量
参数:
- image_path: JPG图像文件路径
返回:
- tensor: PyTorch张量
"""
try:
# 打开图像文件
image = Image.open(image_path).convert("RGB")
# 转换为张量
tensor = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,]
return tensor
except Exception as e:
print(f"转换失败: {str(e)}")
raise
def get_last_15th_frame_tensor(self, video_url):
"""
从视频URL截取倒数第15帧并转换为Tensor
先下载视频到本地临时文件再处理
"""
try:
# 下载视频
video_path, _ = self.download_video(video_url)
# 获取视频总帧数
cmd_frames = [
"ffprobe",
"-v",
"error",
"-select_streams",
"v:0",
"-show_entries",
"stream=nb_frames",
"-of",
"default=nokey=1:noprint_wrappers=1",
video_path,
]
result = subprocess.run(
cmd_frames, capture_output=True, text=True, check=True
)
# 处理可能的非数字输出
frame_count = result.stdout.strip()
if not frame_count.isdigit():
# 备选方案:通过解码获取帧数
print("无法获取准确帧数,尝试直接解码...")
cmd_decode = ["ffmpeg", "-i", video_path, "-f", "null", "-"]
decode_result = subprocess.run(
cmd_decode, capture_output=True, text=True
)
for line in decode_result.stderr.split("\n"):
if "frame=" in line:
parts = line.split("frame=")[-1].split()[0]
if parts.isdigit():
frame_count = int(parts)
break
else:
raise ValueError("无法确定视频帧数")
else:
frame_count = int(frame_count)
# 计算目标帧
target_frame = max(0, frame_count - 15)
print(f"视频总帧数: {frame_count}, 目标帧: {target_frame}")
# 截取指定帧
with tempfile.NamedTemporaryFile(
suffix="%03d.jpg", delete=True
) as frame_file:
frame_path = frame_file.name
cmd_extract = [
"ffmpeg",
"-ss",
f"00:00:00",
"-i",
video_path,
"-vframes",
"1",
"-vf",
f"select=eq(n\,{target_frame})",
"-vsync",
"0",
"-an",
"-y",
frame_path,
]
subprocess.run(cmd_extract, capture_output=True, check=True)
# 转换为Tensor
tensor = self.jpg_to_tensor(frame_path.replace("%03d", "001"))
except Exception as e:
raise e
return tensor
class JMGestureCorrect:
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE",), "resolution": (["720p", "1080p"])}}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("正面图",)
FUNCTION = "gen"
CATEGORY = "不忘科技-自定义节点🚩/图片/姿态"
def gen(self, image: torch.Tensor, resolution: str):
wait_time = 240
interval = 2
client = JMUtils()
image_io = client.tensor_to_io(image)
upload_data = client.upload_io_to_cos(image_io)
if upload_data["status"]:
image_url = upload_data["data"]
else:
raise Exception("上传失败")
prompt = "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame"
submit_data = client.submit_task(
prompt, image_url, duration="5", resolution=resolution
)
if submit_data["status"]:
job_id = submit_data["data"]
else:
raise Exception("即梦任务提交失败")
job_data = None
for idx, _ in enumerate(range(0, wait_time, interval)):
logger.info(f"查询即梦结果 {idx + 1}")
query = client.query_status(job_id)
if query["status"]:
job_data = query["data"]
break
else:
if (
"error" in query["msg"]
or "失败" in query["msg"]
or "fail" in query["msg"]
):
raise Exception("即梦任务失败 {}".format(query["msg"]))
sleep(interval)
if not job_data:
raise Exception("即梦任务等待超时")
return (client.get_last_15th_frame_tensor(job_data),)
class JMCustom:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"prompt": (
"STRING",
{
"default": "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame",
"multiline": True,
},
),
"duration": ("INT", {"default": 5, "min": 2, "max": 10}),
"resolution": (["720p", "1080p"]),
"wait_time": ("INT", {"default": 180, "min": 60, "max": 600}),
}
}
RETURN_TYPES = (
"STRING",
"IMAGE",
)
RETURN_NAMES = ("视频存储路径", "视频最后一帧")
FUNCTION = "gen"
CATEGORY = "不忘科技-自定义节点🚩/视频/即梦"
def gen(
self,
image: torch.Tensor,
prompt: str,
duration: int,
resolution: str,
wait_time: int,
):
interval = 2
client = JMUtils()
image_io = client.tensor_to_io(image)
upload_data = client.upload_io_to_cos(image_io)
if upload_data["status"]:
image_url = upload_data["data"]
else:
raise Exception("上传失败")
submit_data = client.submit_task(
prompt, image_url, str(duration), resolution=resolution
)
if submit_data["status"]:
job_id = submit_data["data"]
else:
raise Exception("即梦任务提交失败")
job_data = None
for idx, _ in enumerate(range(0, wait_time, interval)):
logger.info(f"查询即梦结果 {idx + 1}")
query = client.query_status(job_id)
if query["status"]:
job_data = query["data"]
break
else:
if (
"error" in query["msg"]
or "失败" in query["msg"]
or "fail" in query["msg"]
):
raise Exception("即梦任务失败 {}".format(query["msg"]))
sleep(interval)
if not job_data:
raise Exception("即梦任务等待超时")
output_dir = folder_paths.get_output_directory()
video_path, last_scene = client.download_video(
job_data, path=os.path.join(output_dir, f"{uuid.uuid4()}.mp4")
)
return (
video_path,
last_scene,
)