428 lines
16 KiB
Python
428 lines
16 KiB
Python
import io
|
||
import json
|
||
import os
|
||
import subprocess
|
||
import tempfile
|
||
import time
|
||
import uuid
|
||
from time import sleep
|
||
|
||
import cv2
|
||
import folder_paths
|
||
import numpy as np
|
||
import requests
|
||
import torch
|
||
import yaml
|
||
from PIL import Image
|
||
from loguru import logger
|
||
from qcloud_cos import CosConfig, CosS3Client
|
||
from torchvision.transforms import transforms
|
||
from tqdm import tqdm
|
||
|
||
|
||
class JMUtils:
|
||
def __init__(self):
|
||
if "aws_key_id" in list(os.environ.keys()):
|
||
yaml_config = os.environ
|
||
else:
|
||
with open(
|
||
os.path.join(
|
||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config.yaml"
|
||
),
|
||
encoding="utf-8",
|
||
mode="r+",
|
||
) as f:
|
||
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
|
||
|
||
self.api_key = yaml_config["jm_api_key"]
|
||
self.cos_region = yaml_config["cos_region"]
|
||
self.cos_secret_id = yaml_config["cos_secret_id"]
|
||
self.cos_secret_key = yaml_config["cos_secret_key"]
|
||
self.cos_bucket_name = yaml_config["cos_sucai_bucket_name"]
|
||
|
||
def submit_task(self, prompt: str, img_url: str, duration: str = "10", resolution:str="720p"):
|
||
try:
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
}
|
||
|
||
json_data = {
|
||
"model": "doubao-seedance-1-0-pro-250528",
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"{prompt} --resolution {resolution} --dur {duration} --camerafixed false",
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": img_url,
|
||
},
|
||
},
|
||
],
|
||
}
|
||
|
||
response = requests.post("https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks",
|
||
headers=headers, json=json_data)
|
||
logger.info(f"submit task: {json.dumps(response.json())}")
|
||
resp_json = response.json()
|
||
if "id" not in resp_json:
|
||
return {"status": False, "data": img_url, "msg": resp_json["error"]["message"]}
|
||
else:
|
||
job_id = resp_json["id"]
|
||
return {"data": job_id, "status": True, "msg": "任务提交成功"}
|
||
except Exception as e:
|
||
logger.error(e)
|
||
return {"data": None, "status": False, "msg": str(e)}
|
||
|
||
def query_status(self, job_id: str):
|
||
resp_dict = {"status": False, "data": None, "msg": ""}
|
||
try:
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
}
|
||
response = requests.get(f"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks/{job_id}",
|
||
headers=headers)
|
||
resp_json = response.json()
|
||
resp_dict["status"] = resp_json["status"] == "succeeded"
|
||
resp_dict["msg"] = resp_json["status"]
|
||
resp_dict["data"] = resp_json["content"]["video_url"] if "content" in resp_json else None
|
||
except Exception as e:
|
||
logger.error(f"error:{str(e)}")
|
||
resp_dict["msg"] = str(e)
|
||
finally:
|
||
return resp_dict
|
||
|
||
def upload_io_to_cos(self, file: io.IOBase, mime_type: str = "image/png"):
|
||
resp_data = {'status': True, 'data': '', 'msg': ''}
|
||
category = mime_type.split('/')[0]
|
||
suffix = mime_type.split('/')[1]
|
||
try:
|
||
object_key = f'tk/{category}/{uuid.uuid4()}.{suffix}'
|
||
config = CosConfig(Region=self.cos_region, SecretId=self.cos_secret_id, SecretKey=self.cos_secret_key)
|
||
client = CosS3Client(config)
|
||
_ = client.upload_file_from_buffer(
|
||
Bucket=self.cos_bucket_name,
|
||
Key=object_key,
|
||
Body=file
|
||
)
|
||
url = f'https://{self.cos_bucket_name}.cos.{self.cos_region}.myqcloud.com/{object_key}'
|
||
resp_data['data'] = url
|
||
resp_data['msg'] = '上传成功'
|
||
|
||
except Exception as e:
|
||
logger.error(e)
|
||
resp_data['status'] = False
|
||
resp_data['msg'] = str(e)
|
||
|
||
return resp_data
|
||
|
||
def tensor_to_io(srlf, tensor: torch.Tensor):
|
||
# 转换为PIL图像
|
||
img = Image.fromarray(np.clip(255. * tensor.cpu().squeeze().numpy(), 0, 255).astype(np.uint8))
|
||
image_data = io.BytesIO()
|
||
img.save(image_data, format='PNG')
|
||
image_data.seek(0)
|
||
return image_data
|
||
|
||
def read_video_last_frame_to_tensor(self, video_path: str) -> torch.Tensor:
|
||
"""
|
||
读取视频文件的最后一帧并将其转换为BCHW格式的PyTorch张量。
|
||
|
||
参数:
|
||
video_path (str): 视频文件的路径。
|
||
|
||
返回:
|
||
torch.Tensor: 形状为[1, H, W, C]的张量,其中H和W分别是视频帧的高度和宽度,通道顺序为RGB。
|
||
|
||
异常:
|
||
FileNotFoundError: 如果指定的视频文件不存在。
|
||
ValueError: 如果视频文件为空或无法读取帧。
|
||
"""
|
||
# 打开视频文件
|
||
cap = cv2.VideoCapture(video_path)
|
||
|
||
if not cap.isOpened():
|
||
raise FileNotFoundError(f"无法打开视频文件: {video_path}")
|
||
|
||
# 获取视频总帧数
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
|
||
if total_frames == 0:
|
||
cap.release()
|
||
raise ValueError("视频文件为空或无法确定帧数")
|
||
|
||
# 设置读取位置到最后一帧
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1)
|
||
|
||
# 读取最后一帧
|
||
ret, frame = cap.read()
|
||
|
||
# 释放资源
|
||
cap.release()
|
||
|
||
if not ret or frame is None:
|
||
raise ValueError(f"无法读取视频的最后一帧,可能视频已损坏")
|
||
|
||
# 转换BGR到RGB (OpenCV默认读取为BGR)
|
||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
|
||
# 转换为PyTorch张量并调整维度为BCHW
|
||
transform = transforms.Compose([
|
||
transforms.ToTensor() # 转换为[C, H, W]格式的张量,值范围从0到1
|
||
])
|
||
|
||
tensor = transform(frame_rgb).unsqueeze(0).permute(0, 2, 3, 1) # 添加批次维度,变为[1, H, W, C]
|
||
|
||
return tensor
|
||
|
||
def download_video(self, url, timeout=30, retries=3, path=None):
|
||
"""下载视频到临时文件并返回文件路径"""
|
||
for attempt in range(retries):
|
||
try:
|
||
# 创建临时文件
|
||
if path:
|
||
temp_path = path
|
||
else:
|
||
temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
||
temp_path = temp_file.name
|
||
temp_file.close()
|
||
|
||
# 下载视频
|
||
print(f"开始下载视频 (尝试 {attempt + 1}/{retries})...")
|
||
response = requests.get(url, stream=True, timeout=timeout)
|
||
response.raise_for_status()
|
||
|
||
# 获取文件大小
|
||
total_size = int(response.headers.get('content-length', 0))
|
||
block_size = 1024 # 1 KB
|
||
|
||
# 使用tqdm显示下载进度
|
||
with open(temp_path, 'wb') as f, tqdm(
|
||
desc=url.split('/')[-1],
|
||
total=total_size,
|
||
unit='B',
|
||
unit_scale=True,
|
||
unit_divisor=1024
|
||
) as bar:
|
||
for data in response.iter_content(block_size):
|
||
size = f.write(data)
|
||
bar.update(size)
|
||
|
||
print(f"视频下载完成: {temp_path}")
|
||
return temp_path, self.read_video_last_frame_to_tensor(temp_path)
|
||
|
||
except Exception as e:
|
||
print(f"下载错误 (尝试 {attempt + 1}/{retries}): {str(e)}")
|
||
if attempt < retries - 1:
|
||
time.sleep((attempt + 1) * 2)
|
||
else:
|
||
raise
|
||
|
||
def jpg_to_tensor(self, image_path, channel_first=False):
|
||
"""
|
||
将JPG图像转换为PyTorch张量
|
||
|
||
参数:
|
||
- image_path: JPG图像文件路径
|
||
- normalize: 是否将像素值归一化到[0.0, 1.0]
|
||
- channel_first: 是否将通道维度放在前面 (C, H, W)
|
||
|
||
返回:
|
||
- tensor: PyTorch张量
|
||
"""
|
||
try:
|
||
# 打开图像文件
|
||
image = Image.open(image_path).convert('RGB')
|
||
|
||
# 转换为张量
|
||
tensor = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,]
|
||
|
||
return tensor
|
||
|
||
except Exception as e:
|
||
print(f"转换失败: {str(e)}")
|
||
raise
|
||
|
||
def get_last_15th_frame_tensor(self, video_url, cleanup=True):
|
||
"""
|
||
从视频URL截取倒数第15帧并转换为Tensor
|
||
先下载视频到本地临时文件再处理
|
||
"""
|
||
try:
|
||
# 下载视频
|
||
video_path, _ = self.download_video(video_url)
|
||
|
||
# 获取视频总帧数
|
||
cmd_frames = [
|
||
'ffprobe', '-v', 'error',
|
||
'-select_streams', 'v:0',
|
||
'-show_entries', 'stream=nb_frames',
|
||
'-of', 'default=nokey=1:noprint_wrappers=1',
|
||
video_path
|
||
]
|
||
|
||
result = subprocess.run(
|
||
cmd_frames,
|
||
capture_output=True,
|
||
text=True,
|
||
check=True
|
||
)
|
||
|
||
# 处理可能的非数字输出
|
||
frame_count = result.stdout.strip()
|
||
if not frame_count.isdigit():
|
||
# 备选方案:通过解码获取帧数
|
||
print("无法获取准确帧数,尝试直接解码...")
|
||
cmd_decode = [
|
||
'ffmpeg', '-i', video_path,
|
||
'-f', 'null', '-'
|
||
]
|
||
decode_result = subprocess.run(
|
||
cmd_decode,
|
||
capture_output=True,
|
||
text=True
|
||
)
|
||
|
||
for line in decode_result.stderr.split('\n'):
|
||
if 'frame=' in line:
|
||
parts = line.split('frame=')[-1].split()[0]
|
||
if parts.isdigit():
|
||
frame_count = int(parts)
|
||
break
|
||
else:
|
||
raise ValueError("无法确定视频帧数")
|
||
else:
|
||
frame_count = int(frame_count)
|
||
|
||
# 计算目标帧
|
||
target_frame = max(0, frame_count - 15)
|
||
print(f"视频总帧数: {frame_count}, 目标帧: {target_frame}")
|
||
|
||
# 截取指定帧
|
||
with tempfile.NamedTemporaryFile(suffix='%03d.jpg', delete=True) as frame_file:
|
||
frame_path = frame_file.name
|
||
|
||
cmd_extract = [
|
||
'ffmpeg',
|
||
'-ss', f'00:00:00',
|
||
'-i', video_path,
|
||
'-vframes', '1',
|
||
'-vf', f'select=eq(n\,{target_frame})',
|
||
'-vsync', '0',
|
||
'-an', '-y',
|
||
frame_path
|
||
]
|
||
|
||
subprocess.run(
|
||
cmd_extract,
|
||
capture_output=True,
|
||
check=True
|
||
)
|
||
|
||
# 转换为Tensor
|
||
tensor = self.jpg_to_tensor(frame_path.replace("%03d", "001"))
|
||
except Exception as e:
|
||
raise e
|
||
return tensor
|
||
|
||
|
||
class JMGestureCorrect:
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"image": ("IMAGE",),
|
||
"resolution":(["720p","1080p"])
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
RETURN_NAMES = ("正面图",)
|
||
FUNCTION = "gen"
|
||
CATEGORY = "不忘科技-自定义节点🚩/图片/姿态"
|
||
|
||
def gen(self, image: torch.Tensor, resolution: str):
|
||
wait_time = 240
|
||
interval = 2
|
||
client = JMUtils()
|
||
image_io = client.tensor_to_io(image)
|
||
upload_data = client.upload_io_to_cos(image_io)
|
||
if upload_data["status"]:
|
||
image_url = upload_data["data"]
|
||
else:
|
||
raise Exception("上传失败")
|
||
prompt = "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame"
|
||
submit_data = client.submit_task(prompt, image_url, duration="5", resolution=resolution)
|
||
if submit_data["status"]:
|
||
job_id = submit_data["data"]
|
||
else:
|
||
raise Exception("即梦任务提交失败")
|
||
job_data = None
|
||
for idx, _ in enumerate(range(0, wait_time, interval)):
|
||
logger.info(f"查询即梦结果 {idx + 1}")
|
||
query = client.query_status(job_id)
|
||
if query["status"]:
|
||
job_data = query["data"]
|
||
break
|
||
else:
|
||
if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]:
|
||
raise Exception("即梦任务失败 {}".format(query["msg"]))
|
||
sleep(interval)
|
||
if not job_data:
|
||
raise Exception("即梦任务等待超时")
|
||
return (client.get_last_15th_frame_tensor(job_data),)
|
||
|
||
|
||
class JMCustom:
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"image": ("IMAGE",),
|
||
"prompt": ("STRING", {
|
||
"default": "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame",
|
||
"multiline": True}),
|
||
"duration": ("INT", {"default": 5, "min": 2, "max": 10}),
|
||
"resolution": (["720p", "1080p"]),
|
||
"wait_time": ("INT", {"default": 180, "min": 60, "max": 600}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING", "IMAGE",)
|
||
RETURN_NAMES = ("视频存储路径", "视频最后一帧")
|
||
FUNCTION = "gen"
|
||
CATEGORY = "不忘科技-自定义节点🚩/视频/即梦"
|
||
|
||
def gen(self, image: torch.Tensor, prompt: str, duration: int, resolution: str, wait_time: int):
|
||
interval = 2
|
||
client = JMUtils()
|
||
image_io = client.tensor_to_io(image)
|
||
upload_data = client.upload_io_to_cos(image_io)
|
||
if upload_data["status"]:
|
||
image_url = upload_data["data"]
|
||
else:
|
||
raise Exception("上传失败")
|
||
submit_data = client.submit_task(prompt, image_url, str(duration), resolution=resolution)
|
||
if submit_data["status"]:
|
||
job_id = submit_data["data"]
|
||
else:
|
||
raise Exception("即梦任务提交失败")
|
||
job_data = None
|
||
for idx, _ in enumerate(range(0, wait_time, interval)):
|
||
logger.info(f"查询即梦结果 {idx + 1}")
|
||
query = client.query_status(job_id)
|
||
if query["status"]:
|
||
job_data = query["data"]
|
||
break
|
||
else:
|
||
if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]:
|
||
raise Exception("即梦任务失败 {}".format(query["msg"]))
|
||
sleep(interval)
|
||
if not job_data:
|
||
raise Exception("即梦任务等待超时")
|
||
video_path, last_scene = client.download_video(job_data, path=os.path.join(folder_paths.get_output_directory(),
|
||
f"{uuid.uuid4()}.mp4"))
|
||
return (video_path, last_scene,) |