modalDeploy/AutoDL/autodl_scheduling/server.py

126 lines
5.6 KiB
Python

import time
import traceback
import uuid
from concurrent.futures.thread import ThreadPoolExecutor
import loguru
import uvicorn
from fastapi import FastAPI, Depends, HTTPException
from fastapi.routing import APIRoute
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from starlette import status
from AutoDL.autodl_scheduling.entity.instance_pool import InstancePool, Instance
from AutoDL.autodl_scheduling.entity.result_map import ResultMap
from AutoDL.autodl_scheduling.entity.running_pool import RunningPool
from AutoDL.autodl_scheduling.entity.waiting_queue import WaitingQueue
class Server:
def __init__(self):
self.app = FastAPI()
self.waiting_queue = WaitingQueue()
self.running_pool = RunningPool()
#账号限制max_instance不能超过30
self.instance_pool = InstancePool(max_instance=2)
self.result_map = ResultMap()
self.executor = ThreadPoolExecutor(max_workers=2)
self.worker_1 = self.executor.submit(self.scaling_worker)
self.worker_2 = self.executor.submit(self.introspect_instance)
self.app.routes.append(APIRoute("/", endpoint=self.submit, methods=['POST']))
self.app.routes.append(APIRoute("/{uid}", endpoint=self.get_result, methods=['GET']))
self.app.routes.append(APIRoute("/", endpoint=self.get_all_result, methods=['GET']))
# async def submit(self, video_url, audio_url, token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
async def submit(self, video_url, audio_url):
# if token.credentials != "7468B79CA2CF53436C06B29A6F16451C":
# raise HTTPException(
# status_code=status.HTTP_401_UNAUTHORIZED,
# detail="Incorrect bearer token",
# headers={"WWW-Authenticate": "Bearer"},
# )
uid = str(uuid.uuid4())
try:
self.waiting_queue.enqueue(uid, video_url, audio_url)
return {"uid": uid}
except:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
async def get_result(self, uid):
try:
return self.result_map.get(uid)
except:
l = list(self.waiting_queue.queue.queue)
for i in l:
if i["uid"] == uid:
return {"status": "queuing", "msg":""}
return {"status": "not found", "msg":""}
async def get_all_result(self):
try:
return self.result_map.map
except:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
def introspect_instance(self):
loguru.logger.info("Introspecting worker started || Scaledown Window: %ds" % self.instance_pool.scaledown_window)
while True:
try:
self.instance_pool.introspection()
time.sleep(1)
except:
traceback.print_exc()
def scaling_worker(self):
loguru.logger.info("Scaling worker started")
while True:
try:
# 提交任务
self.instance_pool.scale_instance(self.waiting_queue.get_size()+self.running_pool.get_running_size(), disable_shrink=True)
for instance in self.instance_pool.instances:
if not instance.active and self.waiting_queue.get_size() > 0:
# 从等待队列取出任务
uid, video_path, audio_path = self.waiting_queue.dequeue()
loguru.logger.info("Task[%s] Submitting to Instance[%s]" % (uid, instance.uuid))
# 提交任务到运行池
if self.running_pool.run(instance.uuid, uid, instance.domain, video_path, audio_path):
# 更新实例池状态
instance.active = True
instance.last_active_time = time.time()
loguru.logger.info("Task Submitted")
self.result_map.set(uid, {"status": "running", "msg": "Task Submitted, Waiting for result"})
else:
loguru.logger.error("Submit Task Failed")
self.result_map.set(uid, {"status":"fail", "msg":"Submit Task Failed"})
#查询结果放到结果缓存里
for uid, task in list(self.running_pool.tasks.items()):
result = self.running_pool.result(uid)
# 任务运行完成或失败
if result:
self.result_map.update(uid, result)
self.running_pool.tasks.pop(uid)
if result["status"] == "fail":
#删除失败任务的实例
self.instance_pool.remove_instance(Instance(uuid=task["instance_id"]))
loguru.logger.info("Instance[%s] Removed Due to Task Failure" % task["instance_id"])
else:
#更新成功任务实例的状态
for instance in self.instance_pool.instances:
if instance.uuid == task["instance_id"]:
loguru.logger.info("Instance[%s] Task Finished" % instance.uuid)
instance.active = False
instance.last_active_time = time.time()
time.sleep(0.5)
except:
traceback.print_exc()
if __name__=="__main__":
server = Server()
uvicorn.run(server.app, host="127.0.0.1", port=8888)