ADD 增加HeyGem口型同步(API)节点
This commit is contained in:
parent
7f38cd8b62
commit
c1bf38a79b
|
|
@ -1,3 +1,4 @@
|
|||
from .nodes.heygem import HeyGemF2F
|
||||
from .nodes.s3 import S3Download, S3Upload
|
||||
from .nodes.text import *
|
||||
from .nodes.traverse_folder import TraverseFolder
|
||||
|
|
@ -29,7 +30,8 @@ NODE_CLASS_MAPPINGS = {
|
|||
"unloadAllModels": UnloadAllModels,
|
||||
"TraverseFolder": TraverseFolder,
|
||||
"LoadTextCustom": LoadTextLocal,
|
||||
"LoadTextCustomOnline": LoadTextOnline
|
||||
"LoadTextCustomOnline": LoadTextOnline,
|
||||
"HeyGemF2F": HeyGemF2F,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
|
|
@ -49,5 +51,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||
"unloadAllModels": "卸载所有已加载模型",
|
||||
"TraverseFolder": "遍历文件夹",
|
||||
"LoadTextCustom": "读取文本文件(本地)",
|
||||
"LoadTextCustomOnline": "读取文本文件(线上)"
|
||||
"LoadTextCustomOnline": "读取文本文件(线上)",
|
||||
"HeyGemF2F": "HeyGem口型同步(API)",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,92 @@
|
|||
import os
|
||||
import uuid
|
||||
|
||||
import loguru
|
||||
import requests
|
||||
import torchaudio
|
||||
import torchvision
|
||||
from time import sleep
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class HeyGemF2F:
|
||||
"""HeyGem 嘴型同步"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"video": ("IMAGE", {"forceInput": True}),
|
||||
"audio": ("AUDIO", {"forceInput": True}),
|
||||
"heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}),
|
||||
"heygem_temp_path": ("STRING", {"default":"/code/data/temp"})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("视频存储路径",)
|
||||
FUNCTION = "f2f"
|
||||
CATEGORY = "不忘科技-自定义节点🚩"
|
||||
|
||||
def path_convert(self,path):
|
||||
if ":" in path:
|
||||
path = path.replace(os.sep,"/").split(":")
|
||||
path[0] = path[0].lower()
|
||||
path[1] = path[1][1:]
|
||||
path = "/".join(["/mnt",*path])
|
||||
return path
|
||||
|
||||
def f2f(self, video:Tensor, audio:dict, heygem_url:str, heygem_temp_path:str):
|
||||
uid = str(uuid.uuid4())
|
||||
try:
|
||||
try:
|
||||
torchvision.io.write_video(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid), video.mul_(255).int(),25)
|
||||
torchaudio.save(os.path.join(os.path.dirname(__file__),"%s.wav" % uid), audio["waveform"].squeeze(0), audio["sample_rate"], True)
|
||||
except:
|
||||
raise RuntimeError("Save Temp File Error! ")
|
||||
payload = {
|
||||
"code": uid,
|
||||
"video_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)),
|
||||
"audio_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)),
|
||||
"chaofen": 1,
|
||||
"watermark_switch": 0,
|
||||
"pn": 1
|
||||
}
|
||||
print(payload)
|
||||
r = requests.post(heygem_url+"/easy/submit", json=payload)
|
||||
if r.status_code != 200:
|
||||
raise RuntimeError("Request Error!")
|
||||
else:
|
||||
r_json = r.json()
|
||||
if r_json["success"]:
|
||||
loguru.logger.info("Submit Task Success")
|
||||
else:
|
||||
raise RuntimeError("Submit Task Fail")
|
||||
t = 30
|
||||
while t>0:
|
||||
r = requests.get(heygem_url+"/easy/query?code="+uid)
|
||||
if r.status_code == 200:
|
||||
j = r.json()
|
||||
if "msg" in j and j["msg"]=="任务不存在":
|
||||
raise RuntimeError("Task Missing")
|
||||
if "data" in j and "异常" in j["data"]["msg"]:
|
||||
raise RuntimeError("Task Run Error: %s" % j["data"]["msg"])
|
||||
if "data" in j and j["data"]["progress"] < 100 and j["data"]["result"] == "":
|
||||
loguru.logger.info("Waiting Task Finish")
|
||||
elif "data" in j and j["data"]["progress"] == 100 and j["data"]["msg"] == "任务完成" and j["data"]["result"] != "":
|
||||
loguru.logger.info("Task Finished")
|
||||
return ("%s%s" % (heygem_temp_path, j["data"]["result"]),)
|
||||
else:
|
||||
loguru.logger.info("Unknown Status")
|
||||
else:
|
||||
loguru.logger.info("Get Task Status Failed")
|
||||
t -= 1
|
||||
sleep(5)
|
||||
except BaseException as e:
|
||||
raise RuntimeError(e)
|
||||
finally:
|
||||
try:
|
||||
os.remove(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid))
|
||||
os.remove(os.path.join(os.path.dirname(__file__),"%s.wav" % uid))
|
||||
except:
|
||||
pass
|
||||
Loading…
Reference in New Issue