ComfyUI-CustomNode/nodes/heygem.py

92 lines
3.6 KiB
Python

import os
import uuid
import loguru
import requests
import torchaudio
import torchvision
from time import sleep
from torch import Tensor
class HeyGemF2F:
"""HeyGem 嘴型同步"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"video": ("IMAGE", {"forceInput": True}),
"audio": ("AUDIO", {"forceInput": True}),
"heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}),
"heygem_temp_path": ("STRING", {"default":"/code/data/temp"})
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("视频存储路径",)
FUNCTION = "f2f"
CATEGORY = "不忘科技-自定义节点🚩"
def path_convert(self,path):
if ":" in path:
path = path.replace(os.sep,"/").split(":")
path[0] = path[0].lower()
path[1] = path[1][1:]
path = "/".join(["/mnt",*path])
return path
def f2f(self, video:Tensor, audio:dict, heygem_url:str, heygem_temp_path:str):
uid = str(uuid.uuid4())
try:
try:
torchvision.io.write_video(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid), video.mul_(255).int(),25)
torchaudio.save(os.path.join(os.path.dirname(__file__),"%s.wav" % uid), audio["waveform"].squeeze(0), audio["sample_rate"], True)
except:
raise RuntimeError("Save Temp File Error! ")
payload = {
"code": uid,
"video_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)),
"audio_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)),
"chaofen": 1,
"watermark_switch": 0,
"pn": 1
}
print(payload)
r = requests.post(heygem_url+"/easy/submit", json=payload)
if r.status_code != 200:
raise RuntimeError("Request Error!")
else:
r_json = r.json()
if r_json["success"]:
loguru.logger.info("Submit Task Success")
else:
raise RuntimeError("Submit Task Fail")
t = 30
while t>0:
r = requests.get(heygem_url+"/easy/query?code="+uid)
if r.status_code == 200:
j = r.json()
if "msg" in j and j["msg"]=="任务不存在":
raise RuntimeError("Task Missing")
if "data" in j and "异常" in j["data"]["msg"]:
raise RuntimeError("Task Run Error: %s" % j["data"]["msg"])
if "data" in j and j["data"]["progress"] < 100 and j["data"]["result"] == "":
loguru.logger.info("Waiting Task Finish")
elif "data" in j and j["data"]["progress"] == 100 and j["data"]["msg"] == "任务完成" and j["data"]["result"] != "":
loguru.logger.info("Task Finished")
return ("%s%s" % (heygem_temp_path, j["data"]["result"]),)
else:
loguru.logger.info("Unknown Status")
else:
loguru.logger.info("Get Task Status Failed")
t -= 1
sleep(5)
except BaseException as e:
raise RuntimeError(e)
finally:
try:
os.remove(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid))
os.remove(os.path.join(os.path.dirname(__file__),"%s.wav" % uid))
except:
pass