modalDeploy/server_pure_heygem.py

299 lines
14 KiB
Python

# HeyGem模板
import json
import os
import shutil
import socket
import subprocess
import time
import traceback
import uuid
from typing import Any, Optional, Union
import httpx
import loguru
import modal
import requests
from fastapi import Depends, HTTPException, status, UploadFile
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
image = (
modal.Image.debian_slim( # start from basic Linux with Python
python_version="3.10"
).apt_install("git")
.apt_install("gcc")
.pip_install("loguru")
.pip_install("fastapi[standard]")
.run_commands(
"apt update && apt install -y ffmpeg && ffmpeg -version"
)
# 添加Python3.8 HeyGem
.run_commands("apt update && apt install -y curl build-essential libssl-dev zlib1g-dev libncurses5-dev libncursesw5-dev libreadline-dev libsqlite3-dev libgdbm-dev libdb5.3-dev libbz2-dev libexpat1-dev lzma liblzma-dev tk-dev libffi-dev")
.run_commands("curl -O https://www.python.org/ftp/python/3.8.12/Python-3.8.12.tar.xz&&tar -xf Python-3.8.12.tar.xz")
.run_commands("cd Python-3.8.12 && ./configure --enable-optimizations && make -j 10 && make altinstall")
.add_local_file("heygem-1.0-py3-none-any.whl","/root/heygem-1.0-py3-none-any.whl", copy=True)
.shell(["/bin/bash", "-c"])
.run_commands("python3.8 -m pip install /root/heygem-1.0-py3-none-any.whl")
.env({"LD_LIBRARY_PATH":"/usr/local/lib/python3.8/site-packages/nvidia/cuda_nvrtc/lib"})
.run_commands("ln -s /usr/local/lib/python3.8/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so.11.2 /usr/local/lib/python3.8/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so")
.run_commands("python3.8 -m pip install https://github.com/pydata/numexpr/releases/download/v2.8.6/numexpr-2.8.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl")
.add_local_file("heygem.py", "/root/heygem.py", copy=True)
.add_local_file("config.yaml", "/root/config.yaml", copy=True)
.workdir("/root").pip_install("boto3").pip_install("PyYAML").pip_install("requests")
)
app = modal.App(name="heygem", image=image)
secret = modal.Secret.from_name("aws-s3-secret")
auth_scheme = HTTPBearer()
bucket_output = "bw-heygem-output"
@app.cls(
allow_concurrent_inputs=1, # required for UI startup process which runs several API calls concurrently
max_containers=25, # limit interactive session to 1 container
gpu="L40S", # good starter GPU for inference
cpu=(2,32),
memory=(6144, 32768),
timeout=2160,
scaledown_window=240,
secrets=[secret, modal.Secret.from_name("web_auth_token")],
volumes={
"/code/data/final": modal.CloudBucketMount(
bucket_name=bucket_output,
# bucket_endpoint_url="https://s3.%s.amazonaws.com" % os.environ["AWS_REGION"],
secret=secret,
key_prefix="/"
),
"/root/logs": modal.CloudBucketMount(
bucket_name=bucket_output,
# bucket_endpoint_url="https://s3.%s.amazonaws.com" % os.environ["AWS_REGION"],
secret=secret,
key_prefix="/logs/"
)
}, # mounts our cached models
)
class HeyGem:
@modal.enter()
def start(self):
def check_port_in_use(port, host='127.0.0.1'):
s = None
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(1)
s.connect((host, int(port)))
return True
except socket.error:
return False
finally:
if s:
s.close()
uid = str(uuid.uuid4())
cmd = "echo client_id: %s && nohup python3.8 /root/heygem.py >> /root/logs/%s.log 2>&1 &" % (uid, uid)
subprocess.run(cmd, shell=True, check=True)
timeout = 30
while timeout > 0:
check = check_port_in_use(8383)
if check:
loguru.logger.success("HeyGem Server started on port 8383")
break
else:
timeout -= 1
time.sleep(2)
if timeout == 0:
raise TimeoutError("HeyGem Server Start timed out")
@modal.method()
def infer(self, code, video_path, audio_path) -> str:
def task_submit(uid, video, audio, heygem_url):
"""Submit a task to the API"""
task_submit_api = f'{heygem_url}/easy/submit'
result_json = {
'status': False, 'data': {}, 'msg': ''
}
try:
data = {
"code": uid,
"video_url": video,
"audio_url": audio,
"chaofen": 1,
"watermark_switch": 0,
"pn": 1
}
loguru.logger.info(f'data={json.dumps(data, ensure_ascii=False)}')
with httpx.Client() as client:
resp = client.post(task_submit_api, json=data)
resp_dict = resp.json()
loguru.logger.info(f'submit data: {json.dumps(resp_dict, ensure_ascii=False)}')
if resp_dict['code'] != 10000:
result_json['status'] = False
result_json['msg'] = result_json['msg']
else:
result_json['status'] = True
result_json['data'] = uid
result_json['msg'] = '任务提交成功'
return result_json
except Exception as e:
loguru.logger.info(f'submit task fail case by:{str(e)}')
raise RuntimeError(str(e))
def query_task_progress(heygem_url, heygem_temp_path, task_id: str, interval: int = 10, timeout: int = 60 * 35):
"""Query task progress and wait for completion"""
result_json = {'status': False, 'data': {}, 'msg': ''}
def query_result(t_id: str):
tmp_dict = {'status': True, 'data': dict(), 'msg': ''}
try:
query_task_url = f'{heygem_url}/easy/query'
params = {
'code': t_id
}
with httpx.Client() as client:
resp = client.get(query_task_url, params=params)
resp_dict = resp.json()
loguru.logger.info(f'query task data: {json.dumps(resp_dict, ensure_ascii=False)}')
status_code = resp_dict['code']
if status_code in (9999, 10002, 10003, 10001):
tmp_dict['status'] = False
tmp_dict['msg'] = resp_dict['msg']
elif status_code == 10000:
status_code = resp_dict['data'].get('status', 1)
if status_code == 3:
tmp_dict['status'] = False
tmp_dict['msg'] = resp_dict['data']['msg']
else:
process = resp_dict['data'].get('progress', 20)
if status_code == 2:
process = 100
else:
process = process
result = resp_dict['data'].get('result', '')
tmp_dict['data'] = {'progress': process,
'path': result,
}
else:
pass
except Exception as e:
loguru.logger.info(f'query task fail case by:{str(e)}')
raise RuntimeError(str(e))
return tmp_dict
end = time.time() + timeout
while time.time() < end:
tmp_result = query_result(task_id)
if not tmp_result['status'] or tmp_result['data'].__eq__({}):
result_json['status'] = False
result_json['msg'] = tmp_result['msg']
break
else:
process = tmp_result['data']['progress']
loguru.logger.info(f'query task progress :{process}')
if tmp_result['data']['progress'] < 100:
time.sleep(interval)
loguru.logger.info(f'wait next interval:{interval}')
else:
p = tmp_result['data']['path']
p = p.replace('/', '').replace('\\', '')
result_json['data'] = "%s/%s" % (heygem_temp_path, p)
result_json['status'] = True
break
return result_json
try:
heygem_url = "http://127.0.0.1:8383"
heygem_temp_path = "/code/data/temp"
submit_result = task_submit(code, video_path, audio_path, heygem_url)
if not submit_result['status']:
raise Exception(f"Task submission failed: {submit_result['msg']}")
task_id = submit_result['data']
loguru.logger.info(f'Submitted task: {task_id}')
# Query task progress
progress_result = query_task_progress(heygem_url, heygem_temp_path, task_id, interval=5)
if not progress_result['status']:
raise RuntimeError(f"Task processing failed: {progress_result['msg']}")
# Return the file for download
file_path = progress_result['data']
if not os.path.exists(file_path):
raise FileNotFoundError(f"Output file not found at {file_path}")
return file_path
except Exception as e:
loguru.logger.error(f"Error processing request || {str(e)}")
raise Exception(str(e))
finally:
try:
os.remove(video_path)
os.remove(audio_path)
except:
pass
def download_file(self, url, path):
with requests.get(url, stream=True) as r:
with open(path, 'wb') as f:
shutil.copyfileobj(r.raw, f)
@modal.fastapi_endpoint(method="POST")
async def api(self, video_file: Optional[Union[UploadFile, str]], audio_file: Optional[Union[UploadFile, str]], token: HTTPAuthorizationCredentials = Depends(auth_scheme)):
if token.credentials != os.environ["AUTH_TOKEN"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
code = str(uuid.uuid4())
try:
if isinstance(video_file, UploadFile):
video_path = "/root/%s" % video_file.filename
with open(video_path, "wb") as f:
data = await video_file.read()
f.write(data)
elif isinstance(video_file, str) and video_file.startswith("http"):
video_path = "/root/%s" % video_file.split("?")[0].split("/")[-1]
self.download_file(video_file, video_path)
else:
return {"status": "fail", "msg": "video file save failed: Not a valid input"}
if isinstance(audio_file, UploadFile):
audio_path = "/root/%s" % audio_file.filename
with open(audio_path, "wb") as f:
data = await audio_file.read()
f.write(data)
elif isinstance(audio_file, str) and audio_file.startswith("http"):
audio_path = "/root/%s" % audio_file.split("?")[0].split("/")[-1]
self.download_file(audio_file, audio_path)
else:
return {"status": "fail", "msg": "audio file save failed: Not a valid input"}
except Exception as e:
return {"status": "fail", "msg": f"video or audio file save failed: {str(e)}"}
loguru.logger.info(f"Enter Inference Module")
try:
file_path = self.infer.local(code, video_path, audio_path)
try:
shutil.copy(file_path, os.path.join("/code/data/final",file_path.split(os.path.sep)[-1]))
except Exception as e:
loguru.logger.info(f"Moving {file_path} to S3 Failed: {str(e)}")
try:
print("Try move file to S3 manually")
# S3 Fallback
import boto3
import yaml
with open("/root/config.yaml", encoding="utf-8",
mode="r+") as config:
yaml_config = yaml.load(config, Loader=yaml.FullLoader)
awss3 = boto3.resource('s3', aws_access_key_id=yaml_config["aws_key_id"],
aws_secret_access_key=yaml_config["aws_access_key"])
awss3.meta.client.upload_file(file_path, bucket_output, file_path.split(os.path.sep)[-1])
except:
return {"status": "fail", "msg": "Failed to move file to S3 manually: " + str(e)}
# resp = FileResponse(path=file_path, filename=Path(file_path).name, status_code=200, headers={
# "Content-Type": "application/octet-stream",
# "Content-Disposition": f"attachment; filename={Path(file_path).name}",
# })
return {"status": "success", "msg":f"{file_path.split(os.path.sep)[-1]}"}
except Exception as e:
traceback.print_exc()
return {"status": "fail", "msg": "Inference module failed: "+str(e)}