import os import uuid import loguru import requests import torchaudio import torchvision from time import sleep from torch import Tensor class HeyGemF2F: """HeyGem 嘴型同步""" @classmethod def INPUT_TYPES(s): return { "required": { "video": ("IMAGE", {"forceInput": True}), "audio": ("AUDIO", {"forceInput": True}), "heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}), "heygem_temp_path": ("STRING", {"default":"/code/data/temp"}) } } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("视频存储路径",) FUNCTION = "f2f" CATEGORY = "不忘科技-自定义节点🚩" def path_convert(self,path): if ":" in path: path = path.replace(os.sep,"/").split(":") path[0] = path[0].lower() path[1] = path[1][1:] path = "/".join(["/mnt",*path]) return path def f2f(self, video:Tensor, audio:dict, heygem_url:str, heygem_temp_path:str): uid = str(uuid.uuid4()) try: try: torchvision.io.write_video(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid), video.mul_(255).int(),25) torchaudio.save(os.path.join(os.path.dirname(__file__),"%s.wav" % uid), audio["waveform"].squeeze(0), audio["sample_rate"], True) except: raise RuntimeError("Save Temp File Error! ") payload = { "code": uid, "video_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)), "audio_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)), "chaofen": 1, "watermark_switch": 0, "pn": 1 } print(payload) r = requests.post(heygem_url+"/easy/submit", json=payload) if r.status_code != 200: raise RuntimeError("Request Error!") else: r_json = r.json() if r_json["success"]: loguru.logger.info("Submit Task Success") else: raise RuntimeError("Submit Task Fail") t = 30 while t>0: r = requests.get(heygem_url+"/easy/query?code="+uid) if r.status_code == 200: j = r.json() if "msg" in j and j["msg"]=="任务不存在": raise RuntimeError("Task Missing") if "data" in j and "异常" in j["data"]["msg"]: raise RuntimeError("Task Run Error: %s" % j["data"]["msg"]) if "data" in j and j["data"]["progress"] < 100 and j["data"]["result"] == "": loguru.logger.info("Waiting Task Finish") elif "data" in j and j["data"]["progress"] == 100 and j["data"]["msg"] == "任务完成" and j["data"]["result"] != "": loguru.logger.info("Task Finished") return ("%s%s" % (heygem_temp_path, j["data"]["result"]),) else: loguru.logger.info("Unknown Status") else: loguru.logger.info("Get Task Status Failed") t -= 1 sleep(5) except BaseException as e: raise RuntimeError(e) finally: try: os.remove(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)) os.remove(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)) except: pass