373 lines
13 KiB
Python
373 lines
13 KiB
Python
import io
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import tempfile
|
|
import time
|
|
import uuid
|
|
from time import sleep
|
|
|
|
import folder_paths
|
|
import numpy as np
|
|
import requests
|
|
import torch
|
|
import yaml
|
|
from PIL import Image
|
|
from loguru import logger
|
|
from qcloud_cos import CosConfig, CosS3Client
|
|
from tqdm import tqdm
|
|
|
|
|
|
class JMUtils:
|
|
def __init__(self):
|
|
if "aws_key_id" in list(os.environ.keys()):
|
|
yaml_config = os.environ
|
|
else:
|
|
with open(
|
|
os.path.join(
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config.yaml"
|
|
),
|
|
encoding="utf-8",
|
|
mode="r+",
|
|
) as f:
|
|
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
self.api_key = yaml_config["jm_api_key"]
|
|
self.cos_region = yaml_config["cos_region"]
|
|
self.cos_secret_id = yaml_config["cos_secret_id"]
|
|
self.cos_secret_key = yaml_config["cos_secret_key"]
|
|
self.cos_bucket_name = yaml_config["cos_sucai_bucket_name"]
|
|
|
|
def submit_task(self, prompt: str, img_url: str, duration: str = "5"):
|
|
try:
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
}
|
|
|
|
json_data = {
|
|
"model": "doubao-seedance-1-0-lite-i2v-250428",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"{prompt} --resolution 720p --dur {duration} --camerafixed false",
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": img_url,
|
|
},
|
|
},
|
|
],
|
|
}
|
|
|
|
response = requests.post("https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks",
|
|
headers=headers, json=json_data)
|
|
logger.info(f"submit task: {json.dumps(response.json())}")
|
|
resp_json = response.json()
|
|
if "id" not in resp_json:
|
|
return {"status": False, "data": img_url, "msg": resp_json["error"]["message"]}
|
|
else:
|
|
job_id = resp_json["id"]
|
|
return {"data": job_id, "status": True, "msg": "任务提交成功"}
|
|
except Exception as e:
|
|
logger.error(e)
|
|
return {"data": None, "status": False, "msg": str(e)}
|
|
|
|
def query_status(self, job_id: str):
|
|
resp_dict = {"status": False, "data": None, "msg": ""}
|
|
try:
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
}
|
|
response = requests.get(f"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks/{job_id}",
|
|
headers=headers)
|
|
resp_json = response.json()
|
|
resp_dict["status"] = resp_json["status"] == "succeeded"
|
|
resp_dict["msg"] = resp_json["status"]
|
|
resp_dict["data"] = resp_json["content"]["video_url"] if "content" in resp_json else None
|
|
except Exception as e:
|
|
logger.error(f"error:{str(e)}")
|
|
resp_dict["msg"] = str(e)
|
|
finally:
|
|
return resp_dict
|
|
|
|
def upload_io_to_cos(self, file: io.IOBase, mime_type: str = "image/png"):
|
|
resp_data = {'status': True, 'data': '', 'msg': ''}
|
|
category = mime_type.split('/')[0]
|
|
suffix = mime_type.split('/')[1]
|
|
try:
|
|
object_key = f'tk/{category}/{uuid.uuid4()}.{suffix}'
|
|
config = CosConfig(Region=self.cos_region, SecretId=self.cos_secret_id, SecretKey=self.cos_secret_key)
|
|
client = CosS3Client(config)
|
|
_ = client.upload_file_from_buffer(
|
|
Bucket=self.cos_bucket_name,
|
|
Key=object_key,
|
|
Body=file
|
|
)
|
|
url = f'https://{self.cos_bucket_name}.cos.{self.cos_region}.myqcloud.com/{object_key}'
|
|
resp_data['data'] = url
|
|
resp_data['msg'] = '上传成功'
|
|
|
|
except Exception as e:
|
|
logger.error(e)
|
|
resp_data['status'] = False
|
|
resp_data['msg'] = str(e)
|
|
|
|
return resp_data
|
|
|
|
def tensor_to_io(srlf, tensor: torch.Tensor):
|
|
# 转换为PIL图像
|
|
img = Image.fromarray(np.clip(255. * tensor.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
|
image_data = io.BytesIO()
|
|
img.save(image_data, format='PNG')
|
|
image_data.seek(0)
|
|
return image_data
|
|
|
|
def download_video(self, url, timeout=30, retries=3, path=None):
|
|
"""下载视频到临时文件并返回文件路径"""
|
|
for attempt in range(retries):
|
|
try:
|
|
# 创建临时文件
|
|
if path:
|
|
temp_path = path
|
|
else:
|
|
temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
|
temp_path = temp_file.name
|
|
temp_file.close()
|
|
|
|
# 下载视频
|
|
print(f"开始下载视频 (尝试 {attempt + 1}/{retries})...")
|
|
response = requests.get(url, stream=True, timeout=timeout)
|
|
response.raise_for_status()
|
|
|
|
# 获取文件大小
|
|
total_size = int(response.headers.get('content-length', 0))
|
|
block_size = 1024 # 1 KB
|
|
|
|
# 使用tqdm显示下载进度
|
|
with open(temp_path, 'wb') as f, tqdm(
|
|
desc=url.split('/')[-1],
|
|
total=total_size,
|
|
unit='B',
|
|
unit_scale=True,
|
|
unit_divisor=1024
|
|
) as bar:
|
|
for data in response.iter_content(block_size):
|
|
size = f.write(data)
|
|
bar.update(size)
|
|
|
|
print(f"视频下载完成: {temp_path}")
|
|
return temp_path
|
|
|
|
except Exception as e:
|
|
print(f"下载错误 (尝试 {attempt + 1}/{retries}): {str(e)}")
|
|
if attempt < retries - 1:
|
|
time.sleep((attempt + 1) * 2)
|
|
else:
|
|
raise
|
|
|
|
def jpg_to_tensor(self, image_path, channel_first=False):
|
|
"""
|
|
将JPG图像转换为PyTorch张量
|
|
|
|
参数:
|
|
- image_path: JPG图像文件路径
|
|
- normalize: 是否将像素值归一化到[0.0, 1.0]
|
|
- channel_first: 是否将通道维度放在前面 (C, H, W)
|
|
|
|
返回:
|
|
- tensor: PyTorch张量
|
|
"""
|
|
try:
|
|
# 打开图像文件
|
|
image = Image.open(image_path).convert('RGB')
|
|
|
|
# 转换为张量
|
|
tensor = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,]
|
|
|
|
return tensor
|
|
|
|
except Exception as e:
|
|
print(f"转换失败: {str(e)}")
|
|
raise
|
|
|
|
def get_last_15th_frame_tensor(self, video_url, cleanup=True):
|
|
"""
|
|
从视频URL截取倒数第15帧并转换为Tensor
|
|
先下载视频到本地临时文件再处理
|
|
"""
|
|
try:
|
|
# 下载视频
|
|
video_path = self.download_video(video_url)
|
|
|
|
# 获取视频总帧数
|
|
cmd_frames = [
|
|
'ffprobe', '-v', 'error',
|
|
'-select_streams', 'v:0',
|
|
'-show_entries', 'stream=nb_frames',
|
|
'-of', 'default=nokey=1:noprint_wrappers=1',
|
|
video_path
|
|
]
|
|
|
|
result = subprocess.run(
|
|
cmd_frames,
|
|
capture_output=True,
|
|
text=True,
|
|
check=True
|
|
)
|
|
|
|
# 处理可能的非数字输出
|
|
frame_count = result.stdout.strip()
|
|
if not frame_count.isdigit():
|
|
# 备选方案:通过解码获取帧数
|
|
print("无法获取准确帧数,尝试直接解码...")
|
|
cmd_decode = [
|
|
'ffmpeg', '-i', video_path,
|
|
'-f', 'null', '-'
|
|
]
|
|
decode_result = subprocess.run(
|
|
cmd_decode,
|
|
capture_output=True,
|
|
text=True
|
|
)
|
|
|
|
for line in decode_result.stderr.split('\n'):
|
|
if 'frame=' in line:
|
|
parts = line.split('frame=')[-1].split()[0]
|
|
if parts.isdigit():
|
|
frame_count = int(parts)
|
|
break
|
|
else:
|
|
raise ValueError("无法确定视频帧数")
|
|
else:
|
|
frame_count = int(frame_count)
|
|
|
|
# 计算目标帧
|
|
target_frame = max(0, frame_count - 15)
|
|
print(f"视频总帧数: {frame_count}, 目标帧: {target_frame}")
|
|
|
|
# 截取指定帧
|
|
with tempfile.NamedTemporaryFile(suffix='%03d.jpg', delete=True) as frame_file:
|
|
frame_path = frame_file.name
|
|
|
|
cmd_extract = [
|
|
'ffmpeg',
|
|
'-ss', f'00:00:00',
|
|
'-i', video_path,
|
|
'-vframes', '1',
|
|
'-vf', f'select=eq(n\,{target_frame})',
|
|
'-vsync', '0',
|
|
'-an', '-y',
|
|
frame_path
|
|
]
|
|
|
|
subprocess.run(
|
|
cmd_extract,
|
|
capture_output=True,
|
|
check=True
|
|
)
|
|
|
|
# 转换为Tensor
|
|
tensor = self.jpg_to_tensor(frame_path.replace("%03d", "001"))
|
|
except Exception as e:
|
|
raise e
|
|
return tensor
|
|
|
|
|
|
class JMGestureCorrect:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",)
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
RETURN_NAMES = ("正面图",)
|
|
FUNCTION = "gen"
|
|
CATEGORY = "不忘科技-自定义节点🚩/图片/姿态"
|
|
|
|
def gen(self, image: torch.Tensor):
|
|
wait_time = 120
|
|
interval = 2
|
|
client = JMUtils()
|
|
image_io = client.tensor_to_io(image)
|
|
upload_data = client.upload_io_to_cos(image_io)
|
|
if upload_data["status"]:
|
|
image_url = upload_data["data"]
|
|
else:
|
|
raise Exception("上传失败")
|
|
prompt = "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame"
|
|
submit_data = client.submit_task(prompt, image_url)
|
|
if submit_data["status"]:
|
|
job_id = submit_data["data"]
|
|
else:
|
|
raise Exception("即梦任务提交失败")
|
|
job_data = None
|
|
for idx, _ in enumerate(range(0, wait_time, interval)):
|
|
logger.info(f"查询即梦结果 {idx + 1}")
|
|
query = client.query_status(job_id)
|
|
if query["status"]:
|
|
job_data = query["data"]
|
|
break
|
|
else:
|
|
if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]:
|
|
raise Exception("即梦任务失败 {}".format(query["msg"]))
|
|
sleep(interval)
|
|
if not job_data:
|
|
raise Exception("即梦任务等待超时")
|
|
return (client.get_last_15th_frame_tensor(job_data),)
|
|
|
|
|
|
class JMCustom:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"prompt": ("STRING", {
|
|
"default": "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame",
|
|
"multiline": True}),
|
|
"duration": ("INT", {"default": 5, "min": 2, "max": 30}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
RETURN_NAMES = ("视频存储路径",)
|
|
FUNCTION = "gen"
|
|
CATEGORY = "不忘科技-自定义节点🚩/视频/即梦"
|
|
|
|
def gen(self, image: torch.Tensor, prompt: str, duration: int):
|
|
wait_time = 120
|
|
interval = 2
|
|
client = JMUtils()
|
|
image_io = client.tensor_to_io(image)
|
|
upload_data = client.upload_io_to_cos(image_io)
|
|
if upload_data["status"]:
|
|
image_url = upload_data["data"]
|
|
else:
|
|
raise Exception("上传失败")
|
|
submit_data = client.submit_task(prompt, image_url, str(duration))
|
|
if submit_data["status"]:
|
|
job_id = submit_data["data"]
|
|
else:
|
|
raise Exception("即梦任务提交失败")
|
|
job_data = None
|
|
for idx, _ in enumerate(range(0, wait_time, interval)):
|
|
logger.info(f"查询即梦结果 {idx + 1}")
|
|
query = client.query_status(job_id)
|
|
if query["status"]:
|
|
job_data = query["data"]
|
|
break
|
|
else:
|
|
if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]:
|
|
raise Exception("即梦任务失败 {}".format(query["msg"]))
|
|
sleep(interval)
|
|
if not job_data:
|
|
raise Exception("即梦任务等待超时")
|
|
return (
|
|
client.download_video(job_data, path=os.path.join(folder_paths.get_output_directory(), f"{uuid.uuid4()}.mp4")),)
|