From 6595aa46e100640f1c5f9501651540c497c3f9ce Mon Sep 17 00:00:00 2001 From: "kyj@bowong.ai" Date: Tue, 15 Jul 2025 18:15:16 +0800 Subject: [PATCH] =?UTF-8?q?ADD=20=E6=B7=BB=E5=8A=A0=E5=8D=B3=E6=A2=A6?= =?UTF-8?q?=E7=94=9F=E8=A7=86=E9=A2=91=E8=8A=82=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Readme.md | 8 ++++ __init__.py | 8 ++-- nodes/image_gesture_nodes.py | 81 ++++++++++++++++++++++++++++++------ 3 files changed, 81 insertions(+), 16 deletions(-) diff --git a/Readme.md b/Readme.md index 1add128..867079f 100644 --- a/Readme.md +++ b/Readme.md @@ -59,6 +59,10 @@ pip install -r requirements.txt - **输入**:LLM 提供商、提示词、图片张量、温度、最大令牌数、超时时间 - **输出**:LLM 输出结果 - **用途**:结合图片张量与 LLM 进行交互 +- **JMCustom**:即梦自定义Prompt生视频 + - **输入**:参考图、Prompt、视频长度 + - **输出**:生成视频路径 + - **用途**:自定义Prompt视频时长生成视频 - **RandomLineSelector**:从多行文本中随机选择一行。 - **输入**:多行文本、随机种子 - **输出**:随机选择的一行文本 @@ -105,6 +109,10 @@ pip install -r requirements.txt - **用途**:根据提供的模板和变量,使用 Jinja2 引擎渲染出最终的字符串,常用于生成动态的 prompt ### 3. 图像和视频处理节点 +- **JMCustom**:即梦自定义Prompt生视频 + - **输入**:参考图、Prompt、视频长度 + - **输出**:生成视频路径 + - **用途**:自定义Prompt视频时长生成视频 - **LoadImg**:从网络/本地加载图片(网络图片优先)。 - **输入**:图片 URL、选择本地图片 - **输出**:图像张量 diff --git a/__init__.py b/__init__.py index 66471f0..97322b7 100644 --- a/__init__.py +++ b/__init__.py @@ -1,7 +1,7 @@ from .nodes.image_modal_nodes import ModalEditCustom, ModalClothesMask, ModalMidJourneyGenerateImage, \ ModalMidJourneyDescribeImage from .nodes.image_face_nodes import FaceDetect, FaceExtract -from .nodes.image_gesture_nodes import JMGestureCorrect +from .nodes.image_gesture_nodes import JMGestureCorrect, JMCustom from .nodes.image_nodes import SaveImagePath, SaveImageWithOutput, LoadImgOptional from .nodes.llm_nodes import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor, Jinja2RenderTemplate from .nodes.object_storage_nodes import COSUpload, COSDownload, S3Download, S3Upload, S3UploadURL @@ -45,7 +45,8 @@ NODE_CLASS_MAPPINGS = { "ModalClothesMask": ModalClothesMask, "ModalEditCustom": ModalEditCustom, "ModalMidJourneyGenerateImage": ModalMidJourneyGenerateImage, - "ModalMidJourneyDescribeImage": ModalMidJourneyDescribeImage + "ModalMidJourneyDescribeImage": ModalMidJourneyDescribeImage, + "JMCustom": JMCustom } NODE_DISPLAY_NAME_MAPPINGS = { @@ -82,5 +83,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ModalClothesMask": "模特指定衣服替换为指定颜色", "ModalEditCustom": "自定义Prompt修改图片", "ModalMidJourneyGenerateImage": "Prompt修图", - "ModalMidJourneyDescribeImage": "反推生图提示词" + "ModalMidJourneyDescribeImage": "反推生图提示词", + "JMCustom": "Prompt生视频" } diff --git a/nodes/image_gesture_nodes.py b/nodes/image_gesture_nodes.py index 66467be..c8f628b 100644 --- a/nodes/image_gesture_nodes.py +++ b/nodes/image_gesture_nodes.py @@ -7,6 +7,7 @@ import time import uuid from time import sleep +import folder_paths import numpy as np import requests import torch @@ -16,6 +17,7 @@ from loguru import logger from qcloud_cos import CosConfig, CosS3Client from tqdm import tqdm + class JMUtils: def __init__(self): if "aws_key_id" in list(os.environ.keys()): @@ -59,7 +61,8 @@ class JMUtils: ], } - response = requests.post("https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks", headers=headers, json=json_data) + response = requests.post("https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks", + headers=headers, json=json_data) logger.info(f"submit task: {json.dumps(response.json())}") resp_json = response.json() if "id" not in resp_json: @@ -78,7 +81,8 @@ class JMUtils: "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } - response = requests.get(f"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks/{job_id}", headers=headers) + response = requests.get(f"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks/{job_id}", + headers=headers) resp_json = response.json() resp_dict["status"] = resp_json["status"] == "succeeded" resp_dict["msg"] = resp_json["status"] @@ -89,7 +93,6 @@ class JMUtils: finally: return resp_dict - def upload_io_to_cos(self, file: io.IOBase, mime_type: str = "image/png"): resp_data = {'status': True, 'data': '', 'msg': ''} category = mime_type.split('/')[0] @@ -122,15 +125,17 @@ class JMUtils: image_data.seek(0) return image_data - - def download_video(self, url, timeout=30, retries=3): + def download_video(self, url, timeout=30, retries=3, path=None): """下载视频到临时文件并返回文件路径""" for attempt in range(retries): try: # 创建临时文件 - temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) - temp_path = temp_file.name - temp_file.close() + if path: + temp_path = path + else: + temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) + temp_path = temp_file.name + temp_file.close() # 下载视频 print(f"开始下载视频 (尝试 {attempt + 1}/{retries})...") @@ -163,7 +168,6 @@ class JMUtils: else: raise - def jpg_to_tensor(self, image_path, channel_first=False): """ 将JPG图像转换为PyTorch张量 @@ -181,7 +185,7 @@ class JMUtils: image = Image.open(image_path).convert('RGB') # 转换为张量 - tensor = torch.from_numpy(np.array(image).astype(np.float32)/255.0)[None,] + tensor = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,] return tensor @@ -266,11 +270,12 @@ class JMUtils: ) # 转换为Tensor - tensor = self.jpg_to_tensor(frame_path.replace("%03d","001")) + tensor = self.jpg_to_tensor(frame_path.replace("%03d", "001")) except Exception as e: raise e return tensor + class JMGestureCorrect: @classmethod def INPUT_TYPES(s): @@ -285,7 +290,7 @@ class JMGestureCorrect: FUNCTION = "gen" CATEGORY = "不忘科技-自定义节点🚩/图片/姿态" - def gen(self, image:torch.Tensor): + def gen(self, image: torch.Tensor): wait_time = 120 interval = 2 client = JMUtils() @@ -303,7 +308,7 @@ class JMGestureCorrect: raise Exception("即梦任务提交失败") job_data = None for idx, _ in enumerate(range(0, wait_time, interval)): - logger.info(f"查询即梦结果 {idx+1}") + logger.info(f"查询即梦结果 {idx + 1}") query = client.query_status(job_id) if query["status"]: job_data = query["data"] @@ -315,3 +320,53 @@ class JMGestureCorrect: if not job_data: raise Exception("即梦任务等待超时") return (client.get_last_15th_frame_tensor(job_data),) + + +class JMCustom: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "prompt": ("STRING", { + "default": "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame", + "multiline": True}), + "duration": ("INT", {"default": 5, "min": 2, "max": 30}), + } + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("视频存储路径",) + FUNCTION = "gen" + CATEGORY = "不忘科技-自定义节点🚩/视频/即梦" + + def gen(self, image: torch.Tensor, prompt: str, duration: int): + wait_time = 120 + interval = 2 + client = JMUtils() + image_io = client.tensor_to_io(image) + upload_data = client.upload_io_to_cos(image_io) + if upload_data["status"]: + image_url = upload_data["data"] + else: + raise Exception("上传失败") + submit_data = client.submit_task(prompt, image_url, str(duration)) + if submit_data["status"]: + job_id = submit_data["data"] + else: + raise Exception("即梦任务提交失败") + job_data = None + for idx, _ in enumerate(range(0, wait_time, interval)): + logger.info(f"查询即梦结果 {idx + 1}") + query = client.query_status(job_id) + if query["status"]: + job_data = query["data"] + break + else: + if "error" in query["msg"] or "失败" in query["msg"] or "fail" in query["msg"]: + raise Exception("即梦任务失败 {}".format(query["msg"])) + sleep(interval) + if not job_data: + raise Exception("即梦任务等待超时") + return ( + client.download_video(job_data, path=os.path.join(folder_paths.get_output_directory(), f"{uuid.uuid4()}.mp4")),)