import io import json import os import subprocess import tempfile import time import uuid from time import sleep import numpy as np import requests import torch import yaml from PIL import Image from loguru import logger from qcloud_cos import CosConfig, CosS3Client from tqdm import tqdm class JMUtils: def __init__(self): if "aws_key_id" in list(os.environ.keys()): yaml_config = os.environ else: with open( os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config.yaml" ), encoding="utf-8", mode="r+", ) as f: yaml_config = yaml.load(f, Loader=yaml.FullLoader) self.api_key = yaml_config["jm_api_key"] self.cos_region = yaml_config["cos_region"] self.cos_secret_id = yaml_config["cos_secret_id"] self.cos_secret_key = yaml_config["cos_secret_key"] self.cos_bucket_name = yaml_config["cos_sucai_bucket_name"] def submit_task(self, prompt: str, img_url: str, duration: str = "5"): try: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } json_data = { "model": "doubao-seedance-1-0-lite-i2v-250428", "content": [ { "type": "text", "text": f"{prompt} --resolution 720p --dur {duration} --camerafixed false", }, { "type": "image_url", "image_url": { "url": img_url, }, }, ], } response = requests.post("https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks", headers=headers, json=json_data) logger.info(f"submit task: {json.dumps(response.json())}") resp_json = response.json() if "id" not in resp_json: return {"status": False, "data": img_url, "msg": resp_json["error"]["message"]} else: job_id = resp_json["id"] return {"data": job_id, "status": True, "msg": "任务提交成功"} except Exception as e: logger.error(e) return {"data": None, "status": False, "msg": str(e)} def query_status(self, job_id: str): resp_dict = {"status": False, "data": None, "msg": ""} try: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } response = requests.get(f"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks/{job_id}", headers=headers) resp_json = response.json() resp_dict["status"] = resp_json["status"] == "succeeded" resp_dict["msg"] = resp_json["status"] resp_dict["data"] = resp_json["content"]["video_url"] if "content" in resp_json else None except Exception as e: logger.error(f"error:{str(e)}") resp_dict["msg"] = str(e) finally: return resp_dict def upload_io_to_cos(self, file: io.IOBase, mime_type: str = "image/png"): resp_data = {'status': True, 'data': '', 'msg': ''} category = mime_type.split('/')[0] suffix = mime_type.split('/')[1] try: object_key = f'tk/{category}/{uuid.uuid4()}.{suffix}' config = CosConfig(Region=self.cos_region, SecretId=self.cos_secret_id, SecretKey=self.cos_secret_key) client = CosS3Client(config) _ = client.upload_file_from_buffer( Bucket=self.cos_bucket_name, Key=object_key, Body=file ) url = f'https://{self.cos_bucket_name}.cos.{self.cos_region}.myqcloud.com/{object_key}' resp_data['data'] = url resp_data['msg'] = '上传成功' except Exception as e: logger.error(e) resp_data['status'] = False resp_data['msg'] = str(e) return resp_data def tensor_to_io(srlf, tensor: torch.Tensor): # 转换为PIL图像 img = Image.fromarray(np.clip(255. * tensor.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) image_data = io.BytesIO() img.save(image_data, format='PNG') image_data.seek(0) return image_data def download_video(self, url, timeout=30, retries=3): """下载视频到临时文件并返回文件路径""" for attempt in range(retries): try: # 创建临时文件 temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) temp_path = temp_file.name temp_file.close() # 下载视频 print(f"开始下载视频 (尝试 {attempt + 1}/{retries})...") response = requests.get(url, stream=True, timeout=timeout) response.raise_for_status() # 获取文件大小 total_size = int(response.headers.get('content-length', 0)) block_size = 1024 # 1 KB # 使用tqdm显示下载进度 with open(temp_path, 'wb') as f, tqdm( desc=url.split('/')[-1], total=total_size, unit='B', unit_scale=True, unit_divisor=1024 ) as bar: for data in response.iter_content(block_size): size = f.write(data) bar.update(size) print(f"视频下载完成: {temp_path}") return temp_path except Exception as e: print(f"下载错误 (尝试 {attempt + 1}/{retries}): {str(e)}") if attempt < retries - 1: time.sleep((attempt + 1) * 2) else: raise def jpg_to_tensor(self, image_path, channel_first=False): """ 将JPG图像转换为PyTorch张量 参数: - image_path: JPG图像文件路径 - normalize: 是否将像素值归一化到[0.0, 1.0] - channel_first: 是否将通道维度放在前面 (C, H, W) 返回: - tensor: PyTorch张量 """ try: # 打开图像文件 image = Image.open(image_path).convert('RGB') # 转换为张量 tensor = torch.from_numpy(np.array(image).astype(np.float32)/255.0)[None,] return tensor except Exception as e: print(f"转换失败: {str(e)}") raise def get_last_15th_frame_tensor(self, video_url, cleanup=True): """ 从视频URL截取倒数第15帧并转换为Tensor 先下载视频到本地临时文件再处理 """ try: # 下载视频 video_path = self.download_video(video_url) # 获取视频总帧数 cmd_frames = [ 'ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=nb_frames', '-of', 'default=nokey=1:noprint_wrappers=1', video_path ] result = subprocess.run( cmd_frames, capture_output=True, text=True, check=True ) # 处理可能的非数字输出 frame_count = result.stdout.strip() if not frame_count.isdigit(): # 备选方案:通过解码获取帧数 print("无法获取准确帧数,尝试直接解码...") cmd_decode = [ 'ffmpeg', '-i', video_path, '-f', 'null', '-' ] decode_result = subprocess.run( cmd_decode, capture_output=True, text=True ) for line in decode_result.stderr.split('\n'): if 'frame=' in line: parts = line.split('frame=')[-1].split()[0] if parts.isdigit(): frame_count = int(parts) break else: raise ValueError("无法确定视频帧数") else: frame_count = int(frame_count) # 计算目标帧 target_frame = max(0, frame_count - 15) print(f"视频总帧数: {frame_count}, 目标帧: {target_frame}") # 截取指定帧 with tempfile.NamedTemporaryFile(suffix='%03d.jpg', delete=True) as frame_file: frame_path = frame_file.name cmd_extract = [ 'ffmpeg', '-ss', f'00:00:00', '-i', video_path, '-vframes', '1', '-vf', f'select=eq(n\,{target_frame})', '-vsync', '0', '-an', '-y', frame_path ] subprocess.run( cmd_extract, capture_output=True, check=True ) # 转换为Tensor tensor = self.jpg_to_tensor(frame_path.replace("%03d","001")) except Exception as e: raise e return tensor class JMGestureCorrect: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",) } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("正面图",) FUNCTION = "gen" CATEGORY = "不忘科技-自定义节点🚩/图片/姿态" def gen(self, image:torch.Tensor): wait_time = 120 interval = 2 client = JMUtils() image_io = client.tensor_to_io(image) upload_data = client.upload_io_to_cos(image_io) if upload_data["status"]: image_url = upload_data["data"] else: raise Exception("上传失败") prompt = "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame" submit_data = client.submit_task(prompt, image_url) if submit_data["status"]: job_id = submit_data["data"] else: raise Exception("即梦任务提交失败") job_data = None for idx, _ in enumerate(range(0, wait_time, interval)): logger.info(f"查询即梦结果 {idx+1}") query = client.query_status(job_id) if query["status"]: job_data = query["data"] break else: if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]: raise Exception("即梦任务失败 {}".format(query["msg"])) sleep(interval) if not job_data: raise Exception("即梦任务等待超时") return (client.get_last_15th_frame_tensor(job_data),)