270 lines
12 KiB
Python
270 lines
12 KiB
Python
import json
|
|
import os
|
|
import shutil
|
|
import socket
|
|
import subprocess
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from concurrent.futures.thread import ThreadPoolExecutor
|
|
from typing import Any, Optional, Union
|
|
|
|
import httpx
|
|
import loguru
|
|
import requests
|
|
import starlette.datastructures
|
|
import uvicorn
|
|
|
|
from fastapi import UploadFile, HTTPException, Depends, FastAPI
|
|
from fastapi.routing import APIRoute
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from starlette import status
|
|
|
|
|
|
class HeyGem:
|
|
def __init__(self):
|
|
def check_port_in_use(port, host='127.0.0.1'):
|
|
s = None
|
|
try:
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.settimeout(1)
|
|
s.connect((host, int(port)))
|
|
return True
|
|
except socket.error:
|
|
return False
|
|
finally:
|
|
if s:
|
|
s.close()
|
|
|
|
uid = str(uuid.uuid4())
|
|
cmd = "/bin/bash -c \"echo client_id: %s && source /root/.bashrc && nohup python /root/heygem.py >> /root/logs/%s.log 2>&1 &\"" % (uid, uid)
|
|
subprocess.run(cmd, shell=True, check=True)
|
|
timeout = 30
|
|
while timeout > 0:
|
|
check = check_port_in_use(8383)
|
|
if check:
|
|
loguru.logger.success("HeyGem Server started on port 8383")
|
|
break
|
|
else:
|
|
timeout -= 1
|
|
time.sleep(2)
|
|
if timeout == 0:
|
|
raise TimeoutError("HeyGem Server Start timed out")
|
|
self.result = {}
|
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
|
self.server = FastAPI()
|
|
self.server.routes.append(APIRoute(path='/', endpoint=self.api, methods=['POST']))
|
|
self.server.routes.append(APIRoute(path='/{code}', endpoint=self.get_result, methods=['GET']))
|
|
|
|
|
|
|
|
def infer(self, code, video_path, audio_path) -> str:
|
|
def task_submit(uid, video, audio, heygem_url):
|
|
"""Submit a task to the API"""
|
|
task_submit_api = f'{heygem_url}/easy/submit'
|
|
result_json = {
|
|
'status': False, 'data': {}, 'msg': ''
|
|
}
|
|
try:
|
|
data = {
|
|
"code": uid,
|
|
"video_url": video,
|
|
"audio_url": audio,
|
|
"chaofen": 1,
|
|
"watermark_switch": 0,
|
|
"pn": 1
|
|
}
|
|
loguru.logger.info(f'data={json.dumps(data, ensure_ascii=False)}')
|
|
with httpx.Client() as client:
|
|
resp = client.post(task_submit_api, json=data)
|
|
resp_dict = resp.json()
|
|
loguru.logger.info(f'submit data: {json.dumps(resp_dict, ensure_ascii=False)}')
|
|
if resp_dict['code'] != 10000:
|
|
result_json['status'] = False
|
|
result_json['msg'] = result_json['msg']
|
|
else:
|
|
result_json['status'] = True
|
|
result_json['data'] = uid
|
|
result_json['msg'] = '任务提交成功'
|
|
return result_json
|
|
except Exception as e:
|
|
loguru.logger.info(f'submit task fail case by:{str(e)}')
|
|
raise RuntimeError(str(e))
|
|
|
|
def query_task_progress(heygem_url, heygem_temp_path, task_id: str, interval: int = 10, timeout: int = 60 * 35):
|
|
"""Query task progress and wait for completion"""
|
|
result_json = {'status': False, 'data': {}, 'msg': ''}
|
|
|
|
def query_result(t_id: str):
|
|
tmp_dict = {'status': True, 'data': dict(), 'msg': ''}
|
|
try:
|
|
query_task_url = f'{heygem_url}/easy/query'
|
|
params = {
|
|
'code': t_id
|
|
}
|
|
with httpx.Client() as client:
|
|
resp = client.get(query_task_url, params=params)
|
|
resp_dict = resp.json()
|
|
loguru.logger.info(f'query task data: {json.dumps(resp_dict, ensure_ascii=False)}')
|
|
status_code = resp_dict['code']
|
|
if status_code in (9999, 10002, 10003, 10001):
|
|
tmp_dict['status'] = False
|
|
tmp_dict['msg'] = resp_dict['msg']
|
|
elif status_code == 10000:
|
|
status_code = resp_dict['data'].get('status', 1)
|
|
if status_code == 3:
|
|
tmp_dict['status'] = False
|
|
tmp_dict['msg'] = resp_dict['data']['msg']
|
|
else:
|
|
process = resp_dict['data'].get('progress', 20)
|
|
if status_code == 2:
|
|
process = 100
|
|
else:
|
|
process = process
|
|
result = resp_dict['data'].get('result', '')
|
|
tmp_dict['data'] = {'progress': process,
|
|
'path': result,
|
|
}
|
|
else:
|
|
pass
|
|
except Exception as e:
|
|
loguru.logger.info(f'query task fail case by:{str(e)}')
|
|
raise RuntimeError(str(e))
|
|
return tmp_dict
|
|
|
|
end = time.time() + timeout
|
|
while time.time() < end:
|
|
tmp_result = query_result(task_id)
|
|
if not tmp_result['status'] or tmp_result['data'].__eq__({}):
|
|
result_json['status'] = False
|
|
result_json['msg'] = tmp_result['msg']
|
|
break
|
|
else:
|
|
process = tmp_result['data']['progress']
|
|
loguru.logger.info(f'query task progress :{process}')
|
|
if tmp_result['data']['progress'] < 100:
|
|
time.sleep(interval)
|
|
loguru.logger.info(f'wait next interval:{interval}')
|
|
else:
|
|
p = tmp_result['data']['path']
|
|
p = p.replace('/', '').replace('\\', '')
|
|
result_json['data'] = "%s/%s" % (heygem_temp_path, p)
|
|
result_json['status'] = True
|
|
break
|
|
return result_json
|
|
|
|
try:
|
|
heygem_url = "http://127.0.0.1:8383"
|
|
heygem_temp_path = "/code/data/temp"
|
|
submit_result = task_submit(code, video_path, audio_path, heygem_url)
|
|
if not submit_result['status']:
|
|
raise Exception(f"Task submission failed: {submit_result['msg']}")
|
|
|
|
task_id = submit_result['data']
|
|
loguru.logger.info(f'Submitted task: {task_id}')
|
|
|
|
# Query task progress
|
|
progress_result = query_task_progress(heygem_url, heygem_temp_path, task_id, interval=5)
|
|
|
|
if not progress_result['status']:
|
|
raise RuntimeError(f"Task processing failed: {progress_result['msg']}")
|
|
|
|
# Return the file for download
|
|
file_path = progress_result['data']
|
|
if not os.path.exists(file_path):
|
|
raise FileNotFoundError(f"Output file not found at {file_path}")
|
|
return file_path
|
|
except Exception as e:
|
|
loguru.logger.error(f"Error processing request || {str(e)}")
|
|
raise Exception(str(e))
|
|
finally:
|
|
try:
|
|
os.remove(video_path)
|
|
os.remove(audio_path)
|
|
except:
|
|
pass
|
|
|
|
def download_file(self, url, path):
|
|
with requests.get(url, stream=True) as r:
|
|
with open(path, 'wb') as f:
|
|
shutil.copyfileobj(r.raw, f)
|
|
|
|
async def get_result(self, code:str, token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
|
|
if token.credentials != os.environ["AUTH_TOKEN"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect bearer token",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
if code in self.result:
|
|
return self.result[code]
|
|
else:
|
|
return {"status": "waiting", "msg":""}
|
|
|
|
async def api(self, video_file: Optional[Union[UploadFile, str]], audio_file: Optional[Union[UploadFile, str]], token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
|
|
if token.credentials != os.environ["AUTH_TOKEN"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect bearer token",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
code = str(uuid.uuid4())
|
|
video_path=None
|
|
audio_path=None
|
|
try:
|
|
file_id = str(uuid.uuid4())
|
|
if isinstance(video_file, UploadFile) or isinstance(video_file, starlette.datastructures.UploadFile):
|
|
video_path = "/root/%s.%s" % (file_id, video_file.filename.split(".")[-1])
|
|
with open(video_path, "wb") as f:
|
|
data = await video_file.read()
|
|
f.write(data)
|
|
elif isinstance(video_file, str) and video_file.startswith("http"):
|
|
video_path = "/root/%s.%s" % (file_id, video_file.split("?")[0].split("/")[-1].split(".")[-1])
|
|
self.download_file(video_file, video_path)
|
|
else:
|
|
self.result[code] = {"status": "fail", "msg": "video file save failed: Not a valid input"}
|
|
|
|
if isinstance(audio_file, UploadFile) or isinstance(audio_file, starlette.datastructures.UploadFile):
|
|
audio_path = "/root/%s.%s" % (file_id, audio_file.filename.split(".")[-1])
|
|
with open(audio_path, "wb") as f:
|
|
data = await audio_file.read()
|
|
f.write(data)
|
|
elif isinstance(audio_file, str) and audio_file.startswith("http"):
|
|
audio_path = "/root/%s.%s" % (file_id, audio_file.split("?")[0].split("/")[-1].split(".")[-1])
|
|
self.download_file(audio_file, audio_path)
|
|
else:
|
|
self.result[code] = {"status": "fail", "msg": "audio file save failed: Not a valid input"}
|
|
except Exception as e:
|
|
self.result[code] = {"status": "fail", "msg": f"video or audio file save failed: {str(e)}"}
|
|
# Submit
|
|
if video_path and audio_path:
|
|
self.executor.submit(self._api, video_path=video_path, audio_path=audio_path, code=code)
|
|
return {"status": "success", "code": code}
|
|
else:
|
|
return self.result[code]
|
|
|
|
def _api(self, video_path: str, audio_path: str, code):
|
|
loguru.logger.info(f"Enter Inference Module")
|
|
try:
|
|
file_path = self.infer(code, video_path, audio_path)
|
|
try:
|
|
loguru.logger.info("Try move file to S3 manually")
|
|
# S3 Fallback
|
|
import boto3
|
|
import yaml
|
|
with open("/root/config.yaml", encoding="utf-8",
|
|
mode="r+") as config:
|
|
yaml_config = yaml.load(config, Loader=yaml.FullLoader)
|
|
awss3 = boto3.resource('s3', aws_access_key_id=yaml_config["aws_key_id"],
|
|
aws_secret_access_key=yaml_config["aws_access_key"])
|
|
awss3.meta.client.upload_file(file_path, "bw-heygem-output", file_path.split(os.path.sep)[-1])
|
|
except Exception as e:
|
|
self.result[code] = {"status": "fail", "msg": "Failed to move file to S3 manually: " + str(e)}
|
|
self.result[code] = {"status": "success", "msg":f"{file_path.split(os.path.sep)[-1]}"}
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
self.result[code] = {"status": "fail", "msg": "Inference module failed: "+str(e)}
|
|
|
|
if __name__ == "__main__":
|
|
heygem = HeyGem()
|
|
uvicorn.run(app=heygem.server, host="0.0.0.0", port=6006) |