92 lines
3.6 KiB
Python
92 lines
3.6 KiB
Python
import os
|
|
import uuid
|
|
|
|
import loguru
|
|
import requests
|
|
import torchaudio
|
|
import torchvision
|
|
from time import sleep
|
|
from torch import Tensor
|
|
|
|
|
|
class HeyGemF2F:
|
|
"""HeyGem 嘴型同步"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"video": ("IMAGE", {"forceInput": True}),
|
|
"audio": ("AUDIO", {"forceInput": True}),
|
|
"heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}),
|
|
"heygem_temp_path": ("STRING", {"default":"/code/data/temp"})
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
RETURN_NAMES = ("视频存储路径",)
|
|
FUNCTION = "f2f"
|
|
CATEGORY = "不忘科技-自定义节点🚩"
|
|
|
|
def path_convert(self,path):
|
|
if ":" in path:
|
|
path = path.replace(os.sep,"/").split(":")
|
|
path[0] = path[0].lower()
|
|
path[1] = path[1][1:]
|
|
path = "/".join(["/mnt",*path])
|
|
return path
|
|
|
|
def f2f(self, video:Tensor, audio:dict, heygem_url:str, heygem_temp_path:str):
|
|
uid = str(uuid.uuid4())
|
|
try:
|
|
try:
|
|
torchvision.io.write_video(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid), video.mul_(255).int(),25)
|
|
torchaudio.save(os.path.join(os.path.dirname(__file__),"%s.wav" % uid), audio["waveform"].squeeze(0), audio["sample_rate"], True)
|
|
except:
|
|
raise RuntimeError("Save Temp File Error! ")
|
|
payload = {
|
|
"code": uid,
|
|
"video_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)),
|
|
"audio_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)),
|
|
"chaofen": 1,
|
|
"watermark_switch": 0,
|
|
"pn": 1
|
|
}
|
|
print(payload)
|
|
r = requests.post(heygem_url+"/easy/submit", json=payload)
|
|
if r.status_code != 200:
|
|
raise RuntimeError("Request Error!")
|
|
else:
|
|
r_json = r.json()
|
|
if r_json["success"]:
|
|
loguru.logger.info("Submit Task Success")
|
|
else:
|
|
raise RuntimeError("Submit Task Fail")
|
|
t = 30
|
|
while t>0:
|
|
r = requests.get(heygem_url+"/easy/query?code="+uid)
|
|
if r.status_code == 200:
|
|
j = r.json()
|
|
if "msg" in j and j["msg"]=="任务不存在":
|
|
raise RuntimeError("Task Missing")
|
|
if "data" in j and "异常" in j["data"]["msg"]:
|
|
raise RuntimeError("Task Run Error: %s" % j["data"]["msg"])
|
|
if "data" in j and j["data"]["progress"] < 100 and j["data"]["result"] == "":
|
|
loguru.logger.info("Waiting Task Finish")
|
|
elif "data" in j and j["data"]["progress"] == 100 and j["data"]["msg"] == "任务完成" and j["data"]["result"] != "":
|
|
loguru.logger.info("Task Finished")
|
|
return ("%s%s" % (heygem_temp_path, j["data"]["result"]),)
|
|
else:
|
|
loguru.logger.info("Unknown Status")
|
|
else:
|
|
loguru.logger.info("Get Task Status Failed")
|
|
t -= 1
|
|
sleep(5)
|
|
except BaseException as e:
|
|
raise RuntimeError(e)
|
|
finally:
|
|
try:
|
|
os.remove(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid))
|
|
os.remove(os.path.join(os.path.dirname(__file__),"%s.wav" % uid))
|
|
except:
|
|
pass |