ADD AutoDL调度基本功能完成
This commit is contained in:
parent
61106608e0
commit
1272a33718
|
|
@ -6,12 +6,15 @@ import subprocess
|
|||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
import loguru
|
||||
import requests
|
||||
import starlette.datastructures
|
||||
import uvicorn
|
||||
|
||||
from fastapi import UploadFile, HTTPException, Depends, FastAPI
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
|
@ -47,8 +50,12 @@ class HeyGem:
|
|||
time.sleep(2)
|
||||
if timeout == 0:
|
||||
raise TimeoutError("HeyGem Server Start timed out")
|
||||
self.result = {}
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.server = FastAPI()
|
||||
self.server.routes.append(APIRoute(path='/', endpoint=self.api, methods=['POST']))
|
||||
self.server.routes.append(APIRoute(path='/{code}', endpoint=self.get_result, methods=['GET']))
|
||||
|
||||
|
||||
|
||||
def infer(self, code, video_path, audio_path) -> str:
|
||||
|
|
@ -182,6 +189,18 @@ class HeyGem:
|
|||
with open(path, 'wb') as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
|
||||
async def get_result(self, code:str, token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
|
||||
if token.credentials != os.environ["AUTH_TOKEN"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect bearer token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
if code in self.result:
|
||||
return self.result[code]
|
||||
else:
|
||||
return {"status": "waiting", "msg":""}
|
||||
|
||||
async def api(self, video_file: Optional[Union[UploadFile, str]], audio_file: Optional[Union[UploadFile, str]], token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
|
||||
if token.credentials != os.environ["AUTH_TOKEN"]:
|
||||
raise HTTPException(
|
||||
|
|
@ -190,31 +209,41 @@ class HeyGem:
|
|||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
code = str(uuid.uuid4())
|
||||
video_path=None
|
||||
audio_path=None
|
||||
try:
|
||||
if isinstance(video_file, UploadFile):
|
||||
video_path = "/root/%s" % video_file.filename
|
||||
file_id = str(uuid.uuid4())
|
||||
if isinstance(video_file, UploadFile) or isinstance(video_file, starlette.datastructures.UploadFile):
|
||||
video_path = "/root/%s.%s" % (file_id, video_file.filename.split(".")[-1])
|
||||
with open(video_path, "wb") as f:
|
||||
data = await video_file.read()
|
||||
f.write(data)
|
||||
elif isinstance(video_file, str) and video_file.startswith("http"):
|
||||
video_path = "/root/%s" % video_file.split("?")[0].split("/")[-1]
|
||||
video_path = "/root/%s.%s" % (file_id, video_file.split("?")[0].split("/")[-1].split(".")[-1])
|
||||
self.download_file(video_file, video_path)
|
||||
else:
|
||||
return {"status": "fail", "msg": "video file save failed: Not a valid input"}
|
||||
self.result[code] = {"status": "fail", "msg": "video file save failed: Not a valid input"}
|
||||
|
||||
if isinstance(audio_file, UploadFile):
|
||||
audio_path = "/root/%s" % audio_file.filename
|
||||
if isinstance(audio_file, UploadFile) or isinstance(audio_file, starlette.datastructures.UploadFile):
|
||||
audio_path = "/root/%s.%s" % (file_id, audio_file.filename.split(".")[-1])
|
||||
with open(audio_path, "wb") as f:
|
||||
data = await audio_file.read()
|
||||
f.write(data)
|
||||
elif isinstance(audio_file, str) and audio_file.startswith("http"):
|
||||
audio_path = "/root/%s" % audio_file.split("?")[0].split("/")[-1]
|
||||
audio_path = "/root/%s.%s" % (file_id, audio_file.split("?")[0].split("/")[-1].split(".")[-1])
|
||||
self.download_file(audio_file, audio_path)
|
||||
else:
|
||||
return {"status": "fail", "msg": "audio file save failed: Not a valid input"}
|
||||
self.result[code] = {"status": "fail", "msg": "audio file save failed: Not a valid input"}
|
||||
except Exception as e:
|
||||
return {"status": "fail", "msg": f"video or audio file save failed: {str(e)}"}
|
||||
self.result[code] = {"status": "fail", "msg": f"video or audio file save failed: {str(e)}"}
|
||||
# Submit
|
||||
if video_path and audio_path:
|
||||
self.executor.submit(self._api, video_path=video_path, audio_path=audio_path, code=code)
|
||||
return {"status": "success", "code": code}
|
||||
else:
|
||||
return self.result[code]
|
||||
|
||||
def _api(self, video_path: str, audio_path: str, code):
|
||||
loguru.logger.info(f"Enter Inference Module")
|
||||
try:
|
||||
file_path = self.infer(code, video_path, audio_path)
|
||||
|
|
@ -230,15 +259,11 @@ class HeyGem:
|
|||
aws_secret_access_key=yaml_config["aws_access_key"])
|
||||
awss3.meta.client.upload_file(file_path, "bw-heygem-output", file_path.split(os.path.sep)[-1])
|
||||
except Exception as e:
|
||||
return {"status": "fail", "msg": "Failed to move file to S3 manually: " + str(e)}
|
||||
# resp = FileResponse(path=file_path, filename=Path(file_path).name, status_code=200, headers={
|
||||
# "Content-Type": "application/octet-stream",
|
||||
# "Content-Disposition": f"attachment; filename={Path(file_path).name}",
|
||||
# })
|
||||
return {"status": "success", "msg":f"{file_path.split(os.path.sep)[-1]}"}
|
||||
self.result[code] = {"status": "fail", "msg": "Failed to move file to S3 manually: " + str(e)}
|
||||
self.result[code] = {"status": "success", "msg":f"{file_path.split(os.path.sep)[-1]}"}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return {"status": "fail", "msg": "Inference module failed: "+str(e)}
|
||||
self.result[code] = {"status": "fail", "msg": "Inference module failed: "+str(e)}
|
||||
|
||||
if __name__ == "__main__":
|
||||
heygem = HeyGem()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,115 @@
|
|||
import copy
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import List
|
||||
|
||||
import loguru
|
||||
|
||||
from AutoDL.autodl_scheduling.util.audodl_sdk import instance_operate, get_autodl_machines, payg
|
||||
|
||||
RETRY_LIMIT = 20 # 创建实例重试次数
|
||||
|
||||
|
||||
class Instance:
|
||||
def __init__(self, uuid:str, active:bool=False, last_active_time:float=-1., domain:str=""):
|
||||
self.uuid:str=uuid
|
||||
self.active:bool=active
|
||||
self.last_active_time:float=last_active_time
|
||||
self.domain:str=domain
|
||||
|
||||
def __str__(self):
|
||||
return "uuid:%s active:%s last_active_time:%s domain:%s" % (self.uuid, self.active, self.last_active_time, self.domain)
|
||||
|
||||
|
||||
class InstancePool:
|
||||
def __init__(self, min_instance=0, max_instance=100, scaledown_window=120, buffer_instance=0, timeout=1200):
|
||||
self.min_instance = min_instance
|
||||
self.max_instance = max_instance
|
||||
self.scaledown_window = scaledown_window
|
||||
self.buffer_instance = buffer_instance
|
||||
self.timeout = timeout
|
||||
self.instances:List[Instance] = []
|
||||
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count()*2)
|
||||
self.threads = []
|
||||
|
||||
def scale_instance(self, target_instance):
|
||||
if target_instance + self.buffer_instance < self.min_instance:
|
||||
return self._scale(self.min_instance)
|
||||
if target_instance + self.buffer_instance > self.max_instance:
|
||||
return self._scale(self.max_instance)
|
||||
if target_instance + self.buffer_instance == len(self.instances):
|
||||
return True
|
||||
return self._scale(target_instance + self.buffer_instance)
|
||||
|
||||
def remove_instance(self, instance:Instance):
|
||||
if instance_operate(instance.uuid, "power_off"):
|
||||
if instance_operate(instance.uuid, "release"):
|
||||
for i in self.instances:
|
||||
if i.uuid == instance.uuid:
|
||||
self.instances.remove(i)
|
||||
else:
|
||||
loguru.logger.error("Instance {} failed to release".format(instance.uuid))
|
||||
else:
|
||||
loguru.logger.error("Instance {} failed to power off".format(instance.uuid))
|
||||
|
||||
def _add_instance(self):
|
||||
lim = RETRY_LIMIT
|
||||
while lim > 0:
|
||||
machines = get_autodl_machines()
|
||||
if len(machines) > 0:
|
||||
m = random.choice(machines)
|
||||
result = payg(m["region_name"], m["machine_id"])
|
||||
if result:
|
||||
self.instances.append(
|
||||
Instance(uuid=result[0], active=False, last_active_time=time.time(), domain="https://"+result[1]))
|
||||
break
|
||||
else:
|
||||
time.sleep(1)
|
||||
lim -= 1
|
||||
if lim <= 0:
|
||||
loguru.logger.error("Fail to Scale[Add] Instance")
|
||||
|
||||
def introspection(self):
|
||||
# 停止超时实例(运行超时和无任务超时)
|
||||
instance_copy = copy.deepcopy(self.instances)
|
||||
for instance in instance_copy:
|
||||
if instance.active:
|
||||
if (time.time() - instance.last_active_time) > self.timeout:
|
||||
self.threads.append(self.executor.submit(self.remove_instance, instance=instance))
|
||||
else:
|
||||
if (time.time() - instance.last_active_time) > self.scaledown_window:
|
||||
self.threads.append(self.executor.submit(self.remove_instance, instance=instance))
|
||||
|
||||
def _scale(self, target_instance:int):
|
||||
loguru.logger.info("Instance Num Before Scaling %d ; Target %d" % (len(self.instances), target_instance))
|
||||
self.introspection()
|
||||
# 调整实例数量
|
||||
instance_copy = copy.deepcopy(self.instances)
|
||||
dest = target_instance - len(instance_copy)
|
||||
if dest < 0:
|
||||
dest = abs(dest)
|
||||
for instance in instance_copy:
|
||||
if not instance.active and dest > 0:
|
||||
self.threads.append(self.executor.submit(self.remove_instance, instance=instance))
|
||||
dest -= 1
|
||||
elif dest > 0:
|
||||
for i in range(dest):
|
||||
self.threads.append(self.executor.submit(self._add_instance))
|
||||
while len(self.threads) > 0:
|
||||
for t in self.threads:
|
||||
t.result(timeout=self.timeout//2)
|
||||
self.threads.remove(t)
|
||||
loguru.logger.info("Instance Num After Scaling %d ; Target %d" % (len(self.instances), target_instance))
|
||||
if len(self.instances) == target_instance:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ip = InstancePool()
|
||||
print(ip.scale_instance(5))
|
||||
time.sleep(5)
|
||||
print(ip.scale_instance(0))
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
from typing import Union
|
||||
|
||||
|
||||
class ResultMap:
|
||||
def __init__(self, max_size:int=1000):
|
||||
self.map = dict()
|
||||
self.max_size = max_size
|
||||
|
||||
def get(self, uid:str) -> Union[dict,None]:
|
||||
if uid in self.map:
|
||||
if self.map[uid]:
|
||||
r = self.map[uid]
|
||||
# self.remove(uid)
|
||||
return r
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
raise KeyError(uid)
|
||||
|
||||
def set(self, uid:str, result:dict=None) -> None:
|
||||
if uid in self.map:
|
||||
raise RuntimeError(f"Key {uid} already exists")
|
||||
else:
|
||||
if len(self.map) + 1 > self.max_size:
|
||||
raise RuntimeError(f"ResultMap exceeds max size {self.max_size}")
|
||||
self.map[uid] = result
|
||||
|
||||
def update(self, uid:str, result:dict) -> None:
|
||||
if uid in self.map:
|
||||
self.map[uid] = result
|
||||
else:
|
||||
raise KeyError(uid)
|
||||
|
||||
def remove(self, uid:str) -> None:
|
||||
if uid in self.map:
|
||||
del self.map[uid]
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import time
|
||||
import uuid
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import Union
|
||||
|
||||
import loguru
|
||||
import requests
|
||||
from fastapi import UploadFile
|
||||
SERVER_TIME_OUT = 60 * 35
|
||||
|
||||
class RunningPool:
|
||||
def __init__(self):
|
||||
self.tasks = {}
|
||||
self._headers = {
|
||||
"Authorization": "Bearer 7468B79CA2CF53436C06B29A6F16451C"
|
||||
}
|
||||
|
||||
def run(self, instance_id, uid, base_url:str, video_file:Union[UploadFile,str], audio_file:Union[UploadFile,str]) -> Union[bool, None]:
|
||||
try:
|
||||
code = self._req(base_url, video_file, audio_file)
|
||||
except Exception as e:
|
||||
loguru.logger.error("%s Task Submit Failed: %s" % (uid, e))
|
||||
return None
|
||||
self.tasks[uid] = {"instance_id": instance_id, "base_url":base_url, "code": code, "status": False, "start_time": time.time()}
|
||||
return True
|
||||
|
||||
def result(self, uid:str) -> Union[dict, None]:
|
||||
if uid in self.tasks:
|
||||
try:
|
||||
resp = requests.get(self.tasks[uid]["base_url"] + "/" + self.tasks[uid]["code"], headers=self._headers)
|
||||
if resp.status_code != 200:
|
||||
raise Exception("%s Get Result Failed: %s" % (uid, resp.status_code))
|
||||
else:
|
||||
if resp.json()["status"] == "waiting":
|
||||
if self.tasks[uid]["start_time"] + SERVER_TIME_OUT + 10 < time.time():
|
||||
self.tasks[uid]["status"] = True
|
||||
return {"status": "fail", "msg": "Instance Timeout", "instance_id": self.tasks[uid]["instance_id"]}
|
||||
return None
|
||||
else:
|
||||
self.tasks[uid]["status"] = True
|
||||
return {"status": resp.json()['status'], "msg": resp.json()['msg'],
|
||||
"instance_id": self.tasks[uid]["instance_id"]}
|
||||
except Exception as e:
|
||||
loguru.logger.error("%s Get Result Failed: %s" % (uid, e))
|
||||
return None
|
||||
else:
|
||||
raise KeyError(uid)
|
||||
|
||||
def get_running_size(self):
|
||||
return len([i for i in self.tasks.values() if not i["status"]])
|
||||
|
||||
def _req(self, base_url, video_file, audio_file):
|
||||
data = {
|
||||
"video_file": video_file,
|
||||
"audio_file": audio_file
|
||||
}
|
||||
resp = requests.post(base_url, data, headers=self._headers, allow_redirects=True, stream=True)
|
||||
if resp.status_code == 200:
|
||||
if resp.json()["status"] == "success":
|
||||
code = resp.json()["code"]
|
||||
return code
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
rp = RunningPool()
|
||||
id = rp.run("https://u336391-b31a-07132bc4.westx.seetacloud.com:8443",
|
||||
"https://sucai-1324682537.cos.ap-shanghai.myqcloud.com/tiktok/video/1111.mp4",
|
||||
"https://sucai-1324682537.cos.ap-shanghai.myqcloud.com/tiktok/video/XBR037ruAZsA.mp3")
|
||||
while True:
|
||||
r = rp.result(id)
|
||||
if r:
|
||||
loguru.logger.success(r)
|
||||
break
|
||||
else:
|
||||
loguru.logger.info("waiting...")
|
||||
time.sleep(5)
|
||||
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
from queue import Queue
|
||||
|
||||
|
||||
class WaitingQueue:
|
||||
def __init__(self):
|
||||
self.queue = Queue(maxsize=500)
|
||||
|
||||
def enqueue(self,uid, video_path,audio_path):
|
||||
data = {
|
||||
"uid": uid,
|
||||
"video_path": video_path,
|
||||
"audio_path": audio_path
|
||||
}
|
||||
self.queue.put(data)
|
||||
|
||||
def dequeue(self):
|
||||
data = self.queue.get()
|
||||
return data["uid"], data["video_path"], data["audio_path"]
|
||||
|
||||
def get_size(self):
|
||||
return self.queue.qsize()
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
import loguru
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from starlette import status
|
||||
|
||||
from AutoDL.autodl_scheduling.entity.instance_pool import InstancePool, Instance
|
||||
from AutoDL.autodl_scheduling.entity.result_map import ResultMap
|
||||
from AutoDL.autodl_scheduling.entity.running_pool import RunningPool
|
||||
from AutoDL.autodl_scheduling.entity.waiting_queue import WaitingQueue
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self):
|
||||
self.app = FastAPI()
|
||||
self.waiting_queue = WaitingQueue()
|
||||
self.running_pool = RunningPool()
|
||||
self.instance_pool = InstancePool(max_instance=29)
|
||||
self.result_map = ResultMap()
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
self.worker_1 = self.executor.submit(self.scaling_worker)
|
||||
self.worker_2 = self.executor.submit(self.introspect_instance)
|
||||
self.app.routes.append(APIRoute("/", endpoint=self.submit, methods=['POST']))
|
||||
self.app.routes.append(APIRoute("/{uid}", endpoint=self.get_result, methods=['GET']))
|
||||
|
||||
# async def submit(self, video_url, audio_url, token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
|
||||
async def submit(self, video_url, audio_url):
|
||||
# if token.credentials != "7468B79CA2CF53436C06B29A6F16451C":
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
# detail="Incorrect bearer token",
|
||||
# headers={"WWW-Authenticate": "Bearer"},
|
||||
# )
|
||||
uid = str(uuid.uuid4())
|
||||
try:
|
||||
self.waiting_queue.enqueue(uid, video_url, audio_url)
|
||||
return {"uid": uid}
|
||||
except:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
async def get_result(self, uid):
|
||||
try:
|
||||
return self.result_map.get(uid)
|
||||
except:
|
||||
if uid in self.running_pool.tasks:
|
||||
return {"status": "running", "msg": ""}
|
||||
return {"status": "queuing", "msg":""}
|
||||
|
||||
def introspect_instance(self):
|
||||
loguru.logger.info("Introspecting worker started")
|
||||
while True:
|
||||
self.instance_pool.introspection()
|
||||
time.sleep(0.5)
|
||||
|
||||
def scaling_worker(self):
|
||||
loguru.logger.info("Scaling worker started")
|
||||
try:
|
||||
last_target = 0
|
||||
while True:
|
||||
# 提交任务
|
||||
if last_target != self.waiting_queue.get_size()+self.running_pool.get_running_size():
|
||||
last_target = self.waiting_queue.get_size()+self.running_pool.get_running_size()
|
||||
self.instance_pool.scale_instance(last_target)
|
||||
for instance in self.instance_pool.instances:
|
||||
if not instance.active:
|
||||
# 从等待队列取出任务
|
||||
uid, video_path, audio_path = self.waiting_queue.dequeue()
|
||||
# 提交任务到运行池
|
||||
if self.running_pool.run(instance.uuid, uid, instance.domain, video_path, audio_path):
|
||||
# 更新实例池状态
|
||||
instance.active = True
|
||||
instance.last_active_time = time.time()
|
||||
loguru.logger.info("Task Submitted")
|
||||
else:
|
||||
loguru.logger.error("Submit Task Failed")
|
||||
|
||||
#查询结果放到结果缓存里
|
||||
for uid, task in list(self.running_pool.tasks.items()):
|
||||
result = self.running_pool.result(uid)
|
||||
# 任务运行完成或失败
|
||||
if result:
|
||||
self.result_map.set(uid, result)
|
||||
self.running_pool.tasks.pop(uid)
|
||||
if result["status"] == "fail":
|
||||
#删除失败任务的实例
|
||||
self.instance_pool.remove_instance(Instance(uuid=task["instance_id"]))
|
||||
else:
|
||||
#更新成功任务实例的状态
|
||||
for instance in self.instance_pool.instances:
|
||||
if instance.uuid == task["instance_id"]:
|
||||
instance.active = False
|
||||
instance.last_active_time = time.time()
|
||||
time.sleep(0.5)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
server = Server()
|
||||
uvicorn.run(server.app, host="127.0.0.1", port=8888)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -10,10 +10,11 @@ token = "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1aWQiOjMzNjM5MSwidXVpZCI6ImU2MD
|
|||
req_instance_page_size = 1500
|
||||
chk_instance_page_size = 1500
|
||||
instances = {}
|
||||
LIM = 30 # 等待状态时间LIM*5s
|
||||
LIM = 200 # 等待状态时间LIM*5s
|
||||
START_LIM = 20 #Heygem脚本启动等待超时时间-10
|
||||
DEBUG = False
|
||||
|
||||
def ssh_try(host,port,pwd):
|
||||
def ssh_try(host,port,pwd,uid):
|
||||
# 建立连接
|
||||
trans = paramiko.Transport((host, int(port)))
|
||||
trans.connect(username="root", password=pwd)
|
||||
|
|
@ -29,10 +30,10 @@ def ssh_try(host,port,pwd):
|
|||
if len(out.split("\n")[-2]) > 2:
|
||||
trans.close()
|
||||
raise RuntimeError(out.split("\n")[-2])
|
||||
loguru.logger.info(f"[{uid}]Waiting for HeyGem server ready...")
|
||||
while int(out.split("\n")[-2]) <= 1 and start_lim > 0:
|
||||
ssh_stdin, ssh_stdout, ssh_stderr = ssh.exec_command("bash -ic \"sleep 2 && lsof -i:6006|wc -l\"")
|
||||
out = ssh_stdout.read().decode()
|
||||
loguru.logger.info("waiting for HeyGem server ready...")
|
||||
start_lim -= 1
|
||||
if start_lim <= 0:
|
||||
loguru.logger.error("HeyGem Server Start Timeout, Please check!")
|
||||
|
|
@ -52,7 +53,7 @@ def get_autodl_machines() -> Union[list, None]:
|
|||
payload = {
|
||||
"charge_type":"payg",
|
||||
"region_sign":"",
|
||||
"gpu_type_name":["RTX 4090", "RTX 4090D", "RTX 3090", "RTX 3080", "RTX 3080x2", "RTX 3080 Ti", "RTX 3060", "RTX A4000", "RTX 2080 Ti", "RTX 2080 Ti x2", "GTX 1080 Ti"],
|
||||
"gpu_type_name":["V100-SXM2-32GB", "vGPU-32GB", "RTX 4090", "RTX 4090D", "RTX 3090", "RTX 3080", "RTX 3080x2", "RTX 3080 Ti", "RTX 3060", "RTX A4000", "RTX 2080 Ti", "RTX 2080 Ti x2", "GTX 1080 Ti"],
|
||||
"machine_tag_name":[],
|
||||
"gpu_idle_num":1,
|
||||
"mount_net_disk":False,
|
||||
|
|
@ -70,15 +71,18 @@ def get_autodl_machines() -> Union[list, None]:
|
|||
"chip_corp":["nvidia"],
|
||||
"machine_id":""
|
||||
}
|
||||
loguru.logger.info("Req Machine index {}".format(index))
|
||||
if DEBUG:
|
||||
loguru.logger.info("Req Machine index {}".format(index))
|
||||
rsp = requests.post("https://www.autodl.com/api/v1/sub_user/user/machine/list", json=payload, headers=headers)
|
||||
|
||||
if rsp.status_code == 200:
|
||||
machine_list = rsp.json()
|
||||
loguru.logger.info("Machine Result Total {}".format(machine_list["data"]["result_total"]))
|
||||
if DEBUG:
|
||||
loguru.logger.info("Machine Result Total {}".format(machine_list["data"]["result_total"]))
|
||||
while index < machine_list["data"]["max_page"]:
|
||||
index += 1
|
||||
loguru.logger.info("Req Machine index {}/{}".format(index, machine_list["data"]["max_page"]))
|
||||
if DEBUG:
|
||||
loguru.logger.info("Req Machine index {}/{}".format(index, machine_list["data"]["max_page"]))
|
||||
payload["page_index"] = index
|
||||
rsp = requests.post("https://www.autodl.com/api/v1/sub_user/user/machine/list", json=payload, headers=headers)
|
||||
if rsp.status_code == 200:
|
||||
|
|
@ -111,9 +115,9 @@ def get_autodl_machines() -> Union[list, None]:
|
|||
|
||||
def payg(region_name:str, machine_id:str) -> tuple[Any, Any] | None:
|
||||
region_image = {
|
||||
"西北": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-232ac04d3b"],
|
||||
"内蒙": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-06814c02d1"],
|
||||
"北京": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-e5334cc4f3"]
|
||||
"西北": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-0d3ab01d64"],
|
||||
"内蒙": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-bfff5e82ae"],
|
||||
"北京": ["hub.kce.ksyun.com/autodl-image/miniconda:cuda11.8-cudnn8-devel-ubuntu20.04-py38","image-537b5da0c5"]
|
||||
}
|
||||
headers = {
|
||||
"Authorization": token,
|
||||
|
|
@ -149,10 +153,10 @@ def payg(region_name:str, machine_id:str) -> tuple[Any, Any] | None:
|
|||
if j["code"] == "Success":
|
||||
lim = LIM
|
||||
while lim>0:
|
||||
time.sleep(5)
|
||||
time.sleep(3)
|
||||
status, host, port, pwd, domain = check_status(j['data'])
|
||||
if status == "running":
|
||||
ssh_try(host, port, pwd)
|
||||
ssh_try(host, port, pwd, j['data'])
|
||||
break
|
||||
else:
|
||||
lim = lim-1
|
||||
|
|
@ -188,22 +192,22 @@ def instance_operate(instance_uuid:str, operation: Literal["power_off","power_on
|
|||
if operation in dest_dict.keys():
|
||||
while lim>0:
|
||||
time.sleep(5)
|
||||
status = check_status(instance_uuid)[0]
|
||||
if status == dest_dict[operation]:
|
||||
status = check_status(instance_uuid)
|
||||
if status and status[0] == dest_dict[operation]:
|
||||
break
|
||||
else:
|
||||
lim = lim-1
|
||||
if lim > 0:
|
||||
loguru.logger.success("Operate[%s] Instance Success" % operation)
|
||||
loguru.logger.success("Operate[%s] Instance[%s] Success" % (operation, instance_uuid))
|
||||
return True
|
||||
else:
|
||||
loguru.logger.error("Operate[%s] Instance Error: Timeout, Please Check!!! instance_uuid(%s)" % (operation, instance_uuid))
|
||||
loguru.logger.error("Operate[%s] Instance[%s] Error: Timeout, Please Check!!!" % (operation, instance_uuid))
|
||||
return False
|
||||
else:
|
||||
loguru.logger.error("Operate[%s] Instance Error: %s" % (operation, j['msg']))
|
||||
loguru.logger.error("Operate[%s] Instance[%s] Error: %s" % (operation, instance_uuid, j['msg']))
|
||||
return False
|
||||
else:
|
||||
loguru.logger.error("Operate[%s] Instance Error: Status Code[%s]" % (operation, rsp.status_code))
|
||||
loguru.logger.error("Operate[%s] Instance[%s] Error: Status Code[%s]" % (operation, instance_uuid, rsp.status_code))
|
||||
return False
|
||||
|
||||
def check_status(instance_uuid:str) -> tuple[Any, Any, Any, Any, Any] | None:
|
||||
|
|
@ -221,13 +225,27 @@ def check_status(instance_uuid:str) -> tuple[Any, Any, Any, Any, Any] | None:
|
|||
"status":[],
|
||||
"charge_type":[]
|
||||
}
|
||||
# loguru.logger.info("Req Instance index {}".format(index))
|
||||
rsp = requests.post("https://www.autodl.com/api/v1/sub_user/instance", json=payload, headers=headers)
|
||||
if DEBUG:
|
||||
loguru.logger.info("Req Instance index {}".format(index))
|
||||
lim = 3
|
||||
while lim > 0:
|
||||
rsp = None
|
||||
try:
|
||||
rsp = requests.post("https://www.autodl.com/api/v1/sub_user/instance", json=payload, headers=headers)
|
||||
except:
|
||||
lim -= 1
|
||||
if rsp:
|
||||
break
|
||||
if lim <= 0:
|
||||
loguru.logger.error("Get Instance Req Error")
|
||||
return None
|
||||
if rsp.status_code == 200:
|
||||
instance_list = rsp.json()
|
||||
# loguru.logger.info("Instance Result Total {}".format(instance_list["data"]["result_total"]))
|
||||
if DEBUG:
|
||||
loguru.logger.info("Instance Result Total {}".format(instance_list["data"]["result_total"]))
|
||||
while index < instance_list["data"]["max_page"]:
|
||||
# loguru.logger.info("Req Instance index {}/{}".format(index, instance_list["data"]["max_page"]))
|
||||
if DEBUG:
|
||||
loguru.logger.info("Req Instance index {}/{}".format(index, instance_list["data"]["max_page"]))
|
||||
payload["page_index"] = index
|
||||
rsp = requests.post("https://www.autodl.com/api/v1/sub_user/instance", json=payload, headers=headers)
|
||||
if rsp.status_code == 200:
|
||||
|
|
@ -237,7 +255,8 @@ def check_status(instance_uuid:str) -> tuple[Any, Any, Any, Any, Any] | None:
|
|||
return None
|
||||
for l in instance_list["data"]["list"]:
|
||||
if l["uuid"] == instance_uuid:
|
||||
loguru.logger.info("Instance {} Status {}".format(instance_uuid, l["status"]))
|
||||
if DEBUG:
|
||||
loguru.logger.info("Instance {} Status {}".format(instance_uuid, l["status"]))
|
||||
return l["status"], l["proxy_host"], l["ssh_port"], l["root_password"], l["tensorboard_domain"]
|
||||
loguru.logger.warning("Instance {} Not Found".format(instance_uuid))
|
||||
return None
|
||||
|
|
@ -101,7 +101,7 @@ auth_scheme = HTTPBearer()
|
|||
gpu=["L4", "T4"],
|
||||
cpu=(2,16),
|
||||
# memory=(32768, 32768), # (内存预留量, 内存使用上限)
|
||||
memory=(20480,40960),
|
||||
memory=(32768,131072),
|
||||
enable_memory_snapshot=False,
|
||||
secrets=[secret, modal.Secret.from_name("web_auth_token")],
|
||||
volumes={
|
||||
|
|
|
|||
Loading…
Reference in New Issue