115 lines
4.6 KiB
Python
115 lines
4.6 KiB
Python
import copy
|
|
import os
|
|
import random
|
|
import time
|
|
from concurrent.futures.thread import ThreadPoolExecutor
|
|
from typing import List
|
|
|
|
import loguru
|
|
|
|
from AutoDL.autodl_scheduling.util.audodl_sdk import instance_operate, get_autodl_machines, payg
|
|
|
|
RETRY_LIMIT = 20 # 创建实例重试次数
|
|
|
|
|
|
class Instance:
|
|
def __init__(self, uuid:str, active:bool=False, last_active_time:float=-1., domain:str=""):
|
|
self.uuid:str=uuid
|
|
self.active:bool=active
|
|
self.last_active_time:float=last_active_time
|
|
self.domain:str=domain
|
|
|
|
def __str__(self):
|
|
return "uuid:%s active:%s last_active_time:%s domain:%s" % (self.uuid, self.active, self.last_active_time, self.domain)
|
|
|
|
|
|
class InstancePool:
|
|
def __init__(self, min_instance=0, max_instance=100, scaledown_window=120, buffer_instance=0, timeout=1200):
|
|
self.min_instance = min_instance
|
|
self.max_instance = max_instance
|
|
self.scaledown_window = scaledown_window
|
|
self.buffer_instance = buffer_instance
|
|
self.timeout = timeout
|
|
self.instances:List[Instance] = []
|
|
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count()*2)
|
|
self.threads = []
|
|
|
|
def scale_instance(self, target_instance):
|
|
if target_instance + self.buffer_instance < self.min_instance:
|
|
return self._scale(self.min_instance)
|
|
if target_instance + self.buffer_instance > self.max_instance:
|
|
return self._scale(self.max_instance)
|
|
if target_instance + self.buffer_instance == len(self.instances):
|
|
return True
|
|
return self._scale(target_instance + self.buffer_instance)
|
|
|
|
def remove_instance(self, instance:Instance):
|
|
if instance_operate(instance.uuid, "power_off"):
|
|
if instance_operate(instance.uuid, "release"):
|
|
for i in self.instances:
|
|
if i.uuid == instance.uuid:
|
|
self.instances.remove(i)
|
|
else:
|
|
loguru.logger.error("Instance {} failed to release".format(instance.uuid))
|
|
else:
|
|
loguru.logger.error("Instance {} failed to power off".format(instance.uuid))
|
|
|
|
def _add_instance(self):
|
|
lim = RETRY_LIMIT
|
|
while lim > 0:
|
|
machines = get_autodl_machines()
|
|
if len(machines) > 0:
|
|
m = random.choice(machines)
|
|
result = payg(m["region_name"], m["machine_id"])
|
|
if result:
|
|
self.instances.append(
|
|
Instance(uuid=result[0], active=False, last_active_time=time.time(), domain="https://"+result[1]))
|
|
break
|
|
else:
|
|
time.sleep(1)
|
|
lim -= 1
|
|
if lim <= 0:
|
|
loguru.logger.error("Fail to Scale[Add] Instance")
|
|
|
|
def introspection(self):
|
|
# 停止超时实例(运行超时和无任务超时)
|
|
instance_copy = copy.deepcopy(self.instances)
|
|
for instance in instance_copy:
|
|
if instance.active:
|
|
if (time.time() - instance.last_active_time) > self.timeout:
|
|
self.threads.append(self.executor.submit(self.remove_instance, instance=instance))
|
|
else:
|
|
if (time.time() - instance.last_active_time) > self.scaledown_window:
|
|
self.threads.append(self.executor.submit(self.remove_instance, instance=instance))
|
|
|
|
def _scale(self, target_instance:int):
|
|
loguru.logger.info("Instance Num Before Scaling %d ; Target %d" % (len(self.instances), target_instance))
|
|
self.introspection()
|
|
# 调整实例数量
|
|
instance_copy = copy.deepcopy(self.instances)
|
|
dest = target_instance - len(instance_copy)
|
|
if dest < 0:
|
|
dest = abs(dest)
|
|
for instance in instance_copy:
|
|
if not instance.active and dest > 0:
|
|
self.threads.append(self.executor.submit(self.remove_instance, instance=instance))
|
|
dest -= 1
|
|
elif dest > 0:
|
|
for i in range(dest):
|
|
self.threads.append(self.executor.submit(self._add_instance))
|
|
while len(self.threads) > 0:
|
|
for t in self.threads:
|
|
t.result(timeout=self.timeout//2)
|
|
self.threads.remove(t)
|
|
loguru.logger.info("Instance Num After Scaling %d ; Target %d" % (len(self.instances), target_instance))
|
|
if len(self.instances) == target_instance:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
ip = InstancePool()
|
|
print(ip.scale_instance(5))
|
|
time.sleep(5)
|
|
print(ip.scale_instance(0)) |