ADD 增加HeyGem口型同步(API)节点

This commit is contained in:
kyj@bowong.ai 2025-04-03 18:04:42 +08:00
parent 7f38cd8b62
commit c1bf38a79b
2 changed files with 97 additions and 2 deletions

View File

@ -1,3 +1,4 @@
from .nodes.heygem import HeyGemF2F
from .nodes.s3 import S3Download, S3Upload
from .nodes.text import *
from .nodes.traverse_folder import TraverseFolder
@ -29,7 +30,8 @@ NODE_CLASS_MAPPINGS = {
"unloadAllModels": UnloadAllModels,
"TraverseFolder": TraverseFolder,
"LoadTextCustom": LoadTextLocal,
"LoadTextCustomOnline": LoadTextOnline
"LoadTextCustomOnline": LoadTextOnline,
"HeyGemF2F": HeyGemF2F,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
@ -49,5 +51,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"unloadAllModels": "卸载所有已加载模型",
"TraverseFolder": "遍历文件夹",
"LoadTextCustom": "读取文本文件(本地)",
"LoadTextCustomOnline": "读取文本文件(线上)"
"LoadTextCustomOnline": "读取文本文件(线上)",
"HeyGemF2F": "HeyGem口型同步(API)",
}

92
nodes/heygem.py Normal file
View File

@ -0,0 +1,92 @@
import os
import uuid
import loguru
import requests
import torchaudio
import torchvision
from time import sleep
from torch import Tensor
class HeyGemF2F:
"""HeyGem 嘴型同步"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"video": ("IMAGE", {"forceInput": True}),
"audio": ("AUDIO", {"forceInput": True}),
"heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}),
"heygem_temp_path": ("STRING", {"default":"/code/data/temp"})
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("视频存储路径",)
FUNCTION = "f2f"
CATEGORY = "不忘科技-自定义节点🚩"
def path_convert(self,path):
if ":" in path:
path = path.replace(os.sep,"/").split(":")
path[0] = path[0].lower()
path[1] = path[1][1:]
path = "/".join(["/mnt",*path])
return path
def f2f(self, video:Tensor, audio:dict, heygem_url:str, heygem_temp_path:str):
uid = str(uuid.uuid4())
try:
try:
torchvision.io.write_video(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid), video.mul_(255).int(),25)
torchaudio.save(os.path.join(os.path.dirname(__file__),"%s.wav" % uid), audio["waveform"].squeeze(0), audio["sample_rate"], True)
except:
raise RuntimeError("Save Temp File Error! ")
payload = {
"code": uid,
"video_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)),
"audio_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)),
"chaofen": 1,
"watermark_switch": 0,
"pn": 1
}
print(payload)
r = requests.post(heygem_url+"/easy/submit", json=payload)
if r.status_code != 200:
raise RuntimeError("Request Error!")
else:
r_json = r.json()
if r_json["success"]:
loguru.logger.info("Submit Task Success")
else:
raise RuntimeError("Submit Task Fail")
t = 30
while t>0:
r = requests.get(heygem_url+"/easy/query?code="+uid)
if r.status_code == 200:
j = r.json()
if "msg" in j and j["msg"]=="任务不存在":
raise RuntimeError("Task Missing")
if "data" in j and "异常" in j["data"]["msg"]:
raise RuntimeError("Task Run Error: %s" % j["data"]["msg"])
if "data" in j and j["data"]["progress"] < 100 and j["data"]["result"] == "":
loguru.logger.info("Waiting Task Finish")
elif "data" in j and j["data"]["progress"] == 100 and j["data"]["msg"] == "任务完成" and j["data"]["result"] != "":
loguru.logger.info("Task Finished")
return ("%s%s" % (heygem_temp_path, j["data"]["result"]),)
else:
loguru.logger.info("Unknown Status")
else:
loguru.logger.info("Get Task Status Failed")
t -= 1
sleep(5)
except BaseException as e:
raise RuntimeError(e)
finally:
try:
os.remove(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid))
os.remove(os.path.join(os.path.dirname(__file__),"%s.wav" % uid))
except:
pass