269 lines
9.2 KiB
Python
269 lines
9.2 KiB
Python
import json
|
|
import os
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
|
|
import httpx
|
|
import loguru
|
|
import torchaudio
|
|
import torchvision
|
|
from torch import Tensor
|
|
|
|
|
|
def task_submit(uid, video, audio, heygem_url):
|
|
"""Submit a task to the API"""
|
|
task_submit_api = f"{heygem_url}/easy/submit"
|
|
result_json = {"status": False, "data": {}, "msg": ""}
|
|
try:
|
|
data = {
|
|
"code": uid,
|
|
"video_url": video,
|
|
"audio_url": audio,
|
|
"chaofen": 1,
|
|
"watermark_switch": 0,
|
|
"pn": 1,
|
|
}
|
|
loguru.logger.info(f"data={data}")
|
|
with httpx.Client() as client:
|
|
resp = client.post(task_submit_api, json=data)
|
|
resp_dict = resp.json()
|
|
loguru.logger.info(f"submit data: {resp_dict}")
|
|
if resp_dict["code"] != 10000:
|
|
result_json["status"] = False
|
|
result_json["msg"] = result_json["msg"]
|
|
else:
|
|
result_json["status"] = True
|
|
result_json["data"] = uid
|
|
result_json["msg"] = "任务提交成功"
|
|
except Exception as e:
|
|
loguru.logger.info(f"submit task fail case by:{str(e)}")
|
|
raise RuntimeError(str(e))
|
|
|
|
return result_json
|
|
|
|
|
|
def query_task_progress(
|
|
heygem_url,
|
|
heygem_temp_path,
|
|
task_id: str,
|
|
interval: int = 10,
|
|
timeout: int = 60 * 15,
|
|
):
|
|
"""Query task progress and wait for completion"""
|
|
result_json = {"status": False, "data": {}, "msg": ""}
|
|
|
|
def query_result(t_id: str):
|
|
tmp_dict = {"status": True, "data": dict(), "msg": ""}
|
|
try:
|
|
query_task_url = f"{heygem_url}/easy/query"
|
|
params = {"code": t_id}
|
|
with httpx.Client() as client:
|
|
resp = client.get(query_task_url, params=params)
|
|
resp_dict = resp.json()
|
|
status_code = resp_dict["code"]
|
|
if status_code in (9999, 10002, 10003, 10001):
|
|
tmp_dict["status"] = False
|
|
tmp_dict["msg"] = resp_dict["msg"]
|
|
elif status_code == 10000:
|
|
loguru.logger.info(f"query task data: {json.dumps(resp_dict)}")
|
|
status_code = resp_dict["data"].get("status", 1)
|
|
if status_code == 3:
|
|
tmp_dict["status"] = False
|
|
tmp_dict["msg"] = resp_dict["data"]["msg"]
|
|
else:
|
|
process = resp_dict["data"].get("progress", 20)
|
|
if status_code == 2:
|
|
process = 100
|
|
else:
|
|
process = process
|
|
result = resp_dict["data"].get("result", "")
|
|
tmp_dict["data"] = {
|
|
"progress": process,
|
|
"path": result,
|
|
}
|
|
else:
|
|
pass
|
|
except Exception as e:
|
|
loguru.logger.info(f"query task fail case by:{str(e)}")
|
|
raise RuntimeError(str(e))
|
|
return tmp_dict
|
|
|
|
end = time.time() + timeout
|
|
while time.time() < end:
|
|
tmp_result = query_result(task_id)
|
|
if not tmp_result["status"] or tmp_result["data"].__eq__({}):
|
|
result_json["status"] = False
|
|
result_json["msg"] = tmp_result["msg"]
|
|
break
|
|
else:
|
|
process = tmp_result["data"]["progress"]
|
|
loguru.logger.info(f"query task progress :{process}")
|
|
if tmp_result["data"]["progress"] < 100:
|
|
time.sleep(interval)
|
|
loguru.logger.info(f"wait next interval:{interval}")
|
|
else:
|
|
p = tmp_result["data"]["path"]
|
|
p = p.replace("/", "").replace("\\", "")
|
|
result_json["data"] = "%s/%s" % (heygem_temp_path, p)
|
|
result_json["status"] = True
|
|
return result_json
|
|
|
|
return result_json
|
|
|
|
|
|
def path_convert(path):
|
|
if ":" in path:
|
|
path = path.replace(os.sep, "/").split(":")
|
|
path[0] = path[0].lower()
|
|
path[1] = path[1][1:]
|
|
path = "/".join(["/mnt", *path])
|
|
return path
|
|
|
|
|
|
def result_path_convert(result_path: str):
|
|
if result_path.startswith("/"):
|
|
result_path = result_path.replace("/", "\\")
|
|
result_path = r"\\wsl.localhost\Debian" + result_path
|
|
return result_path
|
|
|
|
|
|
class HeyGemF2F:
|
|
"""HeyGem 嘴型同步"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"video": ("IMAGE", {"forceInput": True}),
|
|
"audio": ("AUDIO", {"forceInput": True}),
|
|
"heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}),
|
|
"heygem_temp_path": ("STRING", {"default": "/code/data/temp"}),
|
|
"is_Windows": ("BOOLEAN", {"default": False}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
RETURN_NAMES = ("视频存储路径",)
|
|
FUNCTION = "f2f"
|
|
CATEGORY = "不忘科技-自定义节点🚩/视频/口型"
|
|
|
|
def f2f(
|
|
self,
|
|
video: Tensor,
|
|
audio: dict,
|
|
heygem_url: str,
|
|
heygem_temp_path: str,
|
|
is_Windows: bool,
|
|
):
|
|
uid = str(uuid.uuid4())
|
|
video_path = os.path.join(os.path.dirname(__file__), "%s.mp4" % uid)
|
|
audio_path = os.path.join(os.path.dirname(__file__), "%s.wav" % uid)
|
|
try:
|
|
try:
|
|
torchvision.io.write_video(video_path, video.mul_(255).int(), 25)
|
|
torchaudio.save(
|
|
audio_path, audio["waveform"].squeeze(0), audio["sample_rate"], True
|
|
)
|
|
except:
|
|
traceback.print_exc()
|
|
raise RuntimeError("Save Temp File Error! ")
|
|
submit_result = task_submit(
|
|
uid, path_convert(video_path), path_convert(audio_path), heygem_url
|
|
)
|
|
if not submit_result["status"]:
|
|
return {
|
|
"status": False,
|
|
"data": {},
|
|
"msg": f"Task submission failed: {submit_result['msg']}",
|
|
}
|
|
|
|
task_id = submit_result["data"]
|
|
loguru.logger.info(f"Submitted task: {task_id}")
|
|
|
|
# Query task progress
|
|
progress_result = query_task_progress(
|
|
heygem_url, heygem_temp_path, task_id, interval=5
|
|
)
|
|
|
|
if not progress_result["status"]:
|
|
raise RuntimeError(f"Task processing failed: {progress_result['msg']}")
|
|
|
|
# Return the file for download
|
|
file_path = progress_result["data"]
|
|
if is_Windows:
|
|
file_path = result_path_convert(file_path)
|
|
if os.path.exists(file_path):
|
|
return (file_path,)
|
|
else:
|
|
raise FileNotFoundError(f"Output file not found at {file_path}")
|
|
except Exception as e:
|
|
loguru.logger.error(f"Error processing request: {str(e)}")
|
|
raise Exception(str(e))
|
|
finally:
|
|
try:
|
|
os.remove(os.path.join(os.path.dirname(__file__), "%s.mp4" % uid))
|
|
os.remove(os.path.join(os.path.dirname(__file__), "%s.wav" % uid))
|
|
except:
|
|
pass
|
|
|
|
|
|
class HeyGemF2FFromFile:
|
|
"""HeyGem 嘴型同步 直接读取文件"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"video": ("STRING", {"forceInput": True}),
|
|
"audio": ("STRING", {"forceInput": True}),
|
|
"heygem_url": ("STRING", {"default": "http://127.0.0.1:8383"}),
|
|
"heygem_temp_path": ("STRING", {"default": "/code/data/temp"}),
|
|
"is_Windows": ("BOOLEAN", {"default": False}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
RETURN_NAMES = ("视频存储路径",)
|
|
FUNCTION = "f2f"
|
|
CATEGORY = "不忘科技-自定义节点🚩/视频/口型"
|
|
|
|
def f2f(
|
|
self,
|
|
video: str,
|
|
audio: str,
|
|
heygem_url: str,
|
|
heygem_temp_path: str,
|
|
is_Windows: bool,
|
|
):
|
|
uid = str(uuid.uuid4())
|
|
try:
|
|
submit_result = task_submit(uid, video, audio, heygem_url)
|
|
if not submit_result["status"]:
|
|
return {
|
|
"status": False,
|
|
"data": {},
|
|
"msg": f"Task submission failed: {submit_result['msg']}",
|
|
}
|
|
|
|
task_id = submit_result["data"]
|
|
loguru.logger.info(f"Submitted task: {task_id}")
|
|
|
|
# Query task progress
|
|
progress_result = query_task_progress(heygem_url, heygem_temp_path, task_id)
|
|
|
|
if not progress_result["status"]:
|
|
raise RuntimeError(f"Task processing failed: {progress_result['msg']}")
|
|
|
|
# Return the file for download
|
|
file_path = progress_result["data"]
|
|
if is_Windows:
|
|
file_path = result_path_convert(file_path)
|
|
if os.path.exists(file_path):
|
|
return (file_path,)
|
|
else:
|
|
raise FileNotFoundError(f"Output file not found at {file_path}")
|
|
except Exception as e:
|
|
loguru.logger.error(f"Error processing request: {str(e)}")
|
|
raise Exception(str(e))
|