From c1bf38a79b87ce7f2b680c2cde4fb3432c968287 Mon Sep 17 00:00:00 2001 From: "kyj@bowong.ai" Date: Thu, 3 Apr 2025 18:04:42 +0800 Subject: [PATCH] =?UTF-8?q?ADD=20=E5=A2=9E=E5=8A=A0HeyGem=E5=8F=A3?= =?UTF-8?q?=E5=9E=8B=E5=90=8C=E6=AD=A5(API)=E8=8A=82=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __init__.py | 7 ++-- nodes/heygem.py | 92 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 nodes/heygem.py diff --git a/__init__.py b/__init__.py index 17176ce..dabdb9a 100644 --- a/__init__.py +++ b/__init__.py @@ -1,3 +1,4 @@ +from .nodes.heygem import HeyGemF2F from .nodes.s3 import S3Download, S3Upload from .nodes.text import * from .nodes.traverse_folder import TraverseFolder @@ -29,7 +30,8 @@ NODE_CLASS_MAPPINGS = { "unloadAllModels": UnloadAllModels, "TraverseFolder": TraverseFolder, "LoadTextCustom": LoadTextLocal, - "LoadTextCustomOnline": LoadTextOnline + "LoadTextCustomOnline": LoadTextOnline, + "HeyGemF2F": HeyGemF2F, } # A dictionary that contains the friendly/humanly readable titles for the nodes @@ -49,5 +51,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "unloadAllModels": "卸载所有已加载模型", "TraverseFolder": "遍历文件夹", "LoadTextCustom": "读取文本文件(本地)", - "LoadTextCustomOnline": "读取文本文件(线上)" + "LoadTextCustomOnline": "读取文本文件(线上)", + "HeyGemF2F": "HeyGem口型同步(API)", } diff --git a/nodes/heygem.py b/nodes/heygem.py new file mode 100644 index 0000000..00fd998 --- /dev/null +++ b/nodes/heygem.py @@ -0,0 +1,92 @@ +import os +import uuid + +import loguru +import requests +import torchaudio +import torchvision +from time import sleep +from torch import Tensor + + +class HeyGemF2F: + """HeyGem 嘴型同步""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "video": ("IMAGE", {"forceInput": True}), + "audio": ("AUDIO", {"forceInput": True}), + "heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}), + "heygem_temp_path": ("STRING", {"default":"/code/data/temp"}) + } + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("视频存储路径",) + FUNCTION = "f2f" + CATEGORY = "不忘科技-自定义节点🚩" + + def path_convert(self,path): + if ":" in path: + path = path.replace(os.sep,"/").split(":") + path[0] = path[0].lower() + path[1] = path[1][1:] + path = "/".join(["/mnt",*path]) + return path + + def f2f(self, video:Tensor, audio:dict, heygem_url:str, heygem_temp_path:str): + uid = str(uuid.uuid4()) + try: + try: + torchvision.io.write_video(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid), video.mul_(255).int(),25) + torchaudio.save(os.path.join(os.path.dirname(__file__),"%s.wav" % uid), audio["waveform"].squeeze(0), audio["sample_rate"], True) + except: + raise RuntimeError("Save Temp File Error! ") + payload = { + "code": uid, + "video_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)), + "audio_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)), + "chaofen": 1, + "watermark_switch": 0, + "pn": 1 + } + print(payload) + r = requests.post(heygem_url+"/easy/submit", json=payload) + if r.status_code != 200: + raise RuntimeError("Request Error!") + else: + r_json = r.json() + if r_json["success"]: + loguru.logger.info("Submit Task Success") + else: + raise RuntimeError("Submit Task Fail") + t = 30 + while t>0: + r = requests.get(heygem_url+"/easy/query?code="+uid) + if r.status_code == 200: + j = r.json() + if "msg" in j and j["msg"]=="任务不存在": + raise RuntimeError("Task Missing") + if "data" in j and "异常" in j["data"]["msg"]: + raise RuntimeError("Task Run Error: %s" % j["data"]["msg"]) + if "data" in j and j["data"]["progress"] < 100 and j["data"]["result"] == "": + loguru.logger.info("Waiting Task Finish") + elif "data" in j and j["data"]["progress"] == 100 and j["data"]["msg"] == "任务完成" and j["data"]["result"] != "": + loguru.logger.info("Task Finished") + return ("%s%s" % (heygem_temp_path, j["data"]["result"]),) + else: + loguru.logger.info("Unknown Status") + else: + loguru.logger.info("Get Task Status Failed") + t -= 1 + sleep(5) + except BaseException as e: + raise RuntimeError(e) + finally: + try: + os.remove(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)) + os.remove(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)) + except: + pass \ No newline at end of file