modalDeploy/AutoDL/autodl_scheduling/entity/instance_pool.py

129 lines
5.2 KiB
Python

import copy
import os
import random
import time
from concurrent.futures.thread import ThreadPoolExecutor
from typing import List
import loguru
from AutoDL.autodl_scheduling.util.audodl_sdk import instance_operate, get_autodl_machines, payg
RETRY_LIMIT = 20 # 创建实例重试次数
class Instance:
def __init__(self, uuid:str, active:bool=False, last_active_time:float=-1., domain:str=""):
self.uuid:str=uuid
self.active:bool=active
self.last_active_time:float=last_active_time
self.domain:str=domain
def __str__(self):
return "uuid:%s active:%s last_active_time:%s domain:%s" % (self.uuid, self.active, self.last_active_time, self.domain)
class InstancePool:
def __init__(self, min_instance=0, max_instance=100, scaledown_window=120, buffer_instance=0, timeout=1200):
self.min_instance = min_instance
self.max_instance = max_instance
self.scaledown_window = scaledown_window
self.buffer_instance = buffer_instance
self.timeout = timeout
self.instances:List[Instance] = []
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count()*2)
self.threads = []
self.intro_threads = []
def scale_instance(self, target_instance, disable_shrink=True):
if target_instance + self.buffer_instance < self.min_instance:
return self._scale(self.min_instance)
if target_instance + self.buffer_instance > self.max_instance:
return self._scale(self.max_instance)
if target_instance + self.buffer_instance == len(self.instances):
return True
return self._scale(target_instance + self.buffer_instance, disable_shrink)
def remove_instance(self, instance:Instance):
if instance_operate(instance.uuid, "power_off"):
if instance_operate(instance.uuid, "release"):
for i in self.instances:
if i.uuid == instance.uuid:
self.instances.remove(i)
else:
loguru.logger.error("Instance {} failed to release".format(instance.uuid))
else:
loguru.logger.error("Instance {} failed to power off".format(instance.uuid))
def _add_instance(self):
lim = RETRY_LIMIT
while lim > 0:
machines = get_autodl_machines()
if len(machines) > 0:
m = random.choice(machines)
result = payg(m["region_name"], m["machine_id"])
if result:
self.instances.append(
Instance(uuid=result[0], active=False, last_active_time=time.time(), domain="https://"+result[1]))
break
else:
time.sleep(1)
lim -= 1
if lim <= 0:
loguru.logger.error("Fail to Scale[Add] Instance")
def introspection(self):
# 停止超时实例(运行超时和无任务超时)
before=len(self.instances)
instance_copy = copy.deepcopy(self.instances)
flag = False
for instance in instance_copy:
if instance.active:
if (time.time() - instance.last_active_time) > self.timeout:
flag = True
self.intro_threads.append(self.executor.submit(self.remove_instance, instance=instance))
else:
if (time.time() - instance.last_active_time) > self.scaledown_window:
flag = True
self.intro_threads.append(self.executor.submit(self.remove_instance, instance=instance))
while len(self.intro_threads) > 0:
for t in self.intro_threads:
t.result(timeout=self.timeout//2)
self.intro_threads.remove(t)
after = len(self.instances)
if flag:
loguru.logger.info("Instance Num Before Introspecting %d After Introspecting %d" % (before, after))
def _scale(self, target_instance:int, disable_shrink=True):
self.introspection()
# 调整实例数量
instance_copy = copy.deepcopy(self.instances)
dest = target_instance - len(instance_copy)
if (disable_shrink and dest < 0) or dest == 0:
return True
loguru.logger.info("Instance Num Before Scaling %d ; Target %d" % (len(self.instances), target_instance))
if dest < 0:
dest = abs(dest)
for instance in instance_copy:
if not instance.active and dest > 0:
self.threads.append(self.executor.submit(self.remove_instance, instance=instance))
dest -= 1
elif dest > 0:
for i in range(dest):
self.threads.append(self.executor.submit(self._add_instance))
while len(self.threads) > 0:
for t in self.threads:
t.result(timeout=self.timeout//2)
self.threads.remove(t)
loguru.logger.info("Instance Num After Scaling %d ; Target %d" % (len(self.instances), target_instance))
if len(self.instances) == target_instance:
return True
else:
return False
if __name__ == "__main__":
ip = InstancePool()
print(ip.scale_instance(5))
time.sleep(5)
print(ip.scale_instance(0))