From 1272a33718488e35e92ffc88a60547490bf74923 Mon Sep 17 00:00:00 2001 From: "kyj@bowong.ai" Date: Wed, 16 Apr 2025 18:28:15 +0800 Subject: [PATCH] =?UTF-8?q?ADD=20AutoDL=E8=B0=83=E5=BA=A6=E5=9F=BA?= =?UTF-8?q?=E6=9C=AC=E5=8A=9F=E8=83=BD=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AutoDL/AutoDL_pure_heygem.py | 57 ++++++--- .../autodl_scheduling/entity/instance_pool.py | 115 ++++++++++++++++++ AutoDL/autodl_scheduling/entity/result_map.py | 36 ++++++ .../autodl_scheduling/entity/running_pool.py | 80 ++++++++++++ .../autodl_scheduling/entity/waiting_queue.py | 21 ++++ AutoDL/autodl_scheduling/server.py | 110 +++++++++++++++++ .../util}/audodl_sdk.py | 65 ++++++---- server_with_s3_auth.py | 2 +- 8 files changed, 446 insertions(+), 40 deletions(-) create mode 100644 AutoDL/autodl_scheduling/entity/instance_pool.py create mode 100644 AutoDL/autodl_scheduling/entity/result_map.py create mode 100644 AutoDL/autodl_scheduling/entity/running_pool.py create mode 100644 AutoDL/autodl_scheduling/entity/waiting_queue.py create mode 100644 AutoDL/autodl_scheduling/server.py rename AutoDL/{ => autodl_scheduling/util}/audodl_sdk.py (80%) diff --git a/AutoDL/AutoDL_pure_heygem.py b/AutoDL/AutoDL_pure_heygem.py index 5b56462..660c454 100644 --- a/AutoDL/AutoDL_pure_heygem.py +++ b/AutoDL/AutoDL_pure_heygem.py @@ -6,12 +6,15 @@ import subprocess import time import traceback import uuid +from concurrent.futures.thread import ThreadPoolExecutor from typing import Any, Optional, Union import httpx import loguru import requests +import starlette.datastructures import uvicorn + from fastapi import UploadFile, HTTPException, Depends, FastAPI from fastapi.routing import APIRoute from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -47,8 +50,12 @@ class HeyGem: time.sleep(2) if timeout == 0: raise TimeoutError("HeyGem Server Start timed out") + self.result = {} + self.executor = ThreadPoolExecutor(max_workers=1) self.server = FastAPI() self.server.routes.append(APIRoute(path='/', endpoint=self.api, methods=['POST'])) + self.server.routes.append(APIRoute(path='/{code}', endpoint=self.get_result, methods=['GET'])) + def infer(self, code, video_path, audio_path) -> str: @@ -182,6 +189,18 @@ class HeyGem: with open(path, 'wb') as f: shutil.copyfileobj(r.raw, f) + async def get_result(self, code:str, token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): + if token.credentials != os.environ["AUTH_TOKEN"]: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect bearer token", + headers={"WWW-Authenticate": "Bearer"}, + ) + if code in self.result: + return self.result[code] + else: + return {"status": "waiting", "msg":""} + async def api(self, video_file: Optional[Union[UploadFile, str]], audio_file: Optional[Union[UploadFile, str]], token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): if token.credentials != os.environ["AUTH_TOKEN"]: raise HTTPException( @@ -190,31 +209,41 @@ class HeyGem: headers={"WWW-Authenticate": "Bearer"}, ) code = str(uuid.uuid4()) + video_path=None + audio_path=None try: - if isinstance(video_file, UploadFile): - video_path = "/root/%s" % video_file.filename + file_id = str(uuid.uuid4()) + if isinstance(video_file, UploadFile) or isinstance(video_file, starlette.datastructures.UploadFile): + video_path = "/root/%s.%s" % (file_id, video_file.filename.split(".")[-1]) with open(video_path, "wb") as f: data = await video_file.read() f.write(data) elif isinstance(video_file, str) and video_file.startswith("http"): - video_path = "/root/%s" % video_file.split("?")[0].split("/")[-1] + video_path = "/root/%s.%s" % (file_id, video_file.split("?")[0].split("/")[-1].split(".")[-1]) self.download_file(video_file, video_path) else: - return {"status": "fail", "msg": "video file save failed: Not a valid input"} + self.result[code] = {"status": "fail", "msg": "video file save failed: Not a valid input"} - if isinstance(audio_file, UploadFile): - audio_path = "/root/%s" % audio_file.filename + if isinstance(audio_file, UploadFile) or isinstance(audio_file, starlette.datastructures.UploadFile): + audio_path = "/root/%s.%s" % (file_id, audio_file.filename.split(".")[-1]) with open(audio_path, "wb") as f: data = await audio_file.read() f.write(data) elif isinstance(audio_file, str) and audio_file.startswith("http"): - audio_path = "/root/%s" % audio_file.split("?")[0].split("/")[-1] + audio_path = "/root/%s.%s" % (file_id, audio_file.split("?")[0].split("/")[-1].split(".")[-1]) self.download_file(audio_file, audio_path) else: - return {"status": "fail", "msg": "audio file save failed: Not a valid input"} + self.result[code] = {"status": "fail", "msg": "audio file save failed: Not a valid input"} except Exception as e: - return {"status": "fail", "msg": f"video or audio file save failed: {str(e)}"} + self.result[code] = {"status": "fail", "msg": f"video or audio file save failed: {str(e)}"} + # Submit + if video_path and audio_path: + self.executor.submit(self._api, video_path=video_path, audio_path=audio_path, code=code) + return {"status": "success", "code": code} + else: + return self.result[code] + def _api(self, video_path: str, audio_path: str, code): loguru.logger.info(f"Enter Inference Module") try: file_path = self.infer(code, video_path, audio_path) @@ -230,15 +259,11 @@ class HeyGem: aws_secret_access_key=yaml_config["aws_access_key"]) awss3.meta.client.upload_file(file_path, "bw-heygem-output", file_path.split(os.path.sep)[-1]) except Exception as e: - return {"status": "fail", "msg": "Failed to move file to S3 manually: " + str(e)} - # resp = FileResponse(path=file_path, filename=Path(file_path).name, status_code=200, headers={ - # "Content-Type": "application/octet-stream", - # "Content-Disposition": f"attachment; filename={Path(file_path).name}", - # }) - return {"status": "success", "msg":f"{file_path.split(os.path.sep)[-1]}"} + self.result[code] = {"status": "fail", "msg": "Failed to move file to S3 manually: " + str(e)} + self.result[code] = {"status": "success", "msg":f"{file_path.split(os.path.sep)[-1]}"} except Exception as e: traceback.print_exc() - return {"status": "fail", "msg": "Inference module failed: "+str(e)} + self.result[code] = {"status": "fail", "msg": "Inference module failed: "+str(e)} if __name__ == "__main__": heygem = HeyGem() diff --git a/AutoDL/autodl_scheduling/entity/instance_pool.py b/AutoDL/autodl_scheduling/entity/instance_pool.py new file mode 100644 index 0000000..ff8c1d3 --- /dev/null +++ b/AutoDL/autodl_scheduling/entity/instance_pool.py @@ -0,0 +1,115 @@ +import copy +import os +import random +import time +from concurrent.futures.thread import ThreadPoolExecutor +from typing import List + +import loguru + +from AutoDL.autodl_scheduling.util.audodl_sdk import instance_operate, get_autodl_machines, payg + +RETRY_LIMIT = 20 # 创建实例重试次数 + + +class Instance: + def __init__(self, uuid:str, active:bool=False, last_active_time:float=-1., domain:str=""): + self.uuid:str=uuid + self.active:bool=active + self.last_active_time:float=last_active_time + self.domain:str=domain + + def __str__(self): + return "uuid:%s active:%s last_active_time:%s domain:%s" % (self.uuid, self.active, self.last_active_time, self.domain) + + +class InstancePool: + def __init__(self, min_instance=0, max_instance=100, scaledown_window=120, buffer_instance=0, timeout=1200): + self.min_instance = min_instance + self.max_instance = max_instance + self.scaledown_window = scaledown_window + self.buffer_instance = buffer_instance + self.timeout = timeout + self.instances:List[Instance] = [] + self.executor = ThreadPoolExecutor(max_workers=os.cpu_count()*2) + self.threads = [] + + def scale_instance(self, target_instance): + if target_instance + self.buffer_instance < self.min_instance: + return self._scale(self.min_instance) + if target_instance + self.buffer_instance > self.max_instance: + return self._scale(self.max_instance) + if target_instance + self.buffer_instance == len(self.instances): + return True + return self._scale(target_instance + self.buffer_instance) + + def remove_instance(self, instance:Instance): + if instance_operate(instance.uuid, "power_off"): + if instance_operate(instance.uuid, "release"): + for i in self.instances: + if i.uuid == instance.uuid: + self.instances.remove(i) + else: + loguru.logger.error("Instance {} failed to release".format(instance.uuid)) + else: + loguru.logger.error("Instance {} failed to power off".format(instance.uuid)) + + def _add_instance(self): + lim = RETRY_LIMIT + while lim > 0: + machines = get_autodl_machines() + if len(machines) > 0: + m = random.choice(machines) + result = payg(m["region_name"], m["machine_id"]) + if result: + self.instances.append( + Instance(uuid=result[0], active=False, last_active_time=time.time(), domain="https://"+result[1])) + break + else: + time.sleep(1) + lim -= 1 + if lim <= 0: + loguru.logger.error("Fail to Scale[Add] Instance") + + def introspection(self): + # 停止超时实例(运行超时和无任务超时) + instance_copy = copy.deepcopy(self.instances) + for instance in instance_copy: + if instance.active: + if (time.time() - instance.last_active_time) > self.timeout: + self.threads.append(self.executor.submit(self.remove_instance, instance=instance)) + else: + if (time.time() - instance.last_active_time) > self.scaledown_window: + self.threads.append(self.executor.submit(self.remove_instance, instance=instance)) + + def _scale(self, target_instance:int): + loguru.logger.info("Instance Num Before Scaling %d ; Target %d" % (len(self.instances), target_instance)) + self.introspection() + # 调整实例数量 + instance_copy = copy.deepcopy(self.instances) + dest = target_instance - len(instance_copy) + if dest < 0: + dest = abs(dest) + for instance in instance_copy: + if not instance.active and dest > 0: + self.threads.append(self.executor.submit(self.remove_instance, instance=instance)) + dest -= 1 + elif dest > 0: + for i in range(dest): + self.threads.append(self.executor.submit(self._add_instance)) + while len(self.threads) > 0: + for t in self.threads: + t.result(timeout=self.timeout//2) + self.threads.remove(t) + loguru.logger.info("Instance Num After Scaling %d ; Target %d" % (len(self.instances), target_instance)) + if len(self.instances) == target_instance: + return True + else: + return False + + +if __name__ == "__main__": + ip = InstancePool() + print(ip.scale_instance(5)) + time.sleep(5) + print(ip.scale_instance(0)) \ No newline at end of file diff --git a/AutoDL/autodl_scheduling/entity/result_map.py b/AutoDL/autodl_scheduling/entity/result_map.py new file mode 100644 index 0000000..3aafec1 --- /dev/null +++ b/AutoDL/autodl_scheduling/entity/result_map.py @@ -0,0 +1,36 @@ +from typing import Union + + +class ResultMap: + def __init__(self, max_size:int=1000): + self.map = dict() + self.max_size = max_size + + def get(self, uid:str) -> Union[dict,None]: + if uid in self.map: + if self.map[uid]: + r = self.map[uid] + # self.remove(uid) + return r + else: + return None + else: + raise KeyError(uid) + + def set(self, uid:str, result:dict=None) -> None: + if uid in self.map: + raise RuntimeError(f"Key {uid} already exists") + else: + if len(self.map) + 1 > self.max_size: + raise RuntimeError(f"ResultMap exceeds max size {self.max_size}") + self.map[uid] = result + + def update(self, uid:str, result:dict) -> None: + if uid in self.map: + self.map[uid] = result + else: + raise KeyError(uid) + + def remove(self, uid:str) -> None: + if uid in self.map: + del self.map[uid] \ No newline at end of file diff --git a/AutoDL/autodl_scheduling/entity/running_pool.py b/AutoDL/autodl_scheduling/entity/running_pool.py new file mode 100644 index 0000000..5bc8927 --- /dev/null +++ b/AutoDL/autodl_scheduling/entity/running_pool.py @@ -0,0 +1,80 @@ +import time +import uuid +from concurrent.futures.thread import ThreadPoolExecutor +from typing import Union + +import loguru +import requests +from fastapi import UploadFile +SERVER_TIME_OUT = 60 * 35 + +class RunningPool: + def __init__(self): + self.tasks = {} + self._headers = { + "Authorization": "Bearer 7468B79CA2CF53436C06B29A6F16451C" + } + + def run(self, instance_id, uid, base_url:str, video_file:Union[UploadFile,str], audio_file:Union[UploadFile,str]) -> Union[bool, None]: + try: + code = self._req(base_url, video_file, audio_file) + except Exception as e: + loguru.logger.error("%s Task Submit Failed: %s" % (uid, e)) + return None + self.tasks[uid] = {"instance_id": instance_id, "base_url":base_url, "code": code, "status": False, "start_time": time.time()} + return True + + def result(self, uid:str) -> Union[dict, None]: + if uid in self.tasks: + try: + resp = requests.get(self.tasks[uid]["base_url"] + "/" + self.tasks[uid]["code"], headers=self._headers) + if resp.status_code != 200: + raise Exception("%s Get Result Failed: %s" % (uid, resp.status_code)) + else: + if resp.json()["status"] == "waiting": + if self.tasks[uid]["start_time"] + SERVER_TIME_OUT + 10 < time.time(): + self.tasks[uid]["status"] = True + return {"status": "fail", "msg": "Instance Timeout", "instance_id": self.tasks[uid]["instance_id"]} + return None + else: + self.tasks[uid]["status"] = True + return {"status": resp.json()['status'], "msg": resp.json()['msg'], + "instance_id": self.tasks[uid]["instance_id"]} + except Exception as e: + loguru.logger.error("%s Get Result Failed: %s" % (uid, e)) + return None + else: + raise KeyError(uid) + + def get_running_size(self): + return len([i for i in self.tasks.values() if not i["status"]]) + + def _req(self, base_url, video_file, audio_file): + data = { + "video_file": video_file, + "audio_file": audio_file + } + resp = requests.post(base_url, data, headers=self._headers, allow_redirects=True, stream=True) + if resp.status_code == 200: + if resp.json()["status"] == "success": + code = resp.json()["code"] + return code + else: + return None + else: + return None + +if __name__ == "__main__": + rp = RunningPool() + id = rp.run("https://u336391-b31a-07132bc4.westx.seetacloud.com:8443", + "https://sucai-1324682537.cos.ap-shanghai.myqcloud.com/tiktok/video/1111.mp4", + "https://sucai-1324682537.cos.ap-shanghai.myqcloud.com/tiktok/video/XBR037ruAZsA.mp3") + while True: + r = rp.result(id) + if r: + loguru.logger.success(r) + break + else: + loguru.logger.info("waiting...") + time.sleep(5) + diff --git a/AutoDL/autodl_scheduling/entity/waiting_queue.py b/AutoDL/autodl_scheduling/entity/waiting_queue.py new file mode 100644 index 0000000..2e7e8e8 --- /dev/null +++ b/AutoDL/autodl_scheduling/entity/waiting_queue.py @@ -0,0 +1,21 @@ +from queue import Queue + + +class WaitingQueue: + def __init__(self): + self.queue = Queue(maxsize=500) + + def enqueue(self,uid, video_path,audio_path): + data = { + "uid": uid, + "video_path": video_path, + "audio_path": audio_path + } + self.queue.put(data) + + def dequeue(self): + data = self.queue.get() + return data["uid"], data["video_path"], data["audio_path"] + + def get_size(self): + return self.queue.qsize() diff --git a/AutoDL/autodl_scheduling/server.py b/AutoDL/autodl_scheduling/server.py new file mode 100644 index 0000000..8e108f2 --- /dev/null +++ b/AutoDL/autodl_scheduling/server.py @@ -0,0 +1,110 @@ +import time +import traceback +import uuid +from concurrent.futures.thread import ThreadPoolExecutor + +import loguru +import uvicorn +from fastapi import FastAPI, Depends, HTTPException +from fastapi.routing import APIRoute +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from starlette import status + +from AutoDL.autodl_scheduling.entity.instance_pool import InstancePool, Instance +from AutoDL.autodl_scheduling.entity.result_map import ResultMap +from AutoDL.autodl_scheduling.entity.running_pool import RunningPool +from AutoDL.autodl_scheduling.entity.waiting_queue import WaitingQueue + + +class Server: + def __init__(self): + self.app = FastAPI() + self.waiting_queue = WaitingQueue() + self.running_pool = RunningPool() + self.instance_pool = InstancePool(max_instance=29) + self.result_map = ResultMap() + self.executor = ThreadPoolExecutor(max_workers=2) + self.worker_1 = self.executor.submit(self.scaling_worker) + self.worker_2 = self.executor.submit(self.introspect_instance) + self.app.routes.append(APIRoute("/", endpoint=self.submit, methods=['POST'])) + self.app.routes.append(APIRoute("/{uid}", endpoint=self.get_result, methods=['GET'])) + + # async def submit(self, video_url, audio_url, token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): + async def submit(self, video_url, audio_url): + # if token.credentials != "7468B79CA2CF53436C06B29A6F16451C": + # raise HTTPException( + # status_code=status.HTTP_401_UNAUTHORIZED, + # detail="Incorrect bearer token", + # headers={"WWW-Authenticate": "Bearer"}, + # ) + uid = str(uuid.uuid4()) + try: + self.waiting_queue.enqueue(uid, video_url, audio_url) + return {"uid": uid} + except: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + + async def get_result(self, uid): + try: + return self.result_map.get(uid) + except: + if uid in self.running_pool.tasks: + return {"status": "running", "msg": ""} + return {"status": "queuing", "msg":""} + + def introspect_instance(self): + loguru.logger.info("Introspecting worker started") + while True: + self.instance_pool.introspection() + time.sleep(0.5) + + def scaling_worker(self): + loguru.logger.info("Scaling worker started") + try: + last_target = 0 + while True: + # 提交任务 + if last_target != self.waiting_queue.get_size()+self.running_pool.get_running_size(): + last_target = self.waiting_queue.get_size()+self.running_pool.get_running_size() + self.instance_pool.scale_instance(last_target) + for instance in self.instance_pool.instances: + if not instance.active: + # 从等待队列取出任务 + uid, video_path, audio_path = self.waiting_queue.dequeue() + # 提交任务到运行池 + if self.running_pool.run(instance.uuid, uid, instance.domain, video_path, audio_path): + # 更新实例池状态 + instance.active = True + instance.last_active_time = time.time() + loguru.logger.info("Task Submitted") + else: + loguru.logger.error("Submit Task Failed") + + #查询结果放到结果缓存里 + for uid, task in list(self.running_pool.tasks.items()): + result = self.running_pool.result(uid) + # 任务运行完成或失败 + if result: + self.result_map.set(uid, result) + self.running_pool.tasks.pop(uid) + if result["status"] == "fail": + #删除失败任务的实例 + self.instance_pool.remove_instance(Instance(uuid=task["instance_id"])) + else: + #更新成功任务实例的状态 + for instance in self.instance_pool.instances: + if instance.uuid == task["instance_id"]: + instance.active = False + instance.last_active_time = time.time() + time.sleep(0.5) + except: + traceback.print_exc() + + +if __name__=="__main__": + server = Server() + uvicorn.run(server.app, host="127.0.0.1", port=8888) + + + + diff --git a/AutoDL/audodl_sdk.py b/AutoDL/autodl_scheduling/util/audodl_sdk.py similarity index 80% rename from AutoDL/audodl_sdk.py rename to AutoDL/autodl_scheduling/util/audodl_sdk.py index 670b508..18a2ea1 100644 --- a/AutoDL/audodl_sdk.py +++ b/AutoDL/autodl_scheduling/util/audodl_sdk.py @@ -10,10 +10,11 @@ token = "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1aWQiOjMzNjM5MSwidXVpZCI6ImU2MD req_instance_page_size = 1500 chk_instance_page_size = 1500 instances = {} -LIM = 30 # 等待状态时间LIM*5s +LIM = 200 # 等待状态时间LIM*5s START_LIM = 20 #Heygem脚本启动等待超时时间-10 +DEBUG = False -def ssh_try(host,port,pwd): +def ssh_try(host,port,pwd,uid): # 建立连接 trans = paramiko.Transport((host, int(port))) trans.connect(username="root", password=pwd) @@ -29,10 +30,10 @@ def ssh_try(host,port,pwd): if len(out.split("\n")[-2]) > 2: trans.close() raise RuntimeError(out.split("\n")[-2]) + loguru.logger.info(f"[{uid}]Waiting for HeyGem server ready...") while int(out.split("\n")[-2]) <= 1 and start_lim > 0: ssh_stdin, ssh_stdout, ssh_stderr = ssh.exec_command("bash -ic \"sleep 2 && lsof -i:6006|wc -l\"") out = ssh_stdout.read().decode() - loguru.logger.info("waiting for HeyGem server ready...") start_lim -= 1 if start_lim <= 0: loguru.logger.error("HeyGem Server Start Timeout, Please check!") @@ -52,7 +53,7 @@ def get_autodl_machines() -> Union[list, None]: payload = { "charge_type":"payg", "region_sign":"", - "gpu_type_name":["RTX 4090", "RTX 4090D", "RTX 3090", "RTX 3080", "RTX 3080x2", "RTX 3080 Ti", "RTX 3060", "RTX A4000", "RTX 2080 Ti", "RTX 2080 Ti x2", "GTX 1080 Ti"], + "gpu_type_name":["V100-SXM2-32GB", "vGPU-32GB", "RTX 4090", "RTX 4090D", "RTX 3090", "RTX 3080", "RTX 3080x2", "RTX 3080 Ti", "RTX 3060", "RTX A4000", "RTX 2080 Ti", "RTX 2080 Ti x2", "GTX 1080 Ti"], "machine_tag_name":[], "gpu_idle_num":1, "mount_net_disk":False, @@ -70,15 +71,18 @@ def get_autodl_machines() -> Union[list, None]: "chip_corp":["nvidia"], "machine_id":"" } - loguru.logger.info("Req Machine index {}".format(index)) + if DEBUG: + loguru.logger.info("Req Machine index {}".format(index)) rsp = requests.post("https://www.autodl.com/api/v1/sub_user/user/machine/list", json=payload, headers=headers) if rsp.status_code == 200: machine_list = rsp.json() - loguru.logger.info("Machine Result Total {}".format(machine_list["data"]["result_total"])) + if DEBUG: + loguru.logger.info("Machine Result Total {}".format(machine_list["data"]["result_total"])) while index < machine_list["data"]["max_page"]: index += 1 - loguru.logger.info("Req Machine index {}/{}".format(index, machine_list["data"]["max_page"])) + if DEBUG: + loguru.logger.info("Req Machine index {}/{}".format(index, machine_list["data"]["max_page"])) payload["page_index"] = index rsp = requests.post("https://www.autodl.com/api/v1/sub_user/user/machine/list", json=payload, headers=headers) if rsp.status_code == 200: @@ -111,9 +115,9 @@ def get_autodl_machines() -> Union[list, None]: def payg(region_name:str, machine_id:str) -> tuple[Any, Any] | None: region_image = { - "西北": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-232ac04d3b"], - "内蒙": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-06814c02d1"], - "北京": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-e5334cc4f3"] + "西北": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-0d3ab01d64"], + "内蒙": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-bfff5e82ae"], + "北京": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-537b5da0c5"] } headers = { "Authorization": token, @@ -149,10 +153,10 @@ def payg(region_name:str, machine_id:str) -> tuple[Any, Any] | None: if j["code"] == "Success": lim = LIM while lim>0: - time.sleep(5) + time.sleep(3) status, host, port, pwd, domain = check_status(j['data']) if status == "running": - ssh_try(host, port, pwd) + ssh_try(host, port, pwd, j['data']) break else: lim = lim-1 @@ -188,22 +192,22 @@ def instance_operate(instance_uuid:str, operation: Literal["power_off","power_on if operation in dest_dict.keys(): while lim>0: time.sleep(5) - status = check_status(instance_uuid)[0] - if status == dest_dict[operation]: + status = check_status(instance_uuid) + if status and status[0] == dest_dict[operation]: break else: lim = lim-1 if lim > 0: - loguru.logger.success("Operate[%s] Instance Success" % operation) + loguru.logger.success("Operate[%s] Instance[%s] Success" % (operation, instance_uuid)) return True else: - loguru.logger.error("Operate[%s] Instance Error: Timeout, Please Check!!! instance_uuid(%s)" % (operation, instance_uuid)) + loguru.logger.error("Operate[%s] Instance[%s] Error: Timeout, Please Check!!!" % (operation, instance_uuid)) return False else: - loguru.logger.error("Operate[%s] Instance Error: %s" % (operation, j['msg'])) + loguru.logger.error("Operate[%s] Instance[%s] Error: %s" % (operation, instance_uuid, j['msg'])) return False else: - loguru.logger.error("Operate[%s] Instance Error: Status Code[%s]" % (operation, rsp.status_code)) + loguru.logger.error("Operate[%s] Instance[%s] Error: Status Code[%s]" % (operation, instance_uuid, rsp.status_code)) return False def check_status(instance_uuid:str) -> tuple[Any, Any, Any, Any, Any] | None: @@ -221,13 +225,27 @@ def check_status(instance_uuid:str) -> tuple[Any, Any, Any, Any, Any] | None: "status":[], "charge_type":[] } - # loguru.logger.info("Req Instance index {}".format(index)) - rsp = requests.post("https://www.autodl.com/api/v1/sub_user/instance", json=payload, headers=headers) + if DEBUG: + loguru.logger.info("Req Instance index {}".format(index)) + lim = 3 + while lim > 0: + rsp = None + try: + rsp = requests.post("https://www.autodl.com/api/v1/sub_user/instance", json=payload, headers=headers) + except: + lim -= 1 + if rsp: + break + if lim <= 0: + loguru.logger.error("Get Instance Req Error") + return None if rsp.status_code == 200: instance_list = rsp.json() - # loguru.logger.info("Instance Result Total {}".format(instance_list["data"]["result_total"])) + if DEBUG: + loguru.logger.info("Instance Result Total {}".format(instance_list["data"]["result_total"])) while index < instance_list["data"]["max_page"]: - # loguru.logger.info("Req Instance index {}/{}".format(index, instance_list["data"]["max_page"])) + if DEBUG: + loguru.logger.info("Req Instance index {}/{}".format(index, instance_list["data"]["max_page"])) payload["page_index"] = index rsp = requests.post("https://www.autodl.com/api/v1/sub_user/instance", json=payload, headers=headers) if rsp.status_code == 200: @@ -237,7 +255,8 @@ def check_status(instance_uuid:str) -> tuple[Any, Any, Any, Any, Any] | None: return None for l in instance_list["data"]["list"]: if l["uuid"] == instance_uuid: - loguru.logger.info("Instance {} Status {}".format(instance_uuid, l["status"])) + if DEBUG: + loguru.logger.info("Instance {} Status {}".format(instance_uuid, l["status"])) return l["status"], l["proxy_host"], l["ssh_port"], l["root_password"], l["tensorboard_domain"] loguru.logger.warning("Instance {} Not Found".format(instance_uuid)) return None diff --git a/server_with_s3_auth.py b/server_with_s3_auth.py index eb18c5f..682d8a2 100644 --- a/server_with_s3_auth.py +++ b/server_with_s3_auth.py @@ -101,7 +101,7 @@ auth_scheme = HTTPBearer() gpu=["L4", "T4"], cpu=(2,16), # memory=(32768, 32768), # (内存预留量, 内存使用上限) - memory=(20480,40960), + memory=(32768,131072), enable_memory_snapshot=False, secrets=[secret, modal.Secret.from_name("web_auth_token")], volumes={