ADD AutoDL调度基本功能完成

This commit is contained in:
kyj@bowong.ai 2025-04-16 18:28:15 +08:00
parent 61106608e0
commit 1272a33718
8 changed files with 446 additions and 40 deletions

View File

@ -6,12 +6,15 @@ import subprocess
import time
import traceback
import uuid
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Any, Optional, Union
import httpx
import loguru
import requests
import starlette.datastructures
import uvicorn
from fastapi import UploadFile, HTTPException, Depends, FastAPI
from fastapi.routing import APIRoute
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
@ -47,8 +50,12 @@ class HeyGem:
time.sleep(2)
if timeout == 0:
raise TimeoutError("HeyGem Server Start timed out")
self.result = {}
self.executor = ThreadPoolExecutor(max_workers=1)
self.server = FastAPI()
self.server.routes.append(APIRoute(path='/', endpoint=self.api, methods=['POST']))
self.server.routes.append(APIRoute(path='/{code}', endpoint=self.get_result, methods=['GET']))
def infer(self, code, video_path, audio_path) -> str:
@ -182,6 +189,18 @@ class HeyGem:
with open(path, 'wb') as f:
shutil.copyfileobj(r.raw, f)
async def get_result(self, code:str, token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if token.credentials != os.environ["AUTH_TOKEN"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
if code in self.result:
return self.result[code]
else:
return {"status": "waiting", "msg":""}
async def api(self, video_file: Optional[Union[UploadFile, str]], audio_file: Optional[Union[UploadFile, str]], token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if token.credentials != os.environ["AUTH_TOKEN"]:
raise HTTPException(
@ -190,31 +209,41 @@ class HeyGem:
headers={"WWW-Authenticate": "Bearer"},
)
code = str(uuid.uuid4())
video_path=None
audio_path=None
try:
if isinstance(video_file, UploadFile):
video_path = "/root/%s" % video_file.filename
file_id = str(uuid.uuid4())
if isinstance(video_file, UploadFile) or isinstance(video_file, starlette.datastructures.UploadFile):
video_path = "/root/%s.%s" % (file_id, video_file.filename.split(".")[-1])
with open(video_path, "wb") as f:
data = await video_file.read()
f.write(data)
elif isinstance(video_file, str) and video_file.startswith("http"):
video_path = "/root/%s" % video_file.split("?")[0].split("/")[-1]
video_path = "/root/%s.%s" % (file_id, video_file.split("?")[0].split("/")[-1].split(".")[-1])
self.download_file(video_file, video_path)
else:
return {"status": "fail", "msg": "video file save failed: Not a valid input"}
self.result[code] = {"status": "fail", "msg": "video file save failed: Not a valid input"}
if isinstance(audio_file, UploadFile):
audio_path = "/root/%s" % audio_file.filename
if isinstance(audio_file, UploadFile) or isinstance(audio_file, starlette.datastructures.UploadFile):
audio_path = "/root/%s.%s" % (file_id, audio_file.filename.split(".")[-1])
with open(audio_path, "wb") as f:
data = await audio_file.read()
f.write(data)
elif isinstance(audio_file, str) and audio_file.startswith("http"):
audio_path = "/root/%s" % audio_file.split("?")[0].split("/")[-1]
audio_path = "/root/%s.%s" % (file_id, audio_file.split("?")[0].split("/")[-1].split(".")[-1])
self.download_file(audio_file, audio_path)
else:
return {"status": "fail", "msg": "audio file save failed: Not a valid input"}
self.result[code] = {"status": "fail", "msg": "audio file save failed: Not a valid input"}
except Exception as e:
return {"status": "fail", "msg": f"video or audio file save failed: {str(e)}"}
self.result[code] = {"status": "fail", "msg": f"video or audio file save failed: {str(e)}"}
# Submit
if video_path and audio_path:
self.executor.submit(self._api, video_path=video_path, audio_path=audio_path, code=code)
return {"status": "success", "code": code}
else:
return self.result[code]
def _api(self, video_path: str, audio_path: str, code):
loguru.logger.info(f"Enter Inference Module")
try:
file_path = self.infer(code, video_path, audio_path)
@ -230,15 +259,11 @@ class HeyGem:
aws_secret_access_key=yaml_config["aws_access_key"])
awss3.meta.client.upload_file(file_path, "bw-heygem-output", file_path.split(os.path.sep)[-1])
except Exception as e:
return {"status": "fail", "msg": "Failed to move file to S3 manually: " + str(e)}
# resp = FileResponse(path=file_path, filename=Path(file_path).name, status_code=200, headers={
# "Content-Type": "application/octet-stream",
# "Content-Disposition": f"attachment; filename={Path(file_path).name}",
# })
return {"status": "success", "msg":f"{file_path.split(os.path.sep)[-1]}"}
self.result[code] = {"status": "fail", "msg": "Failed to move file to S3 manually: " + str(e)}
self.result[code] = {"status": "success", "msg":f"{file_path.split(os.path.sep)[-1]}"}
except Exception as e:
traceback.print_exc()
return {"status": "fail", "msg": "Inference module failed: "+str(e)}
self.result[code] = {"status": "fail", "msg": "Inference module failed: "+str(e)}
if __name__ == "__main__":
heygem = HeyGem()

View File

@ -0,0 +1,115 @@
import copy
import os
import random
import time
from concurrent.futures.thread import ThreadPoolExecutor
from typing import List
import loguru
from AutoDL.autodl_scheduling.util.audodl_sdk import instance_operate, get_autodl_machines, payg
RETRY_LIMIT = 20 # 创建实例重试次数
class Instance:
def __init__(self, uuid:str, active:bool=False, last_active_time:float=-1., domain:str=""):
self.uuid:str=uuid
self.active:bool=active
self.last_active_time:float=last_active_time
self.domain:str=domain
def __str__(self):
return "uuid:%s active:%s last_active_time:%s domain:%s" % (self.uuid, self.active, self.last_active_time, self.domain)
class InstancePool:
def __init__(self, min_instance=0, max_instance=100, scaledown_window=120, buffer_instance=0, timeout=1200):
self.min_instance = min_instance
self.max_instance = max_instance
self.scaledown_window = scaledown_window
self.buffer_instance = buffer_instance
self.timeout = timeout
self.instances:List[Instance] = []
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count()*2)
self.threads = []
def scale_instance(self, target_instance):
if target_instance + self.buffer_instance < self.min_instance:
return self._scale(self.min_instance)
if target_instance + self.buffer_instance > self.max_instance:
return self._scale(self.max_instance)
if target_instance + self.buffer_instance == len(self.instances):
return True
return self._scale(target_instance + self.buffer_instance)
def remove_instance(self, instance:Instance):
if instance_operate(instance.uuid, "power_off"):
if instance_operate(instance.uuid, "release"):
for i in self.instances:
if i.uuid == instance.uuid:
self.instances.remove(i)
else:
loguru.logger.error("Instance {} failed to release".format(instance.uuid))
else:
loguru.logger.error("Instance {} failed to power off".format(instance.uuid))
def _add_instance(self):
lim = RETRY_LIMIT
while lim > 0:
machines = get_autodl_machines()
if len(machines) > 0:
m = random.choice(machines)
result = payg(m["region_name"], m["machine_id"])
if result:
self.instances.append(
Instance(uuid=result[0], active=False, last_active_time=time.time(), domain="https://"+result[1]))
break
else:
time.sleep(1)
lim -= 1
if lim <= 0:
loguru.logger.error("Fail to Scale[Add] Instance")
def introspection(self):
# 停止超时实例(运行超时和无任务超时)
instance_copy = copy.deepcopy(self.instances)
for instance in instance_copy:
if instance.active:
if (time.time() - instance.last_active_time) > self.timeout:
self.threads.append(self.executor.submit(self.remove_instance, instance=instance))
else:
if (time.time() - instance.last_active_time) > self.scaledown_window:
self.threads.append(self.executor.submit(self.remove_instance, instance=instance))
def _scale(self, target_instance:int):
loguru.logger.info("Instance Num Before Scaling %d ; Target %d" % (len(self.instances), target_instance))
self.introspection()
# 调整实例数量
instance_copy = copy.deepcopy(self.instances)
dest = target_instance - len(instance_copy)
if dest < 0:
dest = abs(dest)
for instance in instance_copy:
if not instance.active and dest > 0:
self.threads.append(self.executor.submit(self.remove_instance, instance=instance))
dest -= 1
elif dest > 0:
for i in range(dest):
self.threads.append(self.executor.submit(self._add_instance))
while len(self.threads) > 0:
for t in self.threads:
t.result(timeout=self.timeout//2)
self.threads.remove(t)
loguru.logger.info("Instance Num After Scaling %d ; Target %d" % (len(self.instances), target_instance))
if len(self.instances) == target_instance:
return True
else:
return False
if __name__ == "__main__":
ip = InstancePool()
print(ip.scale_instance(5))
time.sleep(5)
print(ip.scale_instance(0))

View File

@ -0,0 +1,36 @@
from typing import Union
class ResultMap:
def __init__(self, max_size:int=1000):
self.map = dict()
self.max_size = max_size
def get(self, uid:str) -> Union[dict,None]:
if uid in self.map:
if self.map[uid]:
r = self.map[uid]
# self.remove(uid)
return r
else:
return None
else:
raise KeyError(uid)
def set(self, uid:str, result:dict=None) -> None:
if uid in self.map:
raise RuntimeError(f"Key {uid} already exists")
else:
if len(self.map) + 1 > self.max_size:
raise RuntimeError(f"ResultMap exceeds max size {self.max_size}")
self.map[uid] = result
def update(self, uid:str, result:dict) -> None:
if uid in self.map:
self.map[uid] = result
else:
raise KeyError(uid)
def remove(self, uid:str) -> None:
if uid in self.map:
del self.map[uid]

View File

@ -0,0 +1,80 @@
import time
import uuid
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Union
import loguru
import requests
from fastapi import UploadFile
SERVER_TIME_OUT = 60 * 35
class RunningPool:
def __init__(self):
self.tasks = {}
self._headers = {
"Authorization": "Bearer 7468B79CA2CF53436C06B29A6F16451C"
}
def run(self, instance_id, uid, base_url:str, video_file:Union[UploadFile,str], audio_file:Union[UploadFile,str]) -> Union[bool, None]:
try:
code = self._req(base_url, video_file, audio_file)
except Exception as e:
loguru.logger.error("%s Task Submit Failed: %s" % (uid, e))
return None
self.tasks[uid] = {"instance_id": instance_id, "base_url":base_url, "code": code, "status": False, "start_time": time.time()}
return True
def result(self, uid:str) -> Union[dict, None]:
if uid in self.tasks:
try:
resp = requests.get(self.tasks[uid]["base_url"] + "/" + self.tasks[uid]["code"], headers=self._headers)
if resp.status_code != 200:
raise Exception("%s Get Result Failed: %s" % (uid, resp.status_code))
else:
if resp.json()["status"] == "waiting":
if self.tasks[uid]["start_time"] + SERVER_TIME_OUT + 10 < time.time():
self.tasks[uid]["status"] = True
return {"status": "fail", "msg": "Instance Timeout", "instance_id": self.tasks[uid]["instance_id"]}
return None
else:
self.tasks[uid]["status"] = True
return {"status": resp.json()['status'], "msg": resp.json()['msg'],
"instance_id": self.tasks[uid]["instance_id"]}
except Exception as e:
loguru.logger.error("%s Get Result Failed: %s" % (uid, e))
return None
else:
raise KeyError(uid)
def get_running_size(self):
return len([i for i in self.tasks.values() if not i["status"]])
def _req(self, base_url, video_file, audio_file):
data = {
"video_file": video_file,
"audio_file": audio_file
}
resp = requests.post(base_url, data, headers=self._headers, allow_redirects=True, stream=True)
if resp.status_code == 200:
if resp.json()["status"] == "success":
code = resp.json()["code"]
return code
else:
return None
else:
return None
if __name__ == "__main__":
rp = RunningPool()
id = rp.run("https://u336391-b31a-07132bc4.westx.seetacloud.com:8443",
"https://sucai-1324682537.cos.ap-shanghai.myqcloud.com/tiktok/video/1111.mp4",
"https://sucai-1324682537.cos.ap-shanghai.myqcloud.com/tiktok/video/XBR037ruAZsA.mp3")
while True:
r = rp.result(id)
if r:
loguru.logger.success(r)
break
else:
loguru.logger.info("waiting...")
time.sleep(5)

View File

@ -0,0 +1,21 @@
from queue import Queue
class WaitingQueue:
def __init__(self):
self.queue = Queue(maxsize=500)
def enqueue(self,uid, video_path,audio_path):
data = {
"uid": uid,
"video_path": video_path,
"audio_path": audio_path
}
self.queue.put(data)
def dequeue(self):
data = self.queue.get()
return data["uid"], data["video_path"], data["audio_path"]
def get_size(self):
return self.queue.qsize()

View File

@ -0,0 +1,110 @@
import time
import traceback
import uuid
from concurrent.futures.thread import ThreadPoolExecutor
import loguru
import uvicorn
from fastapi import FastAPI, Depends, HTTPException
from fastapi.routing import APIRoute
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from starlette import status
from AutoDL.autodl_scheduling.entity.instance_pool import InstancePool, Instance
from AutoDL.autodl_scheduling.entity.result_map import ResultMap
from AutoDL.autodl_scheduling.entity.running_pool import RunningPool
from AutoDL.autodl_scheduling.entity.waiting_queue import WaitingQueue
class Server:
def __init__(self):
self.app = FastAPI()
self.waiting_queue = WaitingQueue()
self.running_pool = RunningPool()
self.instance_pool = InstancePool(max_instance=29)
self.result_map = ResultMap()
self.executor = ThreadPoolExecutor(max_workers=2)
self.worker_1 = self.executor.submit(self.scaling_worker)
self.worker_2 = self.executor.submit(self.introspect_instance)
self.app.routes.append(APIRoute("/", endpoint=self.submit, methods=['POST']))
self.app.routes.append(APIRoute("/{uid}", endpoint=self.get_result, methods=['GET']))
# async def submit(self, video_url, audio_url, token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
async def submit(self, video_url, audio_url):
# if token.credentials != "7468B79CA2CF53436C06B29A6F16451C":
# raise HTTPException(
# status_code=status.HTTP_401_UNAUTHORIZED,
# detail="Incorrect bearer token",
# headers={"WWW-Authenticate": "Bearer"},
# )
uid = str(uuid.uuid4())
try:
self.waiting_queue.enqueue(uid, video_url, audio_url)
return {"uid": uid}
except:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
async def get_result(self, uid):
try:
return self.result_map.get(uid)
except:
if uid in self.running_pool.tasks:
return {"status": "running", "msg": ""}
return {"status": "queuing", "msg":""}
def introspect_instance(self):
loguru.logger.info("Introspecting worker started")
while True:
self.instance_pool.introspection()
time.sleep(0.5)
def scaling_worker(self):
loguru.logger.info("Scaling worker started")
try:
last_target = 0
while True:
# 提交任务
if last_target != self.waiting_queue.get_size()+self.running_pool.get_running_size():
last_target = self.waiting_queue.get_size()+self.running_pool.get_running_size()
self.instance_pool.scale_instance(last_target)
for instance in self.instance_pool.instances:
if not instance.active:
# 从等待队列取出任务
uid, video_path, audio_path = self.waiting_queue.dequeue()
# 提交任务到运行池
if self.running_pool.run(instance.uuid, uid, instance.domain, video_path, audio_path):
# 更新实例池状态
instance.active = True
instance.last_active_time = time.time()
loguru.logger.info("Task Submitted")
else:
loguru.logger.error("Submit Task Failed")
#查询结果放到结果缓存里
for uid, task in list(self.running_pool.tasks.items()):
result = self.running_pool.result(uid)
# 任务运行完成或失败
if result:
self.result_map.set(uid, result)
self.running_pool.tasks.pop(uid)
if result["status"] == "fail":
#删除失败任务的实例
self.instance_pool.remove_instance(Instance(uuid=task["instance_id"]))
else:
#更新成功任务实例的状态
for instance in self.instance_pool.instances:
if instance.uuid == task["instance_id"]:
instance.active = False
instance.last_active_time = time.time()
time.sleep(0.5)
except:
traceback.print_exc()
if __name__=="__main__":
server = Server()
uvicorn.run(server.app, host="127.0.0.1", port=8888)

View File

@ -10,10 +10,11 @@ token = "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1aWQiOjMzNjM5MSwidXVpZCI6ImU2MD
req_instance_page_size = 1500
chk_instance_page_size = 1500
instances = {}
LIM = 30 # 等待状态时间LIM*5s
LIM = 200 # 等待状态时间LIM*5s
START_LIM = 20 #Heygem脚本启动等待超时时间-10
DEBUG = False
def ssh_try(host,port,pwd):
def ssh_try(host,port,pwd,uid):
# 建立连接
trans = paramiko.Transport((host, int(port)))
trans.connect(username="root", password=pwd)
@ -29,10 +30,10 @@ def ssh_try(host,port,pwd):
if len(out.split("\n")[-2]) > 2:
trans.close()
raise RuntimeError(out.split("\n")[-2])
loguru.logger.info(f"[{uid}]Waiting for HeyGem server ready...")
while int(out.split("\n")[-2]) <= 1 and start_lim > 0:
ssh_stdin, ssh_stdout, ssh_stderr = ssh.exec_command("bash -ic \"sleep 2 && lsof -i:6006|wc -l\"")
out = ssh_stdout.read().decode()
loguru.logger.info("waiting for HeyGem server ready...")
start_lim -= 1
if start_lim <= 0:
loguru.logger.error("HeyGem Server Start Timeout, Please check!")
@ -52,7 +53,7 @@ def get_autodl_machines() -> Union[list, None]:
payload = {
"charge_type":"payg",
"region_sign":"",
"gpu_type_name":["RTX 4090", "RTX 4090D", "RTX 3090", "RTX 3080", "RTX 3080x2", "RTX 3080 Ti", "RTX 3060", "RTX A4000", "RTX 2080 Ti", "RTX 2080 Ti x2", "GTX 1080 Ti"],
"gpu_type_name":["V100-SXM2-32GB", "vGPU-32GB", "RTX 4090", "RTX 4090D", "RTX 3090", "RTX 3080", "RTX 3080x2", "RTX 3080 Ti", "RTX 3060", "RTX A4000", "RTX 2080 Ti", "RTX 2080 Ti x2", "GTX 1080 Ti"],
"machine_tag_name":[],
"gpu_idle_num":1,
"mount_net_disk":False,
@ -70,15 +71,18 @@ def get_autodl_machines() -> Union[list, None]:
"chip_corp":["nvidia"],
"machine_id":""
}
loguru.logger.info("Req Machine index {}".format(index))
if DEBUG:
loguru.logger.info("Req Machine index {}".format(index))
rsp = requests.post("https://www.autodl.com/api/v1/sub_user/user/machine/list", json=payload, headers=headers)
if rsp.status_code == 200:
machine_list = rsp.json()
loguru.logger.info("Machine Result Total {}".format(machine_list["data"]["result_total"]))
if DEBUG:
loguru.logger.info("Machine Result Total {}".format(machine_list["data"]["result_total"]))
while index < machine_list["data"]["max_page"]:
index += 1
loguru.logger.info("Req Machine index {}/{}".format(index, machine_list["data"]["max_page"]))
if DEBUG:
loguru.logger.info("Req Machine index {}/{}".format(index, machine_list["data"]["max_page"]))
payload["page_index"] = index
rsp = requests.post("https://www.autodl.com/api/v1/sub_user/user/machine/list", json=payload, headers=headers)
if rsp.status_code == 200:
@ -111,9 +115,9 @@ def get_autodl_machines() -> Union[list, None]:
def payg(region_name:str, machine_id:str) -> tuple[Any, Any] | None:
region_image = {
"西北": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-232ac04d3b"],
"内蒙": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-06814c02d1"],
"北京": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-e5334cc4f3"]
"西北": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-0d3ab01d64"],
"内蒙": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-bfff5e82ae"],
"北京": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-537b5da0c5"]
}
headers = {
"Authorization": token,
@ -149,10 +153,10 @@ def payg(region_name:str, machine_id:str) -> tuple[Any, Any] | None:
if j["code"] == "Success":
lim = LIM
while lim>0:
time.sleep(5)
time.sleep(3)
status, host, port, pwd, domain = check_status(j['data'])
if status == "running":
ssh_try(host, port, pwd)
ssh_try(host, port, pwd, j['data'])
break
else:
lim = lim-1
@ -188,22 +192,22 @@ def instance_operate(instance_uuid:str, operation: Literal["power_off","power_on
if operation in dest_dict.keys():
while lim>0:
time.sleep(5)
status = check_status(instance_uuid)[0]
if status == dest_dict[operation]:
status = check_status(instance_uuid)
if status and status[0] == dest_dict[operation]:
break
else:
lim = lim-1
if lim > 0:
loguru.logger.success("Operate[%s] Instance Success" % operation)
loguru.logger.success("Operate[%s] Instance[%s] Success" % (operation, instance_uuid))
return True
else:
loguru.logger.error("Operate[%s] Instance Error: Timeout, Please Check!!! instance_uuid(%s)" % (operation, instance_uuid))
loguru.logger.error("Operate[%s] Instance[%s] Error: Timeout, Please Check!!!" % (operation, instance_uuid))
return False
else:
loguru.logger.error("Operate[%s] Instance Error: %s" % (operation, j['msg']))
loguru.logger.error("Operate[%s] Instance[%s] Error: %s" % (operation, instance_uuid, j['msg']))
return False
else:
loguru.logger.error("Operate[%s] Instance Error: Status Code[%s]" % (operation, rsp.status_code))
loguru.logger.error("Operate[%s] Instance[%s] Error: Status Code[%s]" % (operation, instance_uuid, rsp.status_code))
return False
def check_status(instance_uuid:str) -> tuple[Any, Any, Any, Any, Any] | None:
@ -221,13 +225,27 @@ def check_status(instance_uuid:str) -> tuple[Any, Any, Any, Any, Any] | None:
"status":[],
"charge_type":[]
}
# loguru.logger.info("Req Instance index {}".format(index))
rsp = requests.post("https://www.autodl.com/api/v1/sub_user/instance", json=payload, headers=headers)
if DEBUG:
loguru.logger.info("Req Instance index {}".format(index))
lim = 3
while lim > 0:
rsp = None
try:
rsp = requests.post("https://www.autodl.com/api/v1/sub_user/instance", json=payload, headers=headers)
except:
lim -= 1
if rsp:
break
if lim <= 0:
loguru.logger.error("Get Instance Req Error")
return None
if rsp.status_code == 200:
instance_list = rsp.json()
# loguru.logger.info("Instance Result Total {}".format(instance_list["data"]["result_total"]))
if DEBUG:
loguru.logger.info("Instance Result Total {}".format(instance_list["data"]["result_total"]))
while index < instance_list["data"]["max_page"]:
# loguru.logger.info("Req Instance index {}/{}".format(index, instance_list["data"]["max_page"]))
if DEBUG:
loguru.logger.info("Req Instance index {}/{}".format(index, instance_list["data"]["max_page"]))
payload["page_index"] = index
rsp = requests.post("https://www.autodl.com/api/v1/sub_user/instance", json=payload, headers=headers)
if rsp.status_code == 200:
@ -237,7 +255,8 @@ def check_status(instance_uuid:str) -> tuple[Any, Any, Any, Any, Any] | None:
return None
for l in instance_list["data"]["list"]:
if l["uuid"] == instance_uuid:
loguru.logger.info("Instance {} Status {}".format(instance_uuid, l["status"]))
if DEBUG:
loguru.logger.info("Instance {} Status {}".format(instance_uuid, l["status"]))
return l["status"], l["proxy_host"], l["ssh_port"], l["root_password"], l["tensorboard_domain"]
loguru.logger.warning("Instance {} Not Found".format(instance_uuid))
return None

View File

@ -101,7 +101,7 @@ auth_scheme = HTTPBearer()
gpu=["L4", "T4"],
cpu=(2,16),
# memory=(32768, 32768), # (内存预留量, 内存使用上限)
memory=(20480,40960),
memory=(32768,131072),
enable_memory_snapshot=False,
secrets=[secret, modal.Secret.from_name("web_auth_token")],
volumes={