ADD 增加判断逻辑
This commit is contained in:
parent
c1bf38a79b
commit
568b63d310
|
|
@ -1,4 +1,4 @@
|
|||
from .nodes.heygem import HeyGemF2F
|
||||
from .nodes.heygem import HeyGemF2F, HeyGemF2FFromFile
|
||||
from .nodes.s3 import S3Download, S3Upload
|
||||
from .nodes.text import *
|
||||
from .nodes.traverse_folder import TraverseFolder
|
||||
|
|
@ -32,6 +32,7 @@ NODE_CLASS_MAPPINGS = {
|
|||
"LoadTextCustom": LoadTextLocal,
|
||||
"LoadTextCustomOnline": LoadTextOnline,
|
||||
"HeyGemF2F": HeyGemF2F,
|
||||
"HeyGemF2FFromFile": HeyGemF2FFromFile,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
|
|
@ -52,5 +53,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||
"TraverseFolder": "遍历文件夹",
|
||||
"LoadTextCustom": "读取文本文件(本地)",
|
||||
"LoadTextCustomOnline": "读取文本文件(线上)",
|
||||
"HeyGemF2F": "HeyGem口型同步(API)",
|
||||
"HeyGemF2F": "HeyGem口型同步(API, 传入文件Tensor)",
|
||||
"HeyGemF2FFromFile": "HeyGem口型同步(API, 传入文件路径)"
|
||||
}
|
||||
|
|
|
|||
318
nodes/heygem.py
318
nodes/heygem.py
|
|
@ -1,11 +1,12 @@
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
import loguru
|
||||
import requests
|
||||
import torchaudio
|
||||
import torchvision
|
||||
from time import sleep
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
|
|
@ -37,6 +38,103 @@ class HeyGemF2F:
|
|||
return path
|
||||
|
||||
def f2f(self, video:Tensor, audio:dict, heygem_url:str, heygem_temp_path:str):
|
||||
def task_submit():
|
||||
"""Submit a task to the API"""
|
||||
task_submit_api = f'{heygem_url}/easy/submit'
|
||||
result_json = {
|
||||
'status': False, 'data': {}, 'msg': ''
|
||||
}
|
||||
try:
|
||||
code = str(uuid.uuid4())
|
||||
data = {
|
||||
"code": uid,
|
||||
"video_url": self.path_convert(os.path.join(os.path.dirname(__file__), "%s.mp4" % uid)),
|
||||
"audio_url": self.path_convert(os.path.join(os.path.dirname(__file__), "%s.wav" % uid)),
|
||||
"chaofen": 1,
|
||||
"watermark_switch": 0,
|
||||
"pn": 1
|
||||
}
|
||||
loguru.logger.info(f'data={data}')
|
||||
with httpx.Client() as client:
|
||||
resp = client.post(task_submit_api, json=data)
|
||||
resp_dict = resp.json()
|
||||
loguru.logger.info(f'submit data: {resp_dict}')
|
||||
if resp_dict['code'] != 10000:
|
||||
result_json['status'] = False
|
||||
result_json['msg'] = result_json['msg']
|
||||
else:
|
||||
result_json['status'] = True
|
||||
result_json['data'] = code
|
||||
result_json['msg'] = '任务提交成功'
|
||||
except Exception as e:
|
||||
loguru.logger.info(f'submit task fail case by:{str(e)}')
|
||||
raise RuntimeError(str(e))
|
||||
|
||||
return result_json
|
||||
|
||||
def query_task_progress(task_id: str, interval: int = 10, timeout: int = 60 * 15):
|
||||
"""Query task progress and wait for completion"""
|
||||
result_json = {'status': False, 'data': {}, 'msg': ''}
|
||||
|
||||
def query_result(t_id: str):
|
||||
tmp_dict = {'status': True, 'data': dict(), 'msg': ''}
|
||||
try:
|
||||
query_task_url = f'{heygem_url}/easy/query'
|
||||
params = {
|
||||
'code': t_id
|
||||
}
|
||||
with httpx.Client() as client:
|
||||
resp = client.get(query_task_url, params=params)
|
||||
resp_dict = resp.json()
|
||||
status_code = resp_dict['code']
|
||||
if status_code in (9999, 10002, 10003, 10001):
|
||||
tmp_dict['status'] = False
|
||||
tmp_dict['msg'] = resp_dict['msg']
|
||||
elif status_code == 10000:
|
||||
loguru.logger.info(f'query task data: {json.dumps(resp_dict)}')
|
||||
status_code = resp_dict['data'].get('status', 1)
|
||||
if status_code == 3:
|
||||
tmp_dict['status'] = False
|
||||
tmp_dict['msg'] = resp_dict['data']['msg']
|
||||
else:
|
||||
process = resp_dict['data'].get('progress', 20)
|
||||
if status_code == 2:
|
||||
process = 100
|
||||
else:
|
||||
process = process
|
||||
result = resp_dict['data'].get('result', '')
|
||||
tmp_dict['data'] = {'progress': process,
|
||||
'path': result,
|
||||
}
|
||||
else:
|
||||
pass
|
||||
except Exception as e:
|
||||
loguru.logger.info(f'query task fail case by:{str(e)}')
|
||||
raise RuntimeError(str(e))
|
||||
return tmp_dict
|
||||
|
||||
end = time.time() + timeout
|
||||
while time.time() < end:
|
||||
tmp_result = query_result(task_id)
|
||||
if not tmp_result['status'] or tmp_result['data'].__eq__({}):
|
||||
result_json['status'] = False
|
||||
result_json['msg'] = tmp_result['msg']
|
||||
break
|
||||
else:
|
||||
process = tmp_result['data']['progress']
|
||||
loguru.logger.info(f'query task progress :{process}')
|
||||
if tmp_result['data']['progress'] < 100:
|
||||
time.sleep(interval)
|
||||
loguru.logger.info(f'wait next interval:{interval}')
|
||||
else:
|
||||
p = tmp_result['data']['path']
|
||||
p = p.replace('/', '')
|
||||
result_json['data'] = os.path.join(heygem_temp_path, p)
|
||||
result_json['status'] = True
|
||||
return result_json
|
||||
|
||||
return result_json
|
||||
|
||||
uid = str(uuid.uuid4())
|
||||
try:
|
||||
try:
|
||||
|
|
@ -44,49 +142,185 @@ class HeyGemF2F:
|
|||
torchaudio.save(os.path.join(os.path.dirname(__file__),"%s.wav" % uid), audio["waveform"].squeeze(0), audio["sample_rate"], True)
|
||||
except:
|
||||
raise RuntimeError("Save Temp File Error! ")
|
||||
payload = {
|
||||
"code": uid,
|
||||
"video_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)),
|
||||
"audio_url": self.path_convert(os.path.join(os.path.dirname(__file__),"%s.wav" % uid)),
|
||||
"chaofen": 1,
|
||||
"watermark_switch": 0,
|
||||
"pn": 1
|
||||
}
|
||||
print(payload)
|
||||
r = requests.post(heygem_url+"/easy/submit", json=payload)
|
||||
if r.status_code != 200:
|
||||
raise RuntimeError("Request Error!")
|
||||
submit_result = task_submit()
|
||||
if not submit_result['status']:
|
||||
return {
|
||||
'status': False,
|
||||
'data': {},
|
||||
'msg': f"Task submission failed: {submit_result['msg']}"
|
||||
}
|
||||
|
||||
task_id = submit_result['data']
|
||||
loguru.logger.info(f'Submitted task: {task_id}')
|
||||
|
||||
# Query task progress
|
||||
progress_result = query_task_progress(task_id)
|
||||
|
||||
if not progress_result['status']:
|
||||
raise RuntimeError(f"Task processing failed: {progress_result['msg']}")
|
||||
|
||||
# Return the file for download
|
||||
file_path = progress_result['data']
|
||||
if os.path.exists(file_path):
|
||||
return (file_path,)
|
||||
else:
|
||||
r_json = r.json()
|
||||
if r_json["success"]:
|
||||
loguru.logger.info("Submit Task Success")
|
||||
else:
|
||||
raise RuntimeError("Submit Task Fail")
|
||||
t = 30
|
||||
while t>0:
|
||||
r = requests.get(heygem_url+"/easy/query?code="+uid)
|
||||
if r.status_code == 200:
|
||||
j = r.json()
|
||||
if "msg" in j and j["msg"]=="任务不存在":
|
||||
raise RuntimeError("Task Missing")
|
||||
if "data" in j and "异常" in j["data"]["msg"]:
|
||||
raise RuntimeError("Task Run Error: %s" % j["data"]["msg"])
|
||||
if "data" in j and j["data"]["progress"] < 100 and j["data"]["result"] == "":
|
||||
loguru.logger.info("Waiting Task Finish")
|
||||
elif "data" in j and j["data"]["progress"] == 100 and j["data"]["msg"] == "任务完成" and j["data"]["result"] != "":
|
||||
loguru.logger.info("Task Finished")
|
||||
return ("%s%s" % (heygem_temp_path, j["data"]["result"]),)
|
||||
else:
|
||||
loguru.logger.info("Unknown Status")
|
||||
else:
|
||||
loguru.logger.info("Get Task Status Failed")
|
||||
t -= 1
|
||||
sleep(5)
|
||||
except BaseException as e:
|
||||
raise RuntimeError(e)
|
||||
raise FileNotFoundError(f"Output file not found at {file_path}")
|
||||
except Exception as e:
|
||||
loguru.logger.error(f"Error processing request: {str(e)}")
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
try:
|
||||
os.remove(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid))
|
||||
os.remove(os.path.join(os.path.dirname(__file__),"%s.wav" % uid))
|
||||
except:
|
||||
pass
|
||||
pass
|
||||
|
||||
class HeyGemF2FFromFile:
|
||||
"""HeyGem 嘴型同步 直接读取文件"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"video": ("STRING", {"forceInput": True}),
|
||||
"audio": ("STRING", {"forceInput": True}),
|
||||
"heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}),
|
||||
"heygem_temp_path": ("STRING", {"default":"/code/data/temp"})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("视频存储路径",)
|
||||
FUNCTION = "f2f"
|
||||
CATEGORY = "不忘科技-自定义节点🚩"
|
||||
|
||||
def f2f(self, video:str, audio:str, heygem_url:str, heygem_temp_path:str):
|
||||
def task_submit():
|
||||
"""Submit a task to the API"""
|
||||
task_submit_api = f'{heygem_url}/easy/submit'
|
||||
result_json = {
|
||||
'status': False, 'data': {}, 'msg': ''
|
||||
}
|
||||
try:
|
||||
if not os.path.exists(video):
|
||||
raise RuntimeError(f"Video file not found at {video}")
|
||||
if not os.path.exists(audio):
|
||||
raise RuntimeError(f"Audio file not found at {audio}")
|
||||
code = str(uuid.uuid4())
|
||||
data = {
|
||||
"code": uid,
|
||||
"video_url": video,
|
||||
"audio_url": audio,
|
||||
"chaofen": 1,
|
||||
"watermark_switch": 0,
|
||||
"pn": 1
|
||||
}
|
||||
loguru.logger.info(f'data={data}')
|
||||
with httpx.Client() as client:
|
||||
resp = client.post(task_submit_api, json=data)
|
||||
resp_dict = resp.json()
|
||||
loguru.logger.info(f'submit data: {resp_dict}')
|
||||
if resp_dict['code'] != 10000:
|
||||
result_json['status'] = False
|
||||
result_json['msg'] = result_json['msg']
|
||||
else:
|
||||
result_json['status'] = True
|
||||
result_json['data'] = code
|
||||
result_json['msg'] = '任务提交成功'
|
||||
except Exception as e:
|
||||
loguru.logger.info(f'submit task fail case by:{str(e)}')
|
||||
raise RuntimeError(str(e))
|
||||
|
||||
return result_json
|
||||
|
||||
def query_task_progress(task_id: str, interval: int = 10, timeout: int = 60 * 15):
|
||||
"""Query task progress and wait for completion"""
|
||||
result_json = {'status': False, 'data': {}, 'msg': ''}
|
||||
|
||||
def query_result(t_id: str):
|
||||
tmp_dict = {'status': True, 'data': dict(), 'msg': ''}
|
||||
try:
|
||||
query_task_url = f'{heygem_url}/easy/query'
|
||||
params = {
|
||||
'code': t_id
|
||||
}
|
||||
with httpx.Client() as client:
|
||||
resp = client.get(query_task_url, params=params)
|
||||
resp_dict = resp.json()
|
||||
status_code = resp_dict['code']
|
||||
if status_code in (9999, 10002, 10003, 10001):
|
||||
tmp_dict['status'] = False
|
||||
tmp_dict['msg'] = resp_dict['msg']
|
||||
elif status_code == 10000:
|
||||
loguru.logger.info(f'query task data: {json.dumps(resp_dict)}')
|
||||
status_code = resp_dict['data'].get('status', 1)
|
||||
if status_code == 3:
|
||||
tmp_dict['status'] = False
|
||||
tmp_dict['msg'] = resp_dict['data']['msg']
|
||||
else:
|
||||
process = resp_dict['data'].get('progress', 20)
|
||||
if status_code == 2:
|
||||
process = 100
|
||||
else:
|
||||
process = process
|
||||
result = resp_dict['data'].get('result', '')
|
||||
tmp_dict['data'] = {'progress': process,
|
||||
'path': result,
|
||||
}
|
||||
else:
|
||||
pass
|
||||
except Exception as e:
|
||||
loguru.logger.info(f'query task fail case by:{str(e)}')
|
||||
raise RuntimeError(str(e))
|
||||
return tmp_dict
|
||||
|
||||
end = time.time() + timeout
|
||||
while time.time() < end:
|
||||
tmp_result = query_result(task_id)
|
||||
if not tmp_result['status'] or tmp_result['data'].__eq__({}):
|
||||
result_json['status'] = False
|
||||
result_json['msg'] = tmp_result['msg']
|
||||
break
|
||||
else:
|
||||
process = tmp_result['data']['progress']
|
||||
loguru.logger.info(f'query task progress :{process}')
|
||||
if tmp_result['data']['progress'] < 100:
|
||||
time.sleep(interval)
|
||||
loguru.logger.info(f'wait next interval:{interval}')
|
||||
else:
|
||||
p = tmp_result['data']['path']
|
||||
p = p.replace('/', '')
|
||||
result_json['data'] = os.path.join(heygem_temp_path, p)
|
||||
result_json['status'] = True
|
||||
return result_json
|
||||
|
||||
return result_json
|
||||
|
||||
uid = str(uuid.uuid4())
|
||||
try:
|
||||
submit_result = task_submit()
|
||||
if not submit_result['status']:
|
||||
return {
|
||||
'status': False,
|
||||
'data': {},
|
||||
'msg': f"Task submission failed: {submit_result['msg']}"
|
||||
}
|
||||
|
||||
task_id = submit_result['data']
|
||||
loguru.logger.info(f'Submitted task: {task_id}')
|
||||
|
||||
# Query task progress
|
||||
progress_result = query_task_progress(task_id)
|
||||
|
||||
if not progress_result['status']:
|
||||
raise RuntimeError(f"Task processing failed: {progress_result['msg']}")
|
||||
|
||||
# Return the file for download
|
||||
file_path = progress_result['data']
|
||||
if os.path.exists(file_path):
|
||||
return (file_path,)
|
||||
else:
|
||||
raise FileNotFoundError(f"Output file not found at {file_path}")
|
||||
except Exception as e:
|
||||
loguru.logger.error(f"Error processing request: {str(e)}")
|
||||
raise Exception(str(e))
|
||||
Loading…
Reference in New Issue