126 lines
5.6 KiB
Python
126 lines
5.6 KiB
Python
import time
|
|
import traceback
|
|
import uuid
|
|
from concurrent.futures.thread import ThreadPoolExecutor
|
|
|
|
import loguru
|
|
import uvicorn
|
|
from fastapi import FastAPI, Depends, HTTPException
|
|
from fastapi.routing import APIRoute
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from starlette import status
|
|
|
|
from AutoDL.autodl_scheduling.entity.instance_pool import InstancePool, Instance
|
|
from AutoDL.autodl_scheduling.entity.result_map import ResultMap
|
|
from AutoDL.autodl_scheduling.entity.running_pool import RunningPool
|
|
from AutoDL.autodl_scheduling.entity.waiting_queue import WaitingQueue
|
|
|
|
|
|
class Server:
|
|
def __init__(self):
|
|
self.app = FastAPI()
|
|
self.waiting_queue = WaitingQueue()
|
|
self.running_pool = RunningPool()
|
|
#账号限制max_instance不能超过30
|
|
self.instance_pool = InstancePool(max_instance=2)
|
|
self.result_map = ResultMap()
|
|
self.executor = ThreadPoolExecutor(max_workers=2)
|
|
self.worker_1 = self.executor.submit(self.scaling_worker)
|
|
self.worker_2 = self.executor.submit(self.introspect_instance)
|
|
self.app.routes.append(APIRoute("/", endpoint=self.submit, methods=['POST']))
|
|
self.app.routes.append(APIRoute("/{uid}", endpoint=self.get_result, methods=['GET']))
|
|
self.app.routes.append(APIRoute("/", endpoint=self.get_all_result, methods=['GET']))
|
|
|
|
# async def submit(self, video_url, audio_url, token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
|
|
async def submit(self, video_url, audio_url):
|
|
# if token.credentials != "7468B79CA2CF53436C06B29A6F16451C":
|
|
# raise HTTPException(
|
|
# status_code=status.HTTP_401_UNAUTHORIZED,
|
|
# detail="Incorrect bearer token",
|
|
# headers={"WWW-Authenticate": "Bearer"},
|
|
# )
|
|
uid = str(uuid.uuid4())
|
|
try:
|
|
self.waiting_queue.enqueue(uid, video_url, audio_url)
|
|
return {"uid": uid}
|
|
except:
|
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
|
|
|
async def get_result(self, uid):
|
|
try:
|
|
return self.result_map.get(uid)
|
|
except:
|
|
l = list(self.waiting_queue.queue.queue)
|
|
for i in l:
|
|
if i["uid"] == uid:
|
|
return {"status": "queuing", "msg":""}
|
|
return {"status": "not found", "msg":""}
|
|
|
|
async def get_all_result(self):
|
|
try:
|
|
return self.result_map.map
|
|
except:
|
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
|
|
|
def introspect_instance(self):
|
|
loguru.logger.info("Introspecting worker started || Scaledown Window: %ds" % self.instance_pool.scaledown_window)
|
|
while True:
|
|
try:
|
|
self.instance_pool.introspection()
|
|
time.sleep(1)
|
|
except:
|
|
traceback.print_exc()
|
|
|
|
def scaling_worker(self):
|
|
loguru.logger.info("Scaling worker started")
|
|
while True:
|
|
try:
|
|
# 提交任务
|
|
self.instance_pool.scale_instance(self.waiting_queue.get_size()+self.running_pool.get_running_size(), disable_shrink=True)
|
|
for instance in self.instance_pool.instances:
|
|
if not instance.active and self.waiting_queue.get_size() > 0:
|
|
# 从等待队列取出任务
|
|
uid, video_path, audio_path = self.waiting_queue.dequeue()
|
|
loguru.logger.info("Task[%s] Submitting to Instance[%s]" % (uid, instance.uuid))
|
|
# 提交任务到运行池
|
|
if self.running_pool.run(instance.uuid, uid, instance.domain, video_path, audio_path):
|
|
# 更新实例池状态
|
|
instance.active = True
|
|
instance.last_active_time = time.time()
|
|
loguru.logger.info("Task[%s] Submitted" % uid)
|
|
self.result_map.set(uid, {"status": "running", "msg": "Task Submitted, Waiting for result"})
|
|
else:
|
|
loguru.logger.error("Submit Task[%s] Failed" % uid)
|
|
self.result_map.set(uid, {"status":"fail", "msg":"Submit Task Failed"})
|
|
|
|
#查询结果放到结果缓存里
|
|
for uid, task in list(self.running_pool.tasks.items()):
|
|
result = self.running_pool.result(uid)
|
|
# 任务运行完成或失败
|
|
if result:
|
|
self.result_map.update(uid, result)
|
|
self.running_pool.tasks.pop(uid)
|
|
if result["status"] == "fail":
|
|
#删除失败任务的实例
|
|
self.instance_pool.remove_instance(Instance(uuid=task["instance_id"]))
|
|
loguru.logger.info("Instance[%s] Removed Due to Task[%s] Failure" % (task["instance_id"], uid))
|
|
else:
|
|
#更新成功任务实例的状态
|
|
for instance in self.instance_pool.instances:
|
|
if instance.uuid == task["instance_id"]:
|
|
loguru.logger.info("Instance[%s] Task[%s] Finished" % (instance.uuid, uid))
|
|
instance.active = False
|
|
instance.last_active_time = time.time()
|
|
time.sleep(0.5)
|
|
except:
|
|
traceback.print_exc()
|
|
|
|
|
|
if __name__=="__main__":
|
|
server = Server()
|
|
uvicorn.run(server.app, host="127.0.0.1", port=8888)
|
|
|
|
|
|
|
|
|