299 lines
14 KiB
Python
299 lines
14 KiB
Python
# HeyGem模板
|
|
import json
|
|
import os
|
|
import shutil
|
|
import socket
|
|
import subprocess
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from typing import Any, Optional, Union
|
|
|
|
import httpx
|
|
import loguru
|
|
import modal
|
|
import requests
|
|
from fastapi import Depends, HTTPException, status, UploadFile
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
image = (
|
|
modal.Image.debian_slim( # start from basic Linux with Python
|
|
python_version="3.10"
|
|
).apt_install("git")
|
|
.apt_install("gcc")
|
|
.pip_install("loguru")
|
|
.pip_install("fastapi[standard]")
|
|
.run_commands(
|
|
"apt update && apt install -y ffmpeg && ffmpeg -version"
|
|
)
|
|
# 添加Python3.8 HeyGem
|
|
.run_commands("apt update && apt install -y curl build-essential libssl-dev zlib1g-dev libncurses5-dev libncursesw5-dev libreadline-dev libsqlite3-dev libgdbm-dev libdb5.3-dev libbz2-dev libexpat1-dev lzma liblzma-dev tk-dev libffi-dev")
|
|
.run_commands("curl -O https://www.python.org/ftp/python/3.8.12/Python-3.8.12.tar.xz&&tar -xf Python-3.8.12.tar.xz")
|
|
.run_commands("cd Python-3.8.12 && ./configure --enable-optimizations && make -j 10 && make altinstall")
|
|
.add_local_file("heygem-1.0-py3-none-any.whl","/root/heygem-1.0-py3-none-any.whl", copy=True)
|
|
.shell(["/bin/bash", "-c"])
|
|
.run_commands("python3.8 -m pip install /root/heygem-1.0-py3-none-any.whl")
|
|
.env({"LD_LIBRARY_PATH":"/usr/local/lib/python3.8/site-packages/nvidia/cuda_nvrtc/lib"})
|
|
.run_commands("ln -s /usr/local/lib/python3.8/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so.11.2 /usr/local/lib/python3.8/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so")
|
|
.run_commands("python3.8 -m pip install https://github.com/pydata/numexpr/releases/download/v2.8.6/numexpr-2.8.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl")
|
|
.add_local_file("heygem.py", "/root/heygem.py", copy=True)
|
|
.add_local_file("config.yaml", "/root/config.yaml", copy=True)
|
|
.workdir("/root").pip_install("boto3").pip_install("PyYAML").pip_install("requests")
|
|
)
|
|
|
|
app = modal.App(name="heygem", image=image)
|
|
secret = modal.Secret.from_name("aws-s3-secret")
|
|
auth_scheme = HTTPBearer()
|
|
bucket_output = "bw-heygem-output"
|
|
|
|
@app.cls(
|
|
allow_concurrent_inputs=1, # required for UI startup process which runs several API calls concurrently
|
|
max_containers=25, # limit interactive session to 1 container
|
|
gpu="L40S", # good starter GPU for inference
|
|
cpu=(2,32),
|
|
memory=(6144, 32768),
|
|
timeout=2160,
|
|
scaledown_window=240,
|
|
secrets=[secret, modal.Secret.from_name("web_auth_token")],
|
|
volumes={
|
|
"/code/data/final": modal.CloudBucketMount(
|
|
bucket_name=bucket_output,
|
|
# bucket_endpoint_url="https://s3.%s.amazonaws.com" % os.environ["AWS_REGION"],
|
|
secret=secret,
|
|
key_prefix="/"
|
|
),
|
|
"/root/logs": modal.CloudBucketMount(
|
|
bucket_name=bucket_output,
|
|
# bucket_endpoint_url="https://s3.%s.amazonaws.com" % os.environ["AWS_REGION"],
|
|
secret=secret,
|
|
key_prefix="/logs/"
|
|
)
|
|
}, # mounts our cached models
|
|
)
|
|
class HeyGem:
|
|
@modal.enter()
|
|
def start(self):
|
|
def check_port_in_use(port, host='127.0.0.1'):
|
|
s = None
|
|
try:
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.settimeout(1)
|
|
s.connect((host, int(port)))
|
|
return True
|
|
except socket.error:
|
|
return False
|
|
finally:
|
|
if s:
|
|
s.close()
|
|
uid = str(uuid.uuid4())
|
|
cmd = "echo client_id: %s && nohup python3.8 /root/heygem.py >> /root/logs/%s.log 2>&1 &" % (uid, uid)
|
|
subprocess.run(cmd, shell=True, check=True)
|
|
timeout = 30
|
|
while timeout > 0:
|
|
check = check_port_in_use(8383)
|
|
if check:
|
|
loguru.logger.success("HeyGem Server started on port 8383")
|
|
break
|
|
else:
|
|
timeout -= 1
|
|
time.sleep(2)
|
|
if timeout == 0:
|
|
raise TimeoutError("HeyGem Server Start timed out")
|
|
|
|
|
|
@modal.method()
|
|
def infer(self, code, video_path, audio_path) -> str:
|
|
def task_submit(uid, video, audio, heygem_url):
|
|
"""Submit a task to the API"""
|
|
task_submit_api = f'{heygem_url}/easy/submit'
|
|
result_json = {
|
|
'status': False, 'data': {}, 'msg': ''
|
|
}
|
|
try:
|
|
data = {
|
|
"code": uid,
|
|
"video_url": video,
|
|
"audio_url": audio,
|
|
"chaofen": 1,
|
|
"watermark_switch": 0,
|
|
"pn": 1
|
|
}
|
|
loguru.logger.info(f'data={json.dumps(data, ensure_ascii=False)}')
|
|
with httpx.Client() as client:
|
|
resp = client.post(task_submit_api, json=data)
|
|
resp_dict = resp.json()
|
|
loguru.logger.info(f'submit data: {json.dumps(resp_dict, ensure_ascii=False)}')
|
|
if resp_dict['code'] != 10000:
|
|
result_json['status'] = False
|
|
result_json['msg'] = result_json['msg']
|
|
else:
|
|
result_json['status'] = True
|
|
result_json['data'] = uid
|
|
result_json['msg'] = '任务提交成功'
|
|
return result_json
|
|
except Exception as e:
|
|
loguru.logger.info(f'submit task fail case by:{str(e)}')
|
|
raise RuntimeError(str(e))
|
|
|
|
def query_task_progress(heygem_url, heygem_temp_path, task_id: str, interval: int = 10, timeout: int = 60 * 35):
|
|
"""Query task progress and wait for completion"""
|
|
result_json = {'status': False, 'data': {}, 'msg': ''}
|
|
|
|
def query_result(t_id: str):
|
|
tmp_dict = {'status': True, 'data': dict(), 'msg': ''}
|
|
try:
|
|
query_task_url = f'{heygem_url}/easy/query'
|
|
params = {
|
|
'code': t_id
|
|
}
|
|
with httpx.Client() as client:
|
|
resp = client.get(query_task_url, params=params)
|
|
resp_dict = resp.json()
|
|
loguru.logger.info(f'query task data: {json.dumps(resp_dict, ensure_ascii=False)}')
|
|
status_code = resp_dict['code']
|
|
if status_code in (9999, 10002, 10003, 10001):
|
|
tmp_dict['status'] = False
|
|
tmp_dict['msg'] = resp_dict['msg']
|
|
elif status_code == 10000:
|
|
status_code = resp_dict['data'].get('status', 1)
|
|
if status_code == 3:
|
|
tmp_dict['status'] = False
|
|
tmp_dict['msg'] = resp_dict['data']['msg']
|
|
else:
|
|
process = resp_dict['data'].get('progress', 20)
|
|
if status_code == 2:
|
|
process = 100
|
|
else:
|
|
process = process
|
|
result = resp_dict['data'].get('result', '')
|
|
tmp_dict['data'] = {'progress': process,
|
|
'path': result,
|
|
}
|
|
else:
|
|
pass
|
|
except Exception as e:
|
|
loguru.logger.info(f'query task fail case by:{str(e)}')
|
|
raise RuntimeError(str(e))
|
|
return tmp_dict
|
|
|
|
end = time.time() + timeout
|
|
while time.time() < end:
|
|
tmp_result = query_result(task_id)
|
|
if not tmp_result['status'] or tmp_result['data'].__eq__({}):
|
|
result_json['status'] = False
|
|
result_json['msg'] = tmp_result['msg']
|
|
break
|
|
else:
|
|
process = tmp_result['data']['progress']
|
|
loguru.logger.info(f'query task progress :{process}')
|
|
if tmp_result['data']['progress'] < 100:
|
|
time.sleep(interval)
|
|
loguru.logger.info(f'wait next interval:{interval}')
|
|
else:
|
|
p = tmp_result['data']['path']
|
|
p = p.replace('/', '').replace('\\', '')
|
|
result_json['data'] = "%s/%s" % (heygem_temp_path, p)
|
|
result_json['status'] = True
|
|
break
|
|
return result_json
|
|
|
|
try:
|
|
heygem_url = "http://127.0.0.1:8383"
|
|
heygem_temp_path = "/code/data/temp"
|
|
submit_result = task_submit(code, video_path, audio_path, heygem_url)
|
|
if not submit_result['status']:
|
|
raise Exception(f"Task submission failed: {submit_result['msg']}")
|
|
|
|
task_id = submit_result['data']
|
|
loguru.logger.info(f'Submitted task: {task_id}')
|
|
|
|
# Query task progress
|
|
progress_result = query_task_progress(heygem_url, heygem_temp_path, task_id, interval=5)
|
|
|
|
if not progress_result['status']:
|
|
raise RuntimeError(f"Task processing failed: {progress_result['msg']}")
|
|
|
|
# Return the file for download
|
|
file_path = progress_result['data']
|
|
if not os.path.exists(file_path):
|
|
raise FileNotFoundError(f"Output file not found at {file_path}")
|
|
return file_path
|
|
except Exception as e:
|
|
loguru.logger.error(f"Error processing request || {str(e)}")
|
|
raise Exception(str(e))
|
|
finally:
|
|
try:
|
|
os.remove(video_path)
|
|
os.remove(audio_path)
|
|
except:
|
|
pass
|
|
|
|
def download_file(self, url, path):
|
|
with requests.get(url, stream=True) as r:
|
|
with open(path, 'wb') as f:
|
|
shutil.copyfileobj(r.raw, f)
|
|
|
|
@modal.fastapi_endpoint(method="POST")
|
|
async def api(self, video_file: Optional[Union[UploadFile, str]], audio_file: Optional[Union[UploadFile, str]], token: HTTPAuthorizationCredentials = Depends(auth_scheme)):
|
|
if token.credentials != os.environ["AUTH_TOKEN"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect bearer token",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
code = str(uuid.uuid4())
|
|
try:
|
|
if isinstance(video_file, UploadFile):
|
|
video_path = "/root/%s" % video_file.filename
|
|
with open(video_path, "wb") as f:
|
|
data = await video_file.read()
|
|
f.write(data)
|
|
elif isinstance(video_file, str) and video_file.startswith("http"):
|
|
video_path = "/root/%s" % video_file.split("?")[0].split("/")[-1]
|
|
self.download_file(video_file, video_path)
|
|
else:
|
|
return {"status": "fail", "msg": "video file save failed: Not a valid input"}
|
|
|
|
if isinstance(audio_file, UploadFile):
|
|
audio_path = "/root/%s" % audio_file.filename
|
|
with open(audio_path, "wb") as f:
|
|
data = await audio_file.read()
|
|
f.write(data)
|
|
elif isinstance(audio_file, str) and audio_file.startswith("http"):
|
|
audio_path = "/root/%s" % audio_file.split("?")[0].split("/")[-1]
|
|
self.download_file(audio_file, audio_path)
|
|
else:
|
|
return {"status": "fail", "msg": "audio file save failed: Not a valid input"}
|
|
except Exception as e:
|
|
return {"status": "fail", "msg": f"video or audio file save failed: {str(e)}"}
|
|
|
|
loguru.logger.info(f"Enter Inference Module")
|
|
try:
|
|
file_path = self.infer.local(code, video_path, audio_path)
|
|
try:
|
|
shutil.copy(file_path, os.path.join("/code/data/final",file_path.split(os.path.sep)[-1]))
|
|
except Exception as e:
|
|
loguru.logger.info(f"Moving {file_path} to S3 Failed: {str(e)}")
|
|
try:
|
|
print("Try move file to S3 manually")
|
|
# S3 Fallback
|
|
import boto3
|
|
import yaml
|
|
with open("/root/config.yaml", encoding="utf-8",
|
|
mode="r+") as config:
|
|
yaml_config = yaml.load(config, Loader=yaml.FullLoader)
|
|
awss3 = boto3.resource('s3', aws_access_key_id=yaml_config["aws_key_id"],
|
|
aws_secret_access_key=yaml_config["aws_access_key"])
|
|
awss3.meta.client.upload_file(file_path, bucket_output, file_path.split(os.path.sep)[-1])
|
|
except:
|
|
return {"status": "fail", "msg": "Failed to move file to S3 manually: " + str(e)}
|
|
# resp = FileResponse(path=file_path, filename=Path(file_path).name, status_code=200, headers={
|
|
# "Content-Type": "application/octet-stream",
|
|
# "Content-Disposition": f"attachment; filename={Path(file_path).name}",
|
|
# })
|
|
return {"status": "success", "msg":f"{file_path.split(os.path.sep)[-1]}"}
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return {"status": "fail", "msg": "Inference module failed: "+str(e)}
|
|
|
|
|