import io import json import os import subprocess import tempfile import time import uuid from time import sleep from typing import Any, Dict import cv2 import folder_paths import numpy as np import requests import torch from loguru import logger from PIL import Image from torchvision.transforms import transforms from tqdm import tqdm from ..utils.config_utils import config from ..utils.object_storage import UploadResult, get_provider class JMUtils: """ 即梦AI工具类 提供即梦AI视频生成服务的完整功能,包括: - 图像上传到云存储 - 任务提交和状态查询 - 视频下载和处理 - 张量和图像格式转换 使用统一的存储抽象层,支持多种云存储服务。 """ def __init__(self): """ 初始化即梦工具实例 从配置中读取API密钥和存储配置信息 """ try: # 获取即梦API配置 self.api_key = config.get_config("jm_api_key") if not self.api_key: raise ValueError("即梦API密钥未配置") # 获取COS存储配置(用于素材上传) cos_config = config.get_cos_config() self.cos_bucket_name = cos_config.get("bucket_name") or config.get_config( "cos_sucai_bucket_name" ) if not self.cos_bucket_name: raise ValueError("COS素材存储桶未配置") # 获取存储提供者 self.storage_provider = get_provider("cos") logger.info(f"即梦工具初始化成功,使用存储桶: {self.cos_bucket_name}") except Exception as e: logger.error(f"即梦工具初始化失败: {e}") raise def submit_task( self, prompt: str, img_url: str, duration: str = "10", resolution: str = "720p" ) -> Dict[str, Any]: """ 提交即梦AI视频生成任务 Args: prompt: 生成提示词 img_url: 输入图像URL duration: 视频时长(秒) resolution: 视频分辨率 Returns: Dict: 任务提交结果 - status: 是否成功 - data: 任务ID或原图URL - msg: 消息 """ try: # 验证输入参数 if not prompt or not prompt.strip(): return {"status": False, "data": None, "msg": "提示词不能为空"} if not img_url or not img_url.strip(): return {"status": False, "data": None, "msg": "图像URL不能为空"} headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } json_data = { "model": "doubao-seedance-1-0-pro-250528", "content": [ { "type": "text", "text": f"{prompt.strip()} --resolution {resolution} --dur {duration} --camerafixed false", }, { "type": "image_url", "image_url": { "url": img_url.strip(), }, }, ], } logger.info( f"即梦任务提交中: prompt='{prompt[:50]}...', resolution={resolution}, duration={duration}" ) response = requests.post( "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks", headers=headers, json=json_data, timeout=30, ) response.raise_for_status() resp_json = response.json() logger.info( f"即梦任务提交响应: {json.dumps(resp_json, ensure_ascii=False)}" ) if "id" not in resp_json: error_msg = "未知错误" if "error" in resp_json and "message" in resp_json["error"]: error_msg = resp_json["error"]["message"] return { "status": False, "data": img_url, "msg": f"任务提交失败: {error_msg}", } else: job_id = resp_json["id"] logger.info(f"即梦任务提交成功,任务ID: {job_id}") return {"data": job_id, "status": True, "msg": "任务提交成功"} except requests.RequestException as e: logger.error(f"即梦API请求失败: {e}") return {"data": None, "status": False, "msg": f"网络请求失败: {str(e)}"} except Exception as e: logger.error(f"即梦任务提交异常: {e}") return {"data": None, "status": False, "msg": str(e)} def query_status(self, job_id: str) -> Dict[str, Any]: """ 查询即梦AI任务状态 Args: job_id: 任务ID Returns: Dict: 任务状态查询结果 - status: 任务是否完成成功 - data: 视频URL(如果完成) - msg: 状态消息 """ resp_dict = {"status": False, "data": None, "msg": ""} try: if not job_id or not job_id.strip(): resp_dict["msg"] = "任务ID不能为空" return resp_dict headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } response = requests.get( f"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks/{job_id.strip()}", headers=headers, timeout=15, ) response.raise_for_status() resp_json = response.json() task_status = resp_json.get("status", "unknown") # 任务完成成功 if task_status == "succeeded": resp_dict["status"] = True resp_dict["msg"] = "任务完成" if "content" in resp_json and "video_url" in resp_json["content"]: resp_dict["data"] = resp_json["content"]["video_url"] else: resp_dict["status"] = False resp_dict["msg"] = "任务完成但未找到视频URL" # 任务失败 elif task_status in ["failed", "error"]: resp_dict["status"] = False error_msg = "任务失败" if "error" in resp_json: error_msg += f": {resp_json['error'].get('message', '未知错误')}" resp_dict["msg"] = error_msg # 任务进行中 elif task_status in ["pending", "running", "processing"]: resp_dict["status"] = False resp_dict["msg"] = f"任务进行中: {task_status}" # 其他状态 else: resp_dict["status"] = False resp_dict["msg"] = f"未知任务状态: {task_status}" except requests.RequestException as e: logger.error(f"即梦状态查询网络错误: {e}") resp_dict["msg"] = f"网络请求失败: {str(e)}" except Exception as e: logger.error(f"即梦状态查询异常: {e}") resp_dict["msg"] = str(e) return resp_dict def upload_io_to_cos( self, file: io.IOBase, mime_type: str = "image/png" ) -> Dict[str, Any]: """ 上传IO对象到COS存储 Args: file: 文件IO对象 mime_type: MIME类型 Returns: Dict: 包含上传结果的字典 - status: 是否成功 - data: 上传后的URL - msg: 消息 """ resp_data = {"status": True, "data": "", "msg": ""} try: # 解析MIME类型 parts = mime_type.split("/") category = parts[0] if len(parts) > 0 else "file" suffix = parts[1] if len(parts) > 1 else "bin" # 生成存储键名 object_key = f"tk/{category}/{uuid.uuid4()}.{suffix}" logger.info(f"开始上传文件到COS: {object_key}") # 读取文件内容 file_content = file.read() file.seek(0) # 重置文件指针 # 使用统一存储接口上传 result: UploadResult = self.storage_provider.upload_bytes( file_content, object_key, bucket_name=self.cos_bucket_name ) if result.success: # 构造COS URL(如果result中没有提供) if result.url: resp_data["data"] = result.url else: # 构造默认的COS URL cos_config = config.get_cos_config() region = cos_config.get("region", "ap-beijing") resp_data["data"] = ( f"https://{self.cos_bucket_name}.cos.{region}.myqcloud.com/{object_key}" ) resp_data["msg"] = "上传成功" logger.info(f"文件上传成功: {resp_data['data']}") else: resp_data["status"] = False resp_data["msg"] = result.message or "上传失败" logger.error(f"文件上传失败: {resp_data['msg']}") except Exception as e: logger.error(f"上传文件时发生异常: {e}") resp_data["status"] = False resp_data["msg"] = str(e) return resp_data def tensor_to_io(self, tensor: torch.Tensor) -> io.BytesIO: """ 将PyTorch张量转换为PNG格式的IO对象 Args: tensor: PyTorch图像张量,支持多种格式 - (H, W) 灰度图 - (H, W, C) RGB图像 - (1, H, W, C) 批次图像 Returns: io.BytesIO: PNG格式的字节流对象 Raises: ValueError: 当张量格式不支持时 """ try: # 处理张量维度 if tensor.dim() == 4: # (1, H, W, C) tensor = tensor.squeeze(0) elif tensor.dim() == 2: # (H, W) 灰度图 pass # 保持原样 elif tensor.dim() == 3: # (H, W, C) pass # 保持原样 else: raise ValueError(f"不支持的张量维度: {tensor.dim()}D") # 转换为numpy数组 numpy_array = tensor.cpu().numpy() # 确保数值在有效范围内 numpy_array = np.clip(numpy_array, 0.0, 1.0) # 转换为0-255范围的uint8 image_array = (numpy_array * 255).astype(np.uint8) # 处理灰度图 if len(image_array.shape) == 2: img = Image.fromarray(image_array, mode="L") else: img = Image.fromarray(image_array, mode="RGB") # 保存为PNG格式的BytesIO image_data = io.BytesIO() img.save(image_data, format="PNG", optimize=True) image_data.seek(0) logger.debug(f"张量转换为PNG成功,大小: {len(image_data.getvalue())} bytes") return image_data except Exception as e: logger.error(f"张量转换失败: {e}") raise ValueError(f"张量转换为图像失败: {str(e)}") def read_video_last_frame_to_tensor(self, video_path: str) -> torch.Tensor: """ 读取视频文件的最后一帧并将其转换为BCHW格式的PyTorch张量。 参数: video_path (str): 视频文件的路径。 返回: torch.Tensor: 形状为[1, H, W, C]的张量,其中H和W分别是视频帧的高度和宽度,通道顺序为RGB。 异常: FileNotFoundError: 如果指定的视频文件不存在。 ValueError: 如果视频文件为空或无法读取帧。 """ # 打开视频文件 cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise FileNotFoundError(f"无法打开视频文件: {video_path}") # 获取视频总帧数 total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total_frames == 0: cap.release() raise ValueError("视频文件为空或无法确定帧数") # 设置读取位置到最后一帧 cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1) # 读取最后一帧 ret, frame = cap.read() # 释放资源 cap.release() if not ret or frame is None: raise ValueError(f"无法读取视频的最后一帧,可能视频已损坏") # 转换BGR到RGB (OpenCV默认读取为BGR) frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换为PyTorch张量并调整维度为BCHW transform = transforms.Compose( [transforms.ToTensor()] # 转换为[C, H, W]格式的张量,值范围从0到1 ) tensor = ( transform(frame_rgb).unsqueeze(0).permute(0, 2, 3, 1) ) # 添加批次维度,变为[1, H, W, C] return tensor def download_video(self, url, timeout=30, retries=3, path=None): """下载视频到临时文件并返回文件路径""" for attempt in range(retries): try: # 创建临时文件 if path: temp_path = path else: temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) temp_path = temp_file.name temp_file.close() # 下载视频 print(f"开始下载视频 (尝试 {attempt + 1}/{retries})...") response = requests.get(url, stream=True, timeout=timeout) response.raise_for_status() # 获取文件大小 total_size = int(response.headers.get("content-length", 0)) block_size = 1024 # 1 KB # 使用tqdm显示下载进度 with open(temp_path, "wb") as f, tqdm( desc=url.split("/")[-1], total=total_size, unit="B", unit_scale=True, unit_divisor=1024, ) as bar: for data in response.iter_content(block_size): size = f.write(data) bar.update(size) print(f"视频下载完成: {temp_path}") return temp_path, self.read_video_last_frame_to_tensor(temp_path) except Exception as e: print(f"下载错误 (尝试 {attempt + 1}/{retries}): {str(e)}") if attempt < retries - 1: time.sleep((attempt + 1) * 2) else: raise def jpg_to_tensor(self, image_path): """ 将JPG图像转换为PyTorch张量 参数: - image_path: JPG图像文件路径 返回: - tensor: PyTorch张量 """ try: # 打开图像文件 image = Image.open(image_path).convert("RGB") # 转换为张量 tensor = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,] return tensor except Exception as e: print(f"转换失败: {str(e)}") raise def get_last_15th_frame_tensor(self, video_url): """ 从视频URL截取倒数第15帧并转换为Tensor 先下载视频到本地临时文件再处理 """ try: # 下载视频 video_path, _ = self.download_video(video_url) # 获取视频总帧数 cmd_frames = [ "ffprobe", "-v", "error", "-select_streams", "v:0", "-show_entries", "stream=nb_frames", "-of", "default=nokey=1:noprint_wrappers=1", video_path, ] result = subprocess.run( cmd_frames, capture_output=True, text=True, check=True ) # 处理可能的非数字输出 frame_count = result.stdout.strip() if not frame_count.isdigit(): # 备选方案:通过解码获取帧数 print("无法获取准确帧数,尝试直接解码...") cmd_decode = ["ffmpeg", "-i", video_path, "-f", "null", "-"] decode_result = subprocess.run( cmd_decode, capture_output=True, text=True ) for line in decode_result.stderr.split("\n"): if "frame=" in line: parts = line.split("frame=")[-1].split()[0] if parts.isdigit(): frame_count = int(parts) break else: raise ValueError("无法确定视频帧数") else: frame_count = int(frame_count) # 计算目标帧 target_frame = max(0, frame_count - 15) print(f"视频总帧数: {frame_count}, 目标帧: {target_frame}") # 截取指定帧 with tempfile.NamedTemporaryFile( suffix="%03d.jpg", delete=True ) as frame_file: frame_path = frame_file.name cmd_extract = [ "ffmpeg", "-ss", f"00:00:00", "-i", video_path, "-vframes", "1", "-vf", f"select=eq(n\,{target_frame})", "-vsync", "0", "-an", "-y", frame_path, ] subprocess.run(cmd_extract, capture_output=True, check=True) # 转换为Tensor tensor = self.jpg_to_tensor(frame_path.replace("%03d", "001")) except Exception as e: raise e return tensor class JMGestureCorrect: @classmethod def INPUT_TYPES(s): return {"required": {"image": ("IMAGE",), "resolution": (["720p", "1080p"])}} RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("正面图",) FUNCTION = "gen" CATEGORY = "不忘科技-自定义节点🚩/图片/姿态" def gen(self, image: torch.Tensor, resolution: str): wait_time = 240 interval = 2 client = JMUtils() image_io = client.tensor_to_io(image) upload_data = client.upload_io_to_cos(image_io) if upload_data["status"]: image_url = upload_data["data"] else: raise Exception("上传失败") prompt = "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame" submit_data = client.submit_task( prompt, image_url, duration="5", resolution=resolution ) if submit_data["status"]: job_id = submit_data["data"] else: raise Exception("即梦任务提交失败") job_data = None for idx, _ in enumerate(range(0, wait_time, interval)): logger.info(f"查询即梦结果 {idx + 1}") query = client.query_status(job_id) if query["status"]: job_data = query["data"] break else: if ( "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"] ): raise Exception("即梦任务失败 {}".format(query["msg"])) sleep(interval) if not job_data: raise Exception("即梦任务等待超时") return (client.get_last_15th_frame_tensor(job_data),) class JMCustom: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), "prompt": ( "STRING", { "default": "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame", "multiline": True, }, ), "duration": ("INT", {"default": 5, "min": 2, "max": 10}), "resolution": (["720p", "1080p"]), "wait_time": ("INT", {"default": 180, "min": 60, "max": 600}), } } RETURN_TYPES = ( "STRING", "IMAGE", ) RETURN_NAMES = ("视频存储路径", "视频最后一帧") FUNCTION = "gen" CATEGORY = "不忘科技-自定义节点🚩/视频/即梦" def gen( self, image: torch.Tensor, prompt: str, duration: int, resolution: str, wait_time: int, ): interval = 2 client = JMUtils() image_io = client.tensor_to_io(image) upload_data = client.upload_io_to_cos(image_io) if upload_data["status"]: image_url = upload_data["data"] else: raise Exception("上传失败") submit_data = client.submit_task( prompt, image_url, str(duration), resolution=resolution ) if submit_data["status"]: job_id = submit_data["data"] else: raise Exception("即梦任务提交失败") job_data = None for idx, _ in enumerate(range(0, wait_time, interval)): logger.info(f"查询即梦结果 {idx + 1}") query = client.query_status(job_id) if query["status"]: job_data = query["data"] break else: if ( "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"] ): raise Exception("即梦任务失败 {}".format(query["msg"])) sleep(interval) if not job_data: raise Exception("即梦任务等待超时") output_dir = folder_paths.get_output_directory() video_path, last_scene = client.download_video( job_data, path=os.path.join(output_dir, f"{uuid.uuid4()}.mp4") ) return ( video_path, last_scene, )