ADD 增加HeyGem口型同步(API)节点
This commit is contained in:
parent
7f38cd8b62
commit
c1bf38a79b
|
|
@ -1,3 +1,4 @@
|
||||||
|
from .nodes.heygem import HeyGemF2F
|
||||||
from .nodes.s3 import S3Download, S3Upload
|
from .nodes.s3 import S3Download, S3Upload
|
||||||
from .nodes.text import *
|
from .nodes.text import *
|
||||||
from .nodes.traverse_folder import TraverseFolder
|
from .nodes.traverse_folder import TraverseFolder
|
||||||
|
|
@ -29,7 +30,8 @@ NODE_CLASS_MAPPINGS = {
|
||||||
"unloadAllModels": UnloadAllModels,
|
"unloadAllModels": UnloadAllModels,
|
||||||
"TraverseFolder": TraverseFolder,
|
"TraverseFolder": TraverseFolder,
|
||||||
"LoadTextCustom": LoadTextLocal,
|
"LoadTextCustom": LoadTextLocal,
|
||||||
"LoadTextCustomOnline": LoadTextOnline
|
"LoadTextCustomOnline": LoadTextOnline,
|
||||||
|
"HeyGemF2F": HeyGemF2F,
|
||||||
}
|
}
|
||||||
|
|
||||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||||
|
|
@ -49,5 +51,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"unloadAllModels": "卸载所有已加载模型",
|
"unloadAllModels": "卸载所有已加载模型",
|
||||||
"TraverseFolder": "遍历文件夹",
|
"TraverseFolder": "遍历文件夹",
|
||||||
"LoadTextCustom": "读取文本文件(本地)",
|
"LoadTextCustom": "读取文本文件(本地)",
|
||||||
"LoadTextCustomOnline": "读取文本文件(线上)"
|
"LoadTextCustomOnline": "读取文本文件(线上)",
|
||||||
|
"HeyGemF2F": "HeyGem口型同步(API)",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import loguru
|
||||||
|
import requests
|
||||||
|
import torchaudio
|
||||||
|
import torchvision
|
||||||
|
from time import sleep
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class HeyGemF2F:
|
||||||
|
"""HeyGem 嘴型同步"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"video": ("IMAGE", {"forceInput": True}),
|
||||||
|
"audio": ("AUDIO", {"forceInput": True}),
|
||||||
|
"heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}),
|
||||||
|
"heygem_temp_path": ("STRING", {"default":"/code/data/temp"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
RETURN_NAMES = ("视频存储路径",)
|
||||||
|
FUNCTION = "f2f"
|
||||||
|
CATEGORY = "不忘科技-自定义节点🚩"
|
||||||
|
|
||||||
|
def path_convert(self,path):
|
||||||
|
if ":" in path:
|
||||||
|
path = path.replace(os.sep,"/").split(":")
|
||||||
|
path[0] = path[0].lower()
|
||||||
|
path[1] = path[1][1:]
|
||||||
|
path = "/".join(["/mnt",*path])
|
||||||
|
return path
|
||||||
|
|
||||||
|
def f2f(self, video:Tensor, audio:dict, heygem_url:str, heygem_temp_path:str):
|
||||||
|
uid = str(uuid.uuid4())
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
torchvision.io.write_video(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid), video.mul_(255).int(),25)
|
||||||
|
torchaudio.save(os.path.join(os.path.dirname(__file__),"%s.wav" % uid), audio["waveform"].squeeze(0), audio["sample_rate"], True)
|
||||||
|
except:
|
||||||
|
raise RuntimeError("Save Temp File Error! ")
|
||||||
|
payload = {
|
||||||
|
"code": uid,
|
||||||
|
"video_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)),
|
||||||
|
"audio_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)),
|
||||||
|
"chaofen": 1,
|
||||||
|
"watermark_switch": 0,
|
||||||
|
"pn": 1
|
||||||
|
}
|
||||||
|
print(payload)
|
||||||
|
r = requests.post(heygem_url+"/easy/submit", json=payload)
|
||||||
|
if r.status_code != 200:
|
||||||
|
raise RuntimeError("Request Error!")
|
||||||
|
else:
|
||||||
|
r_json = r.json()
|
||||||
|
if r_json["success"]:
|
||||||
|
loguru.logger.info("Submit Task Success")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Submit Task Fail")
|
||||||
|
t = 30
|
||||||
|
while t>0:
|
||||||
|
r = requests.get(heygem_url+"/easy/query?code="+uid)
|
||||||
|
if r.status_code == 200:
|
||||||
|
j = r.json()
|
||||||
|
if "msg" in j and j["msg"]=="任务不存在":
|
||||||
|
raise RuntimeError("Task Missing")
|
||||||
|
if "data" in j and "异常" in j["data"]["msg"]:
|
||||||
|
raise RuntimeError("Task Run Error: %s" % j["data"]["msg"])
|
||||||
|
if "data" in j and j["data"]["progress"] < 100 and j["data"]["result"] == "":
|
||||||
|
loguru.logger.info("Waiting Task Finish")
|
||||||
|
elif "data" in j and j["data"]["progress"] == 100 and j["data"]["msg"] == "任务完成" and j["data"]["result"] != "":
|
||||||
|
loguru.logger.info("Task Finished")
|
||||||
|
return ("%s%s" % (heygem_temp_path, j["data"]["result"]),)
|
||||||
|
else:
|
||||||
|
loguru.logger.info("Unknown Status")
|
||||||
|
else:
|
||||||
|
loguru.logger.info("Get Task Status Failed")
|
||||||
|
t -= 1
|
||||||
|
sleep(5)
|
||||||
|
except BaseException as e:
|
||||||
|
raise RuntimeError(e)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
os.remove(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid))
|
||||||
|
os.remove(os.path.join(os.path.dirname(__file__),"%s.wav" % uid))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
Loading…
Reference in New Issue