import json import os import time import traceback import uuid import httpx import loguru import torchaudio import torchvision from torch import Tensor def task_submit(uid, video, audio, heygem_url): """Submit a task to the API""" task_submit_api = f"{heygem_url}/easy/submit" result_json = {"status": False, "data": {}, "msg": ""} try: data = { "code": uid, "video_url": video, "audio_url": audio, "chaofen": 1, "watermark_switch": 0, "pn": 1, } loguru.logger.info(f"data={data}") with httpx.Client() as client: resp = client.post(task_submit_api, json=data) resp_dict = resp.json() loguru.logger.info(f"submit data: {resp_dict}") if resp_dict["code"] != 10000: result_json["status"] = False result_json["msg"] = result_json["msg"] else: result_json["status"] = True result_json["data"] = uid result_json["msg"] = "任务提交成功" except Exception as e: loguru.logger.info(f"submit task fail case by:{str(e)}") raise RuntimeError(str(e)) return result_json def query_task_progress( heygem_url, heygem_temp_path, task_id: str, interval: int = 10, timeout: int = 60 * 15, ): """Query task progress and wait for completion""" result_json = {"status": False, "data": {}, "msg": ""} def query_result(t_id: str): tmp_dict = {"status": True, "data": dict(), "msg": ""} try: query_task_url = f"{heygem_url}/easy/query" params = {"code": t_id} with httpx.Client() as client: resp = client.get(query_task_url, params=params) resp_dict = resp.json() status_code = resp_dict["code"] if status_code in (9999, 10002, 10003, 10001): tmp_dict["status"] = False tmp_dict["msg"] = resp_dict["msg"] elif status_code == 10000: loguru.logger.info(f"query task data: {json.dumps(resp_dict)}") status_code = resp_dict["data"].get("status", 1) if status_code == 3: tmp_dict["status"] = False tmp_dict["msg"] = resp_dict["data"]["msg"] else: process = resp_dict["data"].get("progress", 20) if status_code == 2: process = 100 else: process = process result = resp_dict["data"].get("result", "") tmp_dict["data"] = { "progress": process, "path": result, } else: pass except Exception as e: loguru.logger.info(f"query task fail case by:{str(e)}") raise RuntimeError(str(e)) return tmp_dict end = time.time() + timeout while time.time() < end: tmp_result = query_result(task_id) if not tmp_result["status"] or tmp_result["data"].__eq__({}): result_json["status"] = False result_json["msg"] = tmp_result["msg"] break else: process = tmp_result["data"]["progress"] loguru.logger.info(f"query task progress :{process}") if tmp_result["data"]["progress"] < 100: time.sleep(interval) loguru.logger.info(f"wait next interval:{interval}") else: p = tmp_result["data"]["path"] p = p.replace("/", "").replace("\\", "") result_json["data"] = "%s/%s" % (heygem_temp_path, p) result_json["status"] = True return result_json return result_json def path_convert(path): if ":" in path: path = path.replace(os.sep, "/").split(":") path[0] = path[0].lower() path[1] = path[1][1:] path = "/".join(["/mnt", *path]) return path def result_path_convert(result_path: str): if result_path.startswith("/"): result_path = result_path.replace("/", "\\") result_path = r"\\wsl.localhost\Debian" + result_path return result_path class HeyGemF2F: """HeyGem 嘴型同步""" @classmethod def INPUT_TYPES(s): return { "required": { "video": ("IMAGE", {"forceInput": True}), "audio": ("AUDIO", {"forceInput": True}), "heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}), "heygem_temp_path": ("STRING", {"default": "/code/data/temp"}), "is_Windows": ("BOOLEAN", {"default": False}), } } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("视频存储路径",) FUNCTION = "f2f" CATEGORY = "不忘科技-自定义节点🚩/视频/口型" def f2f( self, video: Tensor, audio: dict, heygem_url: str, heygem_temp_path: str, is_Windows: bool, ): uid = str(uuid.uuid4()) video_path = os.path.join(os.path.dirname(__file__), "%s.mp4" % uid) audio_path = os.path.join(os.path.dirname(__file__), "%s.wav" % uid) try: try: torchvision.io.write_video(video_path, video.mul_(255).int(), 25) torchaudio.save( audio_path, audio["waveform"].squeeze(0), audio["sample_rate"], True ) except: traceback.print_exc() raise RuntimeError("Save Temp File Error! ") submit_result = task_submit( uid, path_convert(video_path), path_convert(audio_path), heygem_url ) if not submit_result["status"]: return { "status": False, "data": {}, "msg": f"Task submission failed: {submit_result['msg']}", } task_id = submit_result["data"] loguru.logger.info(f"Submitted task: {task_id}") # Query task progress progress_result = query_task_progress( heygem_url, heygem_temp_path, task_id, interval=5 ) if not progress_result["status"]: raise RuntimeError(f"Task processing failed: {progress_result['msg']}") # Return the file for download file_path = progress_result["data"] if is_Windows: file_path = result_path_convert(file_path) if os.path.exists(file_path): return (file_path,) else: raise FileNotFoundError(f"Output file not found at {file_path}") except Exception as e: loguru.logger.error(f"Error processing request: {str(e)}") raise Exception(str(e)) finally: try: os.remove(os.path.join(os.path.dirname(__file__), "%s.mp4" % uid)) os.remove(os.path.join(os.path.dirname(__file__), "%s.wav" % uid)) except: pass class HeyGemF2FFromFile: """HeyGem 嘴型同步 直接读取文件""" @classmethod def INPUT_TYPES(s): return { "required": { "video": ("STRING", {"forceInput": True}), "audio": ("STRING", {"forceInput": True}), "heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}), "heygem_temp_path": ("STRING", {"default": "/code/data/temp"}), "is_Windows": ("BOOLEAN", {"default": False}), } } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("视频存储路径",) FUNCTION = "f2f" CATEGORY = "不忘科技-自定义节点🚩/视频/口型" def f2f( self, video: str, audio: str, heygem_url: str, heygem_temp_path: str, is_Windows: bool, ): uid = str(uuid.uuid4()) try: submit_result = task_submit(uid, video, audio, heygem_url) if not submit_result["status"]: return { "status": False, "data": {}, "msg": f"Task submission failed: {submit_result['msg']}", } task_id = submit_result["data"] loguru.logger.info(f"Submitted task: {task_id}") # Query task progress progress_result = query_task_progress(heygem_url, heygem_temp_path, task_id) if not progress_result["status"]: raise RuntimeError(f"Task processing failed: {progress_result['msg']}") # Return the file for download file_path = progress_result["data"] if is_Windows: file_path = result_path_convert(file_path) if os.path.exists(file_path): return (file_path,) else: raise FileNotFoundError(f"Output file not found at {file_path}") except Exception as e: loguru.logger.error(f"Error processing request: {str(e)}") raise Exception(str(e))