diff --git a/__init__.py b/__init__.py index dabdb9a..bb9688f 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,4 @@ -from .nodes.heygem import HeyGemF2F +from .nodes.heygem import HeyGemF2F, HeyGemF2FFromFile from .nodes.s3 import S3Download, S3Upload from .nodes.text import * from .nodes.traverse_folder import TraverseFolder @@ -32,6 +32,7 @@ NODE_CLASS_MAPPINGS = { "LoadTextCustom": LoadTextLocal, "LoadTextCustomOnline": LoadTextOnline, "HeyGemF2F": HeyGemF2F, + "HeyGemF2FFromFile": HeyGemF2FFromFile, } # A dictionary that contains the friendly/humanly readable titles for the nodes @@ -52,5 +53,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "TraverseFolder": "遍历文件夹", "LoadTextCustom": "读取文本文件(本地)", "LoadTextCustomOnline": "读取文本文件(线上)", - "HeyGemF2F": "HeyGem口型同步(API)", + "HeyGemF2F": "HeyGem口型同步(API, 传入文件Tensor)", + "HeyGemF2FFromFile": "HeyGem口型同步(API, 传入文件路径)" } diff --git a/nodes/heygem.py b/nodes/heygem.py index 00fd998..9f4eabf 100644 --- a/nodes/heygem.py +++ b/nodes/heygem.py @@ -1,11 +1,12 @@ +import json import os +import time import uuid +import httpx import loguru -import requests import torchaudio import torchvision -from time import sleep from torch import Tensor @@ -37,6 +38,103 @@ class HeyGemF2F: return path def f2f(self, video:Tensor, audio:dict, heygem_url:str, heygem_temp_path:str): + def task_submit(): + """Submit a task to the API""" + task_submit_api = f'{heygem_url}/easy/submit' + result_json = { + 'status': False, 'data': {}, 'msg': '' + } + try: + code = str(uuid.uuid4()) + data = { + "code": uid, + "video_url": self.path_convert(os.path.join(os.path.dirname(__file__), "%s.mp4" % uid)), + "audio_url": self.path_convert(os.path.join(os.path.dirname(__file__), "%s.wav" % uid)), + "chaofen": 1, + "watermark_switch": 0, + "pn": 1 + } + loguru.logger.info(f'data={data}') + with httpx.Client() as client: + resp = client.post(task_submit_api, json=data) + resp_dict = resp.json() + loguru.logger.info(f'submit data: {resp_dict}') + if resp_dict['code'] != 10000: + result_json['status'] = False + result_json['msg'] = result_json['msg'] + else: + result_json['status'] = True + result_json['data'] = code + result_json['msg'] = '任务提交成功' + except Exception as e: + loguru.logger.info(f'submit task fail case by:{str(e)}') + raise RuntimeError(str(e)) + + return result_json + + def query_task_progress(task_id: str, interval: int = 10, timeout: int = 60 * 15): + """Query task progress and wait for completion""" + result_json = {'status': False, 'data': {}, 'msg': ''} + + def query_result(t_id: str): + tmp_dict = {'status': True, 'data': dict(), 'msg': ''} + try: + query_task_url = f'{heygem_url}/easy/query' + params = { + 'code': t_id + } + with httpx.Client() as client: + resp = client.get(query_task_url, params=params) + resp_dict = resp.json() + status_code = resp_dict['code'] + if status_code in (9999, 10002, 10003, 10001): + tmp_dict['status'] = False + tmp_dict['msg'] = resp_dict['msg'] + elif status_code == 10000: + loguru.logger.info(f'query task data: {json.dumps(resp_dict)}') + status_code = resp_dict['data'].get('status', 1) + if status_code == 3: + tmp_dict['status'] = False + tmp_dict['msg'] = resp_dict['data']['msg'] + else: + process = resp_dict['data'].get('progress', 20) + if status_code == 2: + process = 100 + else: + process = process + result = resp_dict['data'].get('result', '') + tmp_dict['data'] = {'progress': process, + 'path': result, + } + else: + pass + except Exception as e: + loguru.logger.info(f'query task fail case by:{str(e)}') + raise RuntimeError(str(e)) + return tmp_dict + + end = time.time() + timeout + while time.time() < end: + tmp_result = query_result(task_id) + if not tmp_result['status'] or tmp_result['data'].__eq__({}): + result_json['status'] = False + result_json['msg'] = tmp_result['msg'] + break + else: + process = tmp_result['data']['progress'] + loguru.logger.info(f'query task progress :{process}') + if tmp_result['data']['progress'] < 100: + time.sleep(interval) + loguru.logger.info(f'wait next interval:{interval}') + else: + p = tmp_result['data']['path'] + p = p.replace('/', '') + result_json['data'] = os.path.join(heygem_temp_path, p) + result_json['status'] = True + return result_json + + return result_json + uid = str(uuid.uuid4()) try: try: @@ -44,49 +142,185 @@ class HeyGemF2F: torchaudio.save(os.path.join(os.path.dirname(__file__),"%s.wav" % uid), audio["waveform"].squeeze(0), audio["sample_rate"], True) except: raise RuntimeError("Save Temp File Error! ") - payload = { - "code": uid, - "video_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)), - "audio_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)), - "chaofen": 1, - "watermark_switch": 0, - "pn": 1 - } - print(payload) - r = requests.post(heygem_url+"/easy/submit", json=payload) - if r.status_code != 200: - raise RuntimeError("Request Error!") + submit_result = task_submit() + if not submit_result['status']: + return { + 'status': False, + 'data': {}, + 'msg': f"Task submission failed: {submit_result['msg']}" + } + + task_id = submit_result['data'] + loguru.logger.info(f'Submitted task: {task_id}') + + # Query task progress + progress_result = query_task_progress(task_id) + + if not progress_result['status']: + raise RuntimeError(f"Task processing failed: {progress_result['msg']}") + + # Return the file for download + file_path = progress_result['data'] + if os.path.exists(file_path): + return (file_path,) else: - r_json = r.json() - if r_json["success"]: - loguru.logger.info("Submit Task Success") - else: - raise RuntimeError("Submit Task Fail") - t = 30 - while t>0: - r = requests.get(heygem_url+"/easy/query?code="+uid) - if r.status_code == 200: - j = r.json() - if "msg" in j and j["msg"]=="任务不存在": - raise RuntimeError("Task Missing") - if "data" in j and "异常" in j["data"]["msg"]: - raise RuntimeError("Task Run Error: %s" % j["data"]["msg"]) - if "data" in j and j["data"]["progress"] < 100 and j["data"]["result"] == "": - loguru.logger.info("Waiting Task Finish") - elif "data" in j and j["data"]["progress"] == 100 and j["data"]["msg"] == "任务完成" and j["data"]["result"] != "": - loguru.logger.info("Task Finished") - return ("%s%s" % (heygem_temp_path, j["data"]["result"]),) - else: - loguru.logger.info("Unknown Status") - else: - loguru.logger.info("Get Task Status Failed") - t -= 1 - sleep(5) - except BaseException as e: - raise RuntimeError(e) + raise FileNotFoundError(f"Output file not found at {file_path}") + except Exception as e: + loguru.logger.error(f"Error processing request: {str(e)}") + raise Exception(str(e)) finally: try: os.remove(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)) os.remove(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)) except: - pass \ No newline at end of file + pass + +class HeyGemF2FFromFile: + """HeyGem 嘴型同步 直接读取文件""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "video": ("STRING", {"forceInput": True}), + "audio": ("STRING", {"forceInput": True}), + "heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}), + "heygem_temp_path": ("STRING", {"default":"/code/data/temp"}) + } + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("视频存储路径",) + FUNCTION = "f2f" + CATEGORY = "不忘科技-自定义节点🚩" + + def f2f(self, video:str, audio:str, heygem_url:str, heygem_temp_path:str): + def task_submit(): + """Submit a task to the API""" + task_submit_api = f'{heygem_url}/easy/submit' + result_json = { + 'status': False, 'data': {}, 'msg': '' + } + try: + if not os.path.exists(video): + raise RuntimeError(f"Video file not found at {video}") + if not os.path.exists(audio): + raise RuntimeError(f"Audio file not found at {audio}") + code = str(uuid.uuid4()) + data = { + "code": uid, + "video_url": video, + "audio_url": audio, + "chaofen": 1, + "watermark_switch": 0, + "pn": 1 + } + loguru.logger.info(f'data={data}') + with httpx.Client() as client: + resp = client.post(task_submit_api, json=data) + resp_dict = resp.json() + loguru.logger.info(f'submit data: {resp_dict}') + if resp_dict['code'] != 10000: + result_json['status'] = False + result_json['msg'] = result_json['msg'] + else: + result_json['status'] = True + result_json['data'] = code + result_json['msg'] = '任务提交成功' + except Exception as e: + loguru.logger.info(f'submit task fail case by:{str(e)}') + raise RuntimeError(str(e)) + + return result_json + + def query_task_progress(task_id: str, interval: int = 10, timeout: int = 60 * 15): + """Query task progress and wait for completion""" + result_json = {'status': False, 'data': {}, 'msg': ''} + + def query_result(t_id: str): + tmp_dict = {'status': True, 'data': dict(), 'msg': ''} + try: + query_task_url = f'{heygem_url}/easy/query' + params = { + 'code': t_id + } + with httpx.Client() as client: + resp = client.get(query_task_url, params=params) + resp_dict = resp.json() + status_code = resp_dict['code'] + if status_code in (9999, 10002, 10003, 10001): + tmp_dict['status'] = False + tmp_dict['msg'] = resp_dict['msg'] + elif status_code == 10000: + loguru.logger.info(f'query task data: {json.dumps(resp_dict)}') + status_code = resp_dict['data'].get('status', 1) + if status_code == 3: + tmp_dict['status'] = False + tmp_dict['msg'] = resp_dict['data']['msg'] + else: + process = resp_dict['data'].get('progress', 20) + if status_code == 2: + process = 100 + else: + process = process + result = resp_dict['data'].get('result', '') + tmp_dict['data'] = {'progress': process, + 'path': result, + } + else: + pass + except Exception as e: + loguru.logger.info(f'query task fail case by:{str(e)}') + raise RuntimeError(str(e)) + return tmp_dict + + end = time.time() + timeout + while time.time() < end: + tmp_result = query_result(task_id) + if not tmp_result['status'] or tmp_result['data'].__eq__({}): + result_json['status'] = False + result_json['msg'] = tmp_result['msg'] + break + else: + process = tmp_result['data']['progress'] + loguru.logger.info(f'query task progress :{process}') + if tmp_result['data']['progress'] < 100: + time.sleep(interval) + loguru.logger.info(f'wait next interval:{interval}') + else: + p = tmp_result['data']['path'] + p = p.replace('/', '') + result_json['data'] = os.path.join(heygem_temp_path, p) + result_json['status'] = True + return result_json + + return result_json + + uid = str(uuid.uuid4()) + try: + submit_result = task_submit() + if not submit_result['status']: + return { + 'status': False, + 'data': {}, + 'msg': f"Task submission failed: {submit_result['msg']}" + } + + task_id = submit_result['data'] + loguru.logger.info(f'Submitted task: {task_id}') + + # Query task progress + progress_result = query_task_progress(task_id) + + if not progress_result['status']: + raise RuntimeError(f"Task processing failed: {progress_result['msg']}") + + # Return the file for download + file_path = progress_result['data'] + if os.path.exists(file_path): + return (file_path,) + else: + raise FileNotFoundError(f"Output file not found at {file_path}") + except Exception as e: + loguru.logger.error(f"Error processing request: {str(e)}") + raise Exception(str(e)) \ No newline at end of file