ComfyUI-CustomNode/nodes/video_lipsync_nodes.py

241 lines
8.9 KiB
Python

import json
import os
import time
import traceback
import uuid
import httpx
import loguru
import torchaudio
import torchvision
from torch import Tensor
def task_submit(uid, video, audio, heygem_url):
"""Submit a task to the API"""
task_submit_api = f'{heygem_url}/easy/submit'
result_json = {
'status': False, 'data': {}, 'msg': ''
}
try:
data = {
"code": uid,
"video_url": video,
"audio_url": audio,
"chaofen": 1,
"watermark_switch": 0,
"pn": 1
}
loguru.logger.info(f'data={data}')
with httpx.Client() as client:
resp = client.post(task_submit_api, json=data)
resp_dict = resp.json()
loguru.logger.info(f'submit data: {resp_dict}')
if resp_dict['code'] != 10000:
result_json['status'] = False
result_json['msg'] = result_json['msg']
else:
result_json['status'] = True
result_json['data'] = uid
result_json['msg'] = '任务提交成功'
except Exception as e:
loguru.logger.info(f'submit task fail case by:{str(e)}')
raise RuntimeError(str(e))
return result_json
def query_task_progress(heygem_url, heygem_temp_path, task_id: str, interval: int = 10, timeout: int = 60 * 15):
"""Query task progress and wait for completion"""
result_json = {'status': False, 'data': {}, 'msg': ''}
def query_result(t_id: str):
tmp_dict = {'status': True, 'data': dict(), 'msg': ''}
try:
query_task_url = f'{heygem_url}/easy/query'
params = {
'code': t_id
}
with httpx.Client() as client:
resp = client.get(query_task_url, params=params)
resp_dict = resp.json()
status_code = resp_dict['code']
if status_code in (9999, 10002, 10003, 10001):
tmp_dict['status'] = False
tmp_dict['msg'] = resp_dict['msg']
elif status_code == 10000:
loguru.logger.info(f'query task data: {json.dumps(resp_dict)}')
status_code = resp_dict['data'].get('status', 1)
if status_code == 3:
tmp_dict['status'] = False
tmp_dict['msg'] = resp_dict['data']['msg']
else:
process = resp_dict['data'].get('progress', 20)
if status_code == 2:
process = 100
else:
process = process
result = resp_dict['data'].get('result', '')
tmp_dict['data'] = {'progress': process,
'path': result,
}
else:
pass
except Exception as e:
loguru.logger.info(f'query task fail case by:{str(e)}')
raise RuntimeError(str(e))
return tmp_dict
end = time.time() + timeout
while time.time() < end:
tmp_result = query_result(task_id)
if not tmp_result['status'] or tmp_result['data'].__eq__({}):
result_json['status'] = False
result_json['msg'] = tmp_result['msg']
break
else:
process = tmp_result['data']['progress']
loguru.logger.info(f'query task progress :{process}')
if tmp_result['data']['progress'] < 100:
time.sleep(interval)
loguru.logger.info(f'wait next interval:{interval}')
else:
p = tmp_result['data']['path']
p = p.replace('/', '').replace('\\', '')
result_json['data'] = "%s/%s"%(heygem_temp_path, p)
result_json['status'] = True
return result_json
return result_json
def path_convert(path):
if ":" in path:
path = path.replace(os.sep,"/").split(":")
path[0] = path[0].lower()
path[1] = path[1][1:]
path = "/".join(["/mnt",*path])
return path
def result_path_convert(result_path:str):
if result_path.startswith("/"):
result_path = result_path.replace("/","\\")
result_path = r"\\wsl.localhost\Debian" + result_path
return result_path
class HeyGemF2F:
"""HeyGem 嘴型同步"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"video": ("IMAGE", {"forceInput": True}),
"audio": ("AUDIO", {"forceInput": True}),
"heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}),
"heygem_temp_path": ("STRING", {"default":"/code/data/temp"}),
"is_Windows": ("BOOLEAN", {"default": False})
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("视频存储路径",)
FUNCTION = "f2f"
CATEGORY = "不忘科技-自定义节点🚩/视频/口型"
def f2f(self, video:Tensor, audio:dict, heygem_url:str, heygem_temp_path:str, is_Windows:bool):
uid = str(uuid.uuid4())
video_path = os.path.join(os.path.dirname(__file__),"%s.mp4" % uid)
audio_path = os.path.join(os.path.dirname(__file__),"%s.wav" % uid)
try:
try:
torchvision.io.write_video(video_path, video.mul_(255).int(),25)
torchaudio.save(audio_path, audio["waveform"].squeeze(0), audio["sample_rate"], True)
except:
traceback.print_exc()
raise RuntimeError("Save Temp File Error! ")
submit_result = task_submit(uid, path_convert(video_path), path_convert(audio_path), heygem_url)
if not submit_result['status']:
return {
'status': False,
'data': {},
'msg': f"Task submission failed: {submit_result['msg']}"
}
task_id = submit_result['data']
loguru.logger.info(f'Submitted task: {task_id}')
# Query task progress
progress_result = query_task_progress(heygem_url, heygem_temp_path, task_id, interval=5)
if not progress_result['status']:
raise RuntimeError(f"Task processing failed: {progress_result['msg']}")
# Return the file for download
file_path = progress_result['data']
if is_Windows:
file_path = result_path_convert(file_path)
if os.path.exists(file_path):
return (file_path,)
else:
raise FileNotFoundError(f"Output file not found at {file_path}")
except Exception as e:
loguru.logger.error(f"Error processing request: {str(e)}")
raise Exception(str(e))
finally:
try:
os.remove(os.path.join(os.path.dirname(__file__),"%s.mp4" % uid))
os.remove(os.path.join(os.path.dirname(__file__),"%s.wav" % uid))
except:
pass
class HeyGemF2FFromFile:
"""HeyGem 嘴型同步 直接读取文件"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"video": ("STRING", {"forceInput": True}),
"audio": ("STRING", {"forceInput": True}),
"heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}),
"heygem_temp_path": ("STRING", {"default":"/code/data/temp"}),
"is_Windows": ("BOOLEAN", {"default": False})
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("视频存储路径",)
FUNCTION = "f2f"
CATEGORY = "不忘科技-自定义节点🚩/视频/口型"
def f2f(self, video:str, audio:str, heygem_url:str, heygem_temp_path:str, is_Windows:bool):
uid = str(uuid.uuid4())
try:
submit_result = task_submit(uid, video, audio, heygem_url)
if not submit_result['status']:
return {
'status': False,
'data': {},
'msg': f"Task submission failed: {submit_result['msg']}"
}
task_id = submit_result['data']
loguru.logger.info(f'Submitted task: {task_id}')
# Query task progress
progress_result = query_task_progress(heygem_url,heygem_temp_path,task_id)
if not progress_result['status']:
raise RuntimeError(f"Task processing failed: {progress_result['msg']}")
# Return the file for download
file_path = progress_result['data']
if is_Windows:
file_path = result_path_convert(file_path)
if os.path.exists(file_path):
return (file_path,)
else:
raise FileNotFoundError(f"Output file not found at {file_path}")
except Exception as e:
loguru.logger.error(f"Error processing request: {str(e)}")
raise Exception(str(e))