import json import os import shutil import socket import subprocess import time import traceback import uuid from concurrent.futures.thread import ThreadPoolExecutor from typing import Any, Optional, Union import httpx import loguru import requests import starlette.datastructures import uvicorn from fastapi import UploadFile, HTTPException, Depends, FastAPI from fastapi.routing import APIRoute from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from starlette import status class HeyGem: def __init__(self): def check_port_in_use(port, host='127.0.0.1'): s = None try: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(1) s.connect((host, int(port))) return True except socket.error: return False finally: if s: s.close() uid = str(uuid.uuid4()) cmd = "/bin/bash -c \"echo client_id: %s && source /root/.bashrc && nohup python /root/heygem.py >> /root/logs/%s.log 2>&1 &\"" % (uid, uid) subprocess.run(cmd, shell=True, check=True) timeout = 30 while timeout > 0: check = check_port_in_use(8383) if check: loguru.logger.success("HeyGem Server started on port 8383") break else: timeout -= 1 time.sleep(2) if timeout == 0: raise TimeoutError("HeyGem Server Start timed out") self.result = {} self.executor = ThreadPoolExecutor(max_workers=1) self.server = FastAPI() self.server.routes.append(APIRoute(path='/', endpoint=self.api, methods=['POST'])) self.server.routes.append(APIRoute(path='/{code}', endpoint=self.get_result, methods=['GET'])) def infer(self, code, video_path, audio_path) -> str: def task_submit(uid, video, audio, heygem_url): """Submit a task to the API""" task_submit_api = f'{heygem_url}/easy/submit' result_json = { 'status': False, 'data': {}, 'msg': '' } try: data = { "code": uid, "video_url": video, "audio_url": audio, "chaofen": 1, "watermark_switch": 0, "pn": 1 } loguru.logger.info(f'data={json.dumps(data, ensure_ascii=False)}') with httpx.Client() as client: resp = client.post(task_submit_api, json=data) resp_dict = resp.json() loguru.logger.info(f'submit data: {json.dumps(resp_dict, ensure_ascii=False)}') if resp_dict['code'] != 10000: result_json['status'] = False result_json['msg'] = result_json['msg'] else: result_json['status'] = True result_json['data'] = uid result_json['msg'] = '任务提交成功' return result_json except Exception as e: loguru.logger.info(f'submit task fail case by:{str(e)}') raise RuntimeError(str(e)) def query_task_progress(heygem_url, heygem_temp_path, task_id: str, interval: int = 10, timeout: int = 60 * 35): """Query task progress and wait for completion""" result_json = {'status': False, 'data': {}, 'msg': ''} def query_result(t_id: str): tmp_dict = {'status': True, 'data': dict(), 'msg': ''} try: query_task_url = f'{heygem_url}/easy/query' params = { 'code': t_id } with httpx.Client() as client: resp = client.get(query_task_url, params=params) resp_dict = resp.json() loguru.logger.info(f'query task data: {json.dumps(resp_dict, ensure_ascii=False)}') status_code = resp_dict['code'] if status_code in (9999, 10002, 10003, 10001): tmp_dict['status'] = False tmp_dict['msg'] = resp_dict['msg'] elif status_code == 10000: status_code = resp_dict['data'].get('status', 1) if status_code == 3: tmp_dict['status'] = False tmp_dict['msg'] = resp_dict['data']['msg'] else: process = resp_dict['data'].get('progress', 20) if status_code == 2: process = 100 else: process = process result = resp_dict['data'].get('result', '') tmp_dict['data'] = {'progress': process, 'path': result, } else: pass except Exception as e: loguru.logger.info(f'query task fail case by:{str(e)}') raise RuntimeError(str(e)) return tmp_dict end = time.time() + timeout while time.time() < end: tmp_result = query_result(task_id) if not tmp_result['status'] or tmp_result['data'].__eq__({}): result_json['status'] = False result_json['msg'] = tmp_result['msg'] break else: process = tmp_result['data']['progress'] loguru.logger.info(f'query task progress :{process}') if tmp_result['data']['progress'] < 100: time.sleep(interval) loguru.logger.info(f'wait next interval:{interval}') else: p = tmp_result['data']['path'] p = p.replace('/', '').replace('\\', '') result_json['data'] = "%s/%s" % (heygem_temp_path, p) result_json['status'] = True break return result_json try: heygem_url = "http://127.0.0.1:8383" heygem_temp_path = "/code/data/temp" submit_result = task_submit(code, video_path, audio_path, heygem_url) if not submit_result['status']: raise Exception(f"Task submission failed: {submit_result['msg']}") task_id = submit_result['data'] loguru.logger.info(f'Submitted task: {task_id}') # Query task progress progress_result = query_task_progress(heygem_url, heygem_temp_path, task_id, interval=5) if not progress_result['status']: raise RuntimeError(f"Task processing failed: {progress_result['msg']}") # Return the file for download file_path = progress_result['data'] if not os.path.exists(file_path): raise FileNotFoundError(f"Output file not found at {file_path}") return file_path except Exception as e: loguru.logger.error(f"Error processing request || {str(e)}") raise Exception(str(e)) finally: try: os.remove(video_path) os.remove(audio_path) except: pass def download_file(self, url, path): with requests.get(url, stream=True) as r: with open(path, 'wb') as f: shutil.copyfileobj(r.raw, f) async def get_result(self, code:str, token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): if token.credentials != os.environ["AUTH_TOKEN"]: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect bearer token", headers={"WWW-Authenticate": "Bearer"}, ) if code in self.result: return self.result[code] else: return {"status": "waiting", "msg":""} async def api(self, video_file: Optional[Union[UploadFile, str]], audio_file: Optional[Union[UploadFile, str]], token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): if token.credentials != os.environ["AUTH_TOKEN"]: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect bearer token", headers={"WWW-Authenticate": "Bearer"}, ) code = str(uuid.uuid4()) video_path=None audio_path=None try: file_id = str(uuid.uuid4()) if isinstance(video_file, UploadFile) or isinstance(video_file, starlette.datastructures.UploadFile): video_path = "/root/%s.%s" % (file_id, video_file.filename.split(".")[-1]) with open(video_path, "wb") as f: data = await video_file.read() f.write(data) elif isinstance(video_file, str) and video_file.startswith("http"): video_path = "/root/%s.%s" % (file_id, video_file.split("?")[0].split("/")[-1].split(".")[-1]) self.download_file(video_file, video_path) else: self.result[code] = {"status": "fail", "msg": "video file save failed: Not a valid input"} if isinstance(audio_file, UploadFile) or isinstance(audio_file, starlette.datastructures.UploadFile): audio_path = "/root/%s.%s" % (file_id, audio_file.filename.split(".")[-1]) with open(audio_path, "wb") as f: data = await audio_file.read() f.write(data) elif isinstance(audio_file, str) and audio_file.startswith("http"): audio_path = "/root/%s.%s" % (file_id, audio_file.split("?")[0].split("/")[-1].split(".")[-1]) self.download_file(audio_file, audio_path) else: self.result[code] = {"status": "fail", "msg": "audio file save failed: Not a valid input"} except Exception as e: self.result[code] = {"status": "fail", "msg": f"video or audio file save failed: {str(e)}"} # Submit if video_path and audio_path: self.executor.submit(self._api, video_path=video_path, audio_path=audio_path, code=code) return {"status": "success", "code": code} else: return self.result[code] def _api(self, video_path: str, audio_path: str, code): loguru.logger.info(f"Enter Inference Module") try: file_path = self.infer(code, video_path, audio_path) try: loguru.logger.info("Try move file to S3 manually") # S3 Fallback import boto3 import yaml with open("/root/config.yaml", encoding="utf-8", mode="r+") as config: yaml_config = yaml.load(config, Loader=yaml.FullLoader) awss3 = boto3.resource('s3', aws_access_key_id=yaml_config["aws_key_id"], aws_secret_access_key=yaml_config["aws_access_key"]) awss3.meta.client.upload_file(file_path, "bw-heygem-output", file_path.split(os.path.sep)[-1]) except Exception as e: self.result[code] = {"status": "fail", "msg": "Failed to move file to S3 manually: " + str(e)} self.result[code] = {"status": "success", "msg":f"{file_path.split(os.path.sep)[-1]}"} except Exception as e: traceback.print_exc() self.result[code] = {"status": "fail", "msg": "Inference module failed: "+str(e)} if __name__ == "__main__": heygem = HeyGem() uvicorn.run(app=heygem.server, host="0.0.0.0", port=6006)