import copy import os import random import time from concurrent.futures.thread import ThreadPoolExecutor from typing import List import loguru from AutoDL.autodl_scheduling.util.audodl_sdk import instance_operate, get_autodl_machines, payg RETRY_LIMIT = 20 # 创建实例重试次数 class Instance: def __init__(self, uuid:str, active:bool=False, last_active_time:float=-1., domain:str=""): self.uuid:str=uuid self.active:bool=active self.last_active_time:float=last_active_time self.domain:str=domain def __str__(self): return "uuid:%s active:%s last_active_time:%s domain:%s" % (self.uuid, self.active, self.last_active_time, self.domain) class InstancePool: def __init__(self, min_instance=0, max_instance=100, scaledown_window=120, buffer_instance=0, timeout=1200): self.min_instance = min_instance self.max_instance = max_instance self.scaledown_window = scaledown_window self.buffer_instance = buffer_instance self.timeout = timeout self.instances:List[Instance] = [] self.executor = ThreadPoolExecutor(max_workers=os.cpu_count()*2) self.threads = [] self.intro_threads = [] def scale_instance(self, target_instance, disable_shrink=True): if target_instance + self.buffer_instance < self.min_instance: return self._scale(self.min_instance) if target_instance + self.buffer_instance > self.max_instance: return self._scale(self.max_instance) if target_instance + self.buffer_instance == len(self.instances): return True return self._scale(target_instance + self.buffer_instance, disable_shrink) def remove_instance(self, instance:Instance): if instance_operate(instance.uuid, "power_off"): if instance_operate(instance.uuid, "release"): for i in self.instances: if i.uuid == instance.uuid: self.instances.remove(i) else: loguru.logger.error("Instance {} failed to release".format(instance.uuid)) else: loguru.logger.error("Instance {} failed to power off".format(instance.uuid)) def _add_instance(self): lim = RETRY_LIMIT while lim > 0: machines = get_autodl_machines() if len(machines) > 0: m = random.choice(machines) result = payg(m["region_name"], m["machine_id"]) if result: self.instances.append( Instance(uuid=result[0], active=False, last_active_time=time.time(), domain="https://"+result[1])) break else: time.sleep(1) lim -= 1 if lim <= 0: loguru.logger.error("Fail to Scale[Add] Instance") def introspection(self): # 停止超时实例(运行超时和无任务超时) timeout = self.timeout//6 while len(self.intro_threads) > 0 and timeout > 0: time.sleep(1) timeout -= 1 if timeout == 0: loguru.logger.error("Fail to Introspect Instances: Timeout") return try: before=len(self.instances) instance_copy = copy.deepcopy(self.instances) flag = [] for instance in instance_copy: if instance.active: if (time.time() - instance.last_active_time) > self.timeout: flag.append(instance.uuid) self.intro_threads.append(self.executor.submit(self.remove_instance, instance=instance)) else: if (time.time() - instance.last_active_time) > self.scaledown_window: flag.append(instance.uuid) self.intro_threads.append(self.executor.submit(self.remove_instance, instance=instance)) while len(self.intro_threads) > 0: for t in self.intro_threads: t.result(timeout=self.timeout//2) self.intro_threads.remove(t) after = len(self.instances) for instance in self.instances: if instance.uuid in flag: raise Exception("Instance[%s] Remove Failed" % instance.uuid) if len(flag) > 0: loguru.logger.info("Instance Num Before Introspecting %d After Introspecting %d" % (before, after)) except Exception as e: loguru.logger.error("Fail to Introspect Instances: %s" % e) def _scale(self, target_instance:int, disable_shrink=True): try: self.introspection() # 调整实例数量 instance_copy = copy.deepcopy(self.instances) dest = target_instance - len(instance_copy) if (disable_shrink and dest < 0) or dest == 0: return True loguru.logger.info("Instance Num Before Scaling %d ; Target %d" % (len(self.instances), target_instance)) if dest < 0: dest = abs(dest) for instance in instance_copy: if not instance.active and dest > 0: self.threads.append(self.executor.submit(self.remove_instance, instance=instance)) dest -= 1 elif dest > 0: for i in range(dest): self.threads.append(self.executor.submit(self._add_instance)) while len(self.threads) > 0: for t in self.threads: t.result(timeout=self.timeout//2) self.threads.remove(t) loguru.logger.info("Instance Num After Scaling %d ; Target %d" % (len(self.instances), target_instance)) if len(self.instances) == target_instance: return True else: return False except Exception as e: loguru.logger.error("Fail to Scale: %s" % e) return False if __name__ == "__main__": ip = InstancePool() print(ip.scale_instance(5)) time.sleep(5) print(ip.scale_instance(0))