import io import json import os import subprocess import tempfile import time import uuid from time import sleep import folder_paths import numpy as np import requests import torch import yaml from PIL import Image from loguru import logger from qcloud_cos import CosConfig, CosS3Client from tqdm import tqdm class JMUtils: def __init__(self): if "aws_key_id" in list(os.environ.keys()): yaml_config = os.environ else: with open( os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config.yaml" ), encoding="utf-8", mode="r+", ) as f: yaml_config = yaml.load(f, Loader=yaml.FullLoader) self.api_key = yaml_config["jm_api_key"] self.cos_region = yaml_config["cos_region"] self.cos_secret_id = yaml_config["cos_secret_id"] self.cos_secret_key = yaml_config["cos_secret_key"] self.cos_bucket_name = yaml_config["cos_sucai_bucket_name"] def submit_task(self, prompt: str, img_url: str, duration: str = "10"): try: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } json_data = { "model": "doubao-seedance-1-0-pro-250528", "content": [ { "type": "text", "text": f"{prompt} --resolution 1080p --dur {duration} --camerafixed false", }, { "type": "image_url", "image_url": { "url": img_url, }, }, ], } response = requests.post("https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks", headers=headers, json=json_data) logger.info(f"submit task: {json.dumps(response.json())}") resp_json = response.json() if "id" not in resp_json: return {"status": False, "data": img_url, "msg": resp_json["error"]["message"]} else: job_id = resp_json["id"] return {"data": job_id, "status": True, "msg": "任务提交成功"} except Exception as e: logger.error(e) return {"data": None, "status": False, "msg": str(e)} def query_status(self, job_id: str): resp_dict = {"status": False, "data": None, "msg": ""} try: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } response = requests.get(f"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks/{job_id}", headers=headers) resp_json = response.json() resp_dict["status"] = resp_json["status"] == "succeeded" resp_dict["msg"] = resp_json["status"] resp_dict["data"] = resp_json["content"]["video_url"] if "content" in resp_json else None except Exception as e: logger.error(f"error:{str(e)}") resp_dict["msg"] = str(e) finally: return resp_dict def upload_io_to_cos(self, file: io.IOBase, mime_type: str = "image/png"): resp_data = {'status': True, 'data': '', 'msg': ''} category = mime_type.split('/')[0] suffix = mime_type.split('/')[1] try: object_key = f'tk/{category}/{uuid.uuid4()}.{suffix}' config = CosConfig(Region=self.cos_region, SecretId=self.cos_secret_id, SecretKey=self.cos_secret_key) client = CosS3Client(config) _ = client.upload_file_from_buffer( Bucket=self.cos_bucket_name, Key=object_key, Body=file ) url = f'https://{self.cos_bucket_name}.cos.{self.cos_region}.myqcloud.com/{object_key}' resp_data['data'] = url resp_data['msg'] = '上传成功' except Exception as e: logger.error(e) resp_data['status'] = False resp_data['msg'] = str(e) return resp_data def tensor_to_io(srlf, tensor: torch.Tensor): # 转换为PIL图像 img = Image.fromarray(np.clip(255. * tensor.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) image_data = io.BytesIO() img.save(image_data, format='PNG') image_data.seek(0) return image_data def download_video(self, url, timeout=30, retries=3, path=None): """下载视频到临时文件并返回文件路径""" for attempt in range(retries): try: # 创建临时文件 if path: temp_path = path else: temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) temp_path = temp_file.name temp_file.close() # 下载视频 print(f"开始下载视频 (尝试 {attempt + 1}/{retries})...") response = requests.get(url, stream=True, timeout=timeout) response.raise_for_status() # 获取文件大小 total_size = int(response.headers.get('content-length', 0)) block_size = 1024 # 1 KB # 使用tqdm显示下载进度 with open(temp_path, 'wb') as f, tqdm( desc=url.split('/')[-1], total=total_size, unit='B', unit_scale=True, unit_divisor=1024 ) as bar: for data in response.iter_content(block_size): size = f.write(data) bar.update(size) print(f"视频下载完成: {temp_path}") return temp_path except Exception as e: print(f"下载错误 (尝试 {attempt + 1}/{retries}): {str(e)}") if attempt < retries - 1: time.sleep((attempt + 1) * 2) else: raise def jpg_to_tensor(self, image_path, channel_first=False): """ 将JPG图像转换为PyTorch张量 参数: - image_path: JPG图像文件路径 - normalize: 是否将像素值归一化到[0.0, 1.0] - channel_first: 是否将通道维度放在前面 (C, H, W) 返回: - tensor: PyTorch张量 """ try: # 打开图像文件 image = Image.open(image_path).convert('RGB') # 转换为张量 tensor = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,] return tensor except Exception as e: print(f"转换失败: {str(e)}") raise def get_last_15th_frame_tensor(self, video_url, cleanup=True): """ 从视频URL截取倒数第15帧并转换为Tensor 先下载视频到本地临时文件再处理 """ try: # 下载视频 video_path = self.download_video(video_url) # 获取视频总帧数 cmd_frames = [ 'ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=nb_frames', '-of', 'default=nokey=1:noprint_wrappers=1', video_path ] result = subprocess.run( cmd_frames, capture_output=True, text=True, check=True ) # 处理可能的非数字输出 frame_count = result.stdout.strip() if not frame_count.isdigit(): # 备选方案:通过解码获取帧数 print("无法获取准确帧数,尝试直接解码...") cmd_decode = [ 'ffmpeg', '-i', video_path, '-f', 'null', '-' ] decode_result = subprocess.run( cmd_decode, capture_output=True, text=True ) for line in decode_result.stderr.split('\n'): if 'frame=' in line: parts = line.split('frame=')[-1].split()[0] if parts.isdigit(): frame_count = int(parts) break else: raise ValueError("无法确定视频帧数") else: frame_count = int(frame_count) # 计算目标帧 target_frame = max(0, frame_count - 15) print(f"视频总帧数: {frame_count}, 目标帧: {target_frame}") # 截取指定帧 with tempfile.NamedTemporaryFile(suffix='%03d.jpg', delete=True) as frame_file: frame_path = frame_file.name cmd_extract = [ 'ffmpeg', '-ss', f'00:00:00', '-i', video_path, '-vframes', '1', '-vf', f'select=eq(n\,{target_frame})', '-vsync', '0', '-an', '-y', frame_path ] subprocess.run( cmd_extract, capture_output=True, check=True ) # 转换为Tensor tensor = self.jpg_to_tensor(frame_path.replace("%03d", "001")) except Exception as e: raise e return tensor class JMGestureCorrect: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",) } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("正面图",) FUNCTION = "gen" CATEGORY = "不忘科技-自定义节点🚩/图片/姿态" def gen(self, image: torch.Tensor): wait_time = 120 interval = 2 client = JMUtils() image_io = client.tensor_to_io(image) upload_data = client.upload_io_to_cos(image_io) if upload_data["status"]: image_url = upload_data["data"] else: raise Exception("上传失败") prompt = "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame" submit_data = client.submit_task(prompt, image_url) if submit_data["status"]: job_id = submit_data["data"] else: raise Exception("即梦任务提交失败") job_data = None for idx, _ in enumerate(range(0, wait_time, interval)): logger.info(f"查询即梦结果 {idx + 1}") query = client.query_status(job_id) if query["status"]: job_data = query["data"] break else: if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]: raise Exception("即梦任务失败 {}".format(query["msg"])) sleep(interval) if not job_data: raise Exception("即梦任务等待超时") return (client.get_last_15th_frame_tensor(job_data),) class JMCustom: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), "prompt": ("STRING", { "default": "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame", "multiline": True}), "duration": ("INT", {"default": 5, "min": 2, "max": 10}), } } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("视频存储路径",) FUNCTION = "gen" CATEGORY = "不忘科技-自定义节点🚩/视频/即梦" def gen(self, image: torch.Tensor, prompt: str, duration: int): wait_time = 120 interval = 2 client = JMUtils() image_io = client.tensor_to_io(image) upload_data = client.upload_io_to_cos(image_io) if upload_data["status"]: image_url = upload_data["data"] else: raise Exception("上传失败") submit_data = client.submit_task(prompt, image_url, str(duration)) if submit_data["status"]: job_id = submit_data["data"] else: raise Exception("即梦任务提交失败") job_data = None for idx, _ in enumerate(range(0, wait_time, interval)): logger.info(f"查询即梦结果 {idx + 1}") query = client.query_status(job_id) if query["status"]: job_data = query["data"] break else: if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]: raise Exception("即梦任务失败 {}".format(query["msg"])) sleep(interval) if not job_data: raise Exception("即梦任务等待超时") return ( client.download_video(job_data, path=os.path.join(folder_paths.get_output_directory(), f"{uuid.uuid4()}.mp4")),)