import json import os import time import traceback import uuid import httpx import loguru import torchaudio import torchvision from torch import Tensor def task_submit(uid, video, audio, heygem_url): """Submit a task to the API""" task_submit_api = f'{heygem_url}/easy/submit' result_json = { 'status': False, 'data': {}, 'msg': '' } try: data = { "code": uid, "video_url": video, "audio_url": audio, "chaofen": 1, "watermark_switch": 0, "pn": 1 } loguru.logger.info(f'data={data}') with httpx.Client() as client: resp = client.post(task_submit_api, json=data) resp_dict = resp.json() loguru.logger.info(f'submit data: {resp_dict}') if resp_dict['code'] != 10000: result_json['status'] = False result_json['msg'] = result_json['msg'] else: result_json['status'] = True result_json['data'] = uid result_json['msg'] = '任务提交成功' except Exception as e: loguru.logger.info(f'submit task fail case by:{str(e)}') raise RuntimeError(str(e)) return result_json def query_task_progress(heygem_url, heygem_temp_path, task_id: str, interval: int = 10, timeout: int = 60 * 15): """Query task progress and wait for completion""" result_json = {'status': False, 'data': {}, 'msg': ''} def query_result(t_id: str): tmp_dict = {'status': True, 'data': dict(), 'msg': ''} try: query_task_url = f'{heygem_url}/easy/query' params = { 'code': t_id } with httpx.Client() as client: resp = client.get(query_task_url, params=params) resp_dict = resp.json() status_code = resp_dict['code'] if status_code in (9999, 10002, 10003, 10001): tmp_dict['status'] = False tmp_dict['msg'] = resp_dict['msg'] elif status_code == 10000: loguru.logger.info(f'query task data: {json.dumps(resp_dict)}') status_code = resp_dict['data'].get('status', 1) if status_code == 3: tmp_dict['status'] = False tmp_dict['msg'] = resp_dict['data']['msg'] else: process = resp_dict['data'].get('progress', 20) if status_code == 2: process = 100 else: process = process result = resp_dict['data'].get('result', '') tmp_dict['data'] = {'progress': process, 'path': result, } else: pass except Exception as e: loguru.logger.info(f'query task fail case by:{str(e)}') raise RuntimeError(str(e)) return tmp_dict end = time.time() + timeout while time.time() < end: tmp_result = query_result(task_id) if not tmp_result['status'] or tmp_result['data'].__eq__({}): result_json['status'] = False result_json['msg'] = tmp_result['msg'] break else: process = tmp_result['data']['progress'] loguru.logger.info(f'query task progress :{process}') if tmp_result['data']['progress'] < 100: time.sleep(interval) loguru.logger.info(f'wait next interval:{interval}') else: p = tmp_result['data']['path'] p = p.replace('/', '').replace('\\', '') result_json['data'] = "%s/%s"%(heygem_temp_path, p) result_json['status'] = True return result_json return result_json def path_convert(path): if ":" in path: path = path.replace(os.sep,"/").split(":") path[0] = path[0].lower() path[1] = path[1][1:] path = "/".join(["/mnt",*path]) return path def result_path_convert(result_path:str): if result_path.startswith("/"): result_path = result_path.replace("/","\\") result_path = r"\\wsl.localhost\Debian" + result_path return result_path class HeyGemF2F: """HeyGem 嘴型同步""" @classmethod def INPUT_TYPES(s): return { "required": { "video": ("IMAGE", {"forceInput": True}), "audio": ("AUDIO", {"forceInput": True}), "heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}), "heygem_temp_path": ("STRING", {"default":"/code/data/temp"}), "is_Windows": ("BOOLEAN", {"default": False}) } } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("视频存储路径",) FUNCTION = "f2f" CATEGORY = "不忘科技-自定义节点🚩/视频/口型" def f2f(self, video:Tensor, audio:dict, heygem_url:str, heygem_temp_path:str, is_Windows:bool): uid = str(uuid.uuid4()) video_path = os.path.join(os.path.dirname(__file__),"%s.mp4" % uid) audio_path = os.path.join(os.path.dirname(__file__),"%s.wav" % uid) try: try: torchvision.io.write_video(video_path, video.mul_(255).int(),25) torchaudio.save(audio_path, audio["waveform"].squeeze(0), audio["sample_rate"], True) except: traceback.print_exc() raise RuntimeError("Save Temp File Error! ") submit_result = task_submit(uid, path_convert(video_path), path_convert(audio_path), heygem_url) if not submit_result['status']: return { 'status': False, 'data': {}, 'msg': f"Task submission failed: {submit_result['msg']}" } task_id = submit_result['data'] loguru.logger.info(f'Submitted task: {task_id}') # Query task progress progress_result = query_task_progress(heygem_url, heygem_temp_path, task_id, interval=5) if not progress_result['status']: raise RuntimeError(f"Task processing failed: {progress_result['msg']}") # Return the file for download file_path = progress_result['data'] if is_Windows: file_path = result_path_convert(file_path) if os.path.exists(file_path): return (file_path,) else: raise FileNotFoundError(f"Output file not found at {file_path}") except Exception as e: loguru.logger.error(f"Error processing request: {str(e)}") raise Exception(str(e)) finally: try: os.remove(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)) os.remove(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)) except: pass class HeyGemF2FFromFile: """HeyGem 嘴型同步 直接读取文件""" @classmethod def INPUT_TYPES(s): return { "required": { "video": ("STRING", {"forceInput": True}), "audio": ("STRING", {"forceInput": True}), "heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}), "heygem_temp_path": ("STRING", {"default":"/code/data/temp"}), "is_Windows": ("BOOLEAN", {"default": False}) } } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("视频存储路径",) FUNCTION = "f2f" CATEGORY = "不忘科技-自定义节点🚩/视频/口型" def f2f(self, video:str, audio:str, heygem_url:str, heygem_temp_path:str, is_Windows:bool): uid = str(uuid.uuid4()) try: submit_result = task_submit(uid, video, audio, heygem_url) if not submit_result['status']: return { 'status': False, 'data': {}, 'msg': f"Task submission failed: {submit_result['msg']}" } task_id = submit_result['data'] loguru.logger.info(f'Submitted task: {task_id}') # Query task progress progress_result = query_task_progress(heygem_url,heygem_temp_path,task_id) if not progress_result['status']: raise RuntimeError(f"Task processing failed: {progress_result['msg']}") # Return the file for download file_path = progress_result['data'] if is_Windows: file_path = result_path_convert(file_path) if os.path.exists(file_path): return (file_path,) else: raise FileNotFoundError(f"Output file not found at {file_path}") except Exception as e: loguru.logger.error(f"Error processing request: {str(e)}") raise Exception(str(e))