modalDeploy/AutoDL/AutoDL_pure_heygem.py

270 lines
12 KiB
Python

import json
import os
import shutil
import socket
import subprocess
import time
import traceback
import uuid
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Any, Optional, Union
import httpx
import loguru
import requests
import starlette.datastructures
import uvicorn
from fastapi import UploadFile, HTTPException, Depends, FastAPI
from fastapi.routing import APIRoute
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
class HeyGem:
def __init__(self):
def check_port_in_use(port, host='127.0.0.1'):
s = None
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(1)
s.connect((host, int(port)))
return True
except socket.error:
return False
finally:
if s:
s.close()
uid = str(uuid.uuid4())
cmd = "/bin/bash -c \"echo client_id: %s && source /root/.bashrc && nohup python /root/heygem.py >> /root/logs/%s.log 2>&1 &\"" % (uid, uid)
subprocess.run(cmd, shell=True, check=True)
timeout = 30
while timeout > 0:
check = check_port_in_use(8383)
if check:
loguru.logger.success("HeyGem Server started on port 8383")
break
else:
timeout -= 1
time.sleep(2)
if timeout == 0:
raise TimeoutError("HeyGem Server Start timed out")
self.result = {}
self.executor = ThreadPoolExecutor(max_workers=1)
self.server = FastAPI()
self.server.routes.append(APIRoute(path='/', endpoint=self.api, methods=['POST']))
self.server.routes.append(APIRoute(path='/{code}', endpoint=self.get_result, methods=['GET']))
def infer(self, code, video_path, audio_path) -> str:
def task_submit(uid, video, audio, heygem_url):
"""Submit a task to the API"""
task_submit_api = f'{heygem_url}/easy/submit'
result_json = {
'status': False, 'data': {}, 'msg': ''
}
try:
data = {
"code": uid,
"video_url": video,
"audio_url": audio,
"chaofen": 1,
"watermark_switch": 0,
"pn": 1
}
loguru.logger.info(f'data={json.dumps(data, ensure_ascii=False)}')
with httpx.Client() as client:
resp = client.post(task_submit_api, json=data)
resp_dict = resp.json()
loguru.logger.info(f'submit data: {json.dumps(resp_dict, ensure_ascii=False)}')
if resp_dict['code'] != 10000:
result_json['status'] = False
result_json['msg'] = result_json['msg']
else:
result_json['status'] = True
result_json['data'] = uid
result_json['msg'] = '任务提交成功'
return result_json
except Exception as e:
loguru.logger.info(f'submit task fail case by:{str(e)}')
raise RuntimeError(str(e))
def query_task_progress(heygem_url, heygem_temp_path, task_id: str, interval: int = 10, timeout: int = 60 * 35):
"""Query task progress and wait for completion"""
result_json = {'status': False, 'data': {}, 'msg': ''}
def query_result(t_id: str):
tmp_dict = {'status': True, 'data': dict(), 'msg': ''}
try:
query_task_url = f'{heygem_url}/easy/query'
params = {
'code': t_id
}
with httpx.Client() as client:
resp = client.get(query_task_url, params=params)
resp_dict = resp.json()
loguru.logger.info(f'query task data: {json.dumps(resp_dict, ensure_ascii=False)}')
status_code = resp_dict['code']
if status_code in (9999, 10002, 10003, 10001):
tmp_dict['status'] = False
tmp_dict['msg'] = resp_dict['msg']
elif status_code == 10000:
status_code = resp_dict['data'].get('status', 1)
if status_code == 3:
tmp_dict['status'] = False
tmp_dict['msg'] = resp_dict['data']['msg']
else:
process = resp_dict['data'].get('progress', 20)
if status_code == 2:
process = 100
else:
process = process
result = resp_dict['data'].get('result', '')
tmp_dict['data'] = {'progress': process,
'path': result,
}
else:
pass
except Exception as e:
loguru.logger.info(f'query task fail case by:{str(e)}')
raise RuntimeError(str(e))
return tmp_dict
end = time.time() + timeout
while time.time() < end:
tmp_result = query_result(task_id)
if not tmp_result['status'] or tmp_result['data'].__eq__({}):
result_json['status'] = False
result_json['msg'] = tmp_result['msg']
break
else:
process = tmp_result['data']['progress']
loguru.logger.info(f'query task progress :{process}')
if tmp_result['data']['progress'] < 100:
time.sleep(interval)
loguru.logger.info(f'wait next interval:{interval}')
else:
p = tmp_result['data']['path']
p = p.replace('/', '').replace('\\', '')
result_json['data'] = "%s/%s" % (heygem_temp_path, p)
result_json['status'] = True
break
return result_json
try:
heygem_url = "http://127.0.0.1:8383"
heygem_temp_path = "/code/data/temp"
submit_result = task_submit(code, video_path, audio_path, heygem_url)
if not submit_result['status']:
raise Exception(f"Task submission failed: {submit_result['msg']}")
task_id = submit_result['data']
loguru.logger.info(f'Submitted task: {task_id}')
# Query task progress
progress_result = query_task_progress(heygem_url, heygem_temp_path, task_id, interval=5)
if not progress_result['status']:
raise RuntimeError(f"Task processing failed: {progress_result['msg']}")
# Return the file for download
file_path = progress_result['data']
if not os.path.exists(file_path):
raise FileNotFoundError(f"Output file not found at {file_path}")
return file_path
except Exception as e:
loguru.logger.error(f"Error processing request || {str(e)}")
raise Exception(str(e))
finally:
try:
os.remove(video_path)
os.remove(audio_path)
except:
pass
def download_file(self, url, path):
with requests.get(url, stream=True) as r:
with open(path, 'wb') as f:
shutil.copyfileobj(r.raw, f)
async def get_result(self, code:str, token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if token.credentials != os.environ["AUTH_TOKEN"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
if code in self.result:
return self.result[code]
else:
return {"status": "waiting", "msg":""}
async def api(self, video_file: Optional[Union[UploadFile, str]], audio_file: Optional[Union[UploadFile, str]], token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if token.credentials != os.environ["AUTH_TOKEN"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
code = str(uuid.uuid4())
video_path=None
audio_path=None
try:
file_id = str(uuid.uuid4())
if isinstance(video_file, UploadFile) or isinstance(video_file, starlette.datastructures.UploadFile):
video_path = "/root/%s.%s" % (file_id, video_file.filename.split(".")[-1])
with open(video_path, "wb") as f:
data = await video_file.read()
f.write(data)
elif isinstance(video_file, str) and video_file.startswith("http"):
video_path = "/root/%s.%s" % (file_id, video_file.split("?")[0].split("/")[-1].split(".")[-1])
self.download_file(video_file, video_path)
else:
self.result[code] = {"status": "fail", "msg": "video file save failed: Not a valid input"}
if isinstance(audio_file, UploadFile) or isinstance(audio_file, starlette.datastructures.UploadFile):
audio_path = "/root/%s.%s" % (file_id, audio_file.filename.split(".")[-1])
with open(audio_path, "wb") as f:
data = await audio_file.read()
f.write(data)
elif isinstance(audio_file, str) and audio_file.startswith("http"):
audio_path = "/root/%s.%s" % (file_id, audio_file.split("?")[0].split("/")[-1].split(".")[-1])
self.download_file(audio_file, audio_path)
else:
self.result[code] = {"status": "fail", "msg": "audio file save failed: Not a valid input"}
except Exception as e:
self.result[code] = {"status": "fail", "msg": f"video or audio file save failed: {str(e)}"}
# Submit
if video_path and audio_path:
self.executor.submit(self._api, video_path=video_path, audio_path=audio_path, code=code)
return {"status": "success", "code": code}
else:
return self.result[code]
def _api(self, video_path: str, audio_path: str, code):
loguru.logger.info(f"Enter Inference Module")
try:
file_path = self.infer(code, video_path, audio_path)
try:
loguru.logger.info("Try move file to S3 manually")
# S3 Fallback
import boto3
import yaml
with open("/root/config.yaml", encoding="utf-8",
mode="r+") as config:
yaml_config = yaml.load(config, Loader=yaml.FullLoader)
awss3 = boto3.resource('s3', aws_access_key_id=yaml_config["aws_key_id"],
aws_secret_access_key=yaml_config["aws_access_key"])
awss3.meta.client.upload_file(file_path, "bw-heygem-output", file_path.split(os.path.sep)[-1])
except Exception as e:
self.result[code] = {"status": "fail", "msg": "Failed to move file to S3 manually: " + str(e)}
self.result[code] = {"status": "success", "msg":f"{file_path.split(os.path.sep)[-1]}"}
except Exception as e:
traceback.print_exc()
self.result[code] = {"status": "fail", "msg": "Inference module failed: "+str(e)}
if __name__ == "__main__":
heygem = HeyGem()
uvicorn.run(app=heygem.server, host="0.0.0.0", port=6006)