ComfyUI-CustomNode/ext/modal_downloader_deploy.py

142 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import shutil
import tempfile
from pathlib import Path
import aiofiles
import httpx
import modal
from fastapi import FastAPI, HTTPException, UploadFile
from pydantic import BaseModel, HttpUrl
image = (
modal.Image.debian_slim(
python_version="3.10"
).pip_install(
["fastapi[standard]", "httpx", "aiofiles"]
)
)
app = modal.App(image=image)
vol = modal.Volume.from_name("comfy_model", create_if_missing=True)
DOWNLOAD_BASE_DIR = Path("/models")
@app.function(
cpu=(0.125, 8),
memory=(128, 4096),
scaledown_window=360,
timeout=600,
max_containers=500,
min_containers=0,
region="ap",
volumes={
"/models": vol
}
)
@modal.concurrent(max_inputs=20)
@modal.asgi_app()
def fastapi_webapp():
fastapi_app = FastAPI(
title="文件下载服务",
description="一个通过URL下载文件到指定目录的API端点",
)
# --- Pydantic 模型 ---
# 定义请求体的数据结构和验证规则
class DownloadRequest(BaseModel):
url: HttpUrl # Pydantic 会自动验证这是否是一个有效的URL
save_path: str # 用户指定的相对保存路径(例如 "videos/my_video.mp4" 或 "my_document.pdf"
# --- FastAPI 端点 ---
@fastapi_app.post("/download-file/", summary="从URL下载模型")
async def download_file_from_url(request: DownloadRequest):
if request.save_path.endswith("/"):
request.save_path += str(request.url).split("/")[-1].split("?")[0]
fn_call = await do_download.spawn.aio(request.url, request.save_path)
return {"task_id": fn_call.object_id}
@fastapi_app.post("/upload-file/", summary="上传模型")
async def upload_file(file: UploadFile, save_path: str):
if save_path.endswith("/"):
save_path += str(file.filename)
destination_path = DOWNLOAD_BASE_DIR.joinpath(save_path).resolve()
if os.path.exists(destination_path):
os.remove(destination_path)
print("删除成功")
with open(destination_path, "wb") as f:
f.write(await file.read())
return {"msg": "上传成功"}
return fastapi_app
@app.function(
cpu=(0.125, 8),
memory=(128, 4096),
scaledown_window=300,
timeout=3600,
max_containers=500,
min_containers=0,
region="ap",
volumes={
"/models": vol
}
)
async def do_download(url, save_path):
print(f"Downloading {url} to {save_path}")
file_name = os.path.basename(save_path)
if not file_name:
raise HTTPException(
status_code=400,
detail="无效的保存路径:无法提取文件名。"
)
# --- 安全性检查 ---
# 构建绝对目标路径
temp_path = "/tmp/" + str(file_name)
destination_path = DOWNLOAD_BASE_DIR.joinpath(save_path).resolve()
# 确保解析后的路径仍然在我们的基础下载目录内,防止目录遍历攻击
if not destination_path.is_relative_to(DOWNLOAD_BASE_DIR.resolve()):
raise HTTPException(
status_code=400,
detail="不安全的路径:禁止在指定的下载目录之外写入文件。"
)
# --- 文件下载与保存 ---
try:
# 创建目标文件夹(如果不存在)
destination_path.parent.mkdir(parents=True, exist_ok=True)
# 使用 httpx 进行异步网络请求
async with httpx.AsyncClient(follow_redirects=True) as client:
# 使用 stream=True 进行流式下载,适合大文件
async with client.stream("GET", str(url)) as response:
# 检查请求是否成功
response.raise_for_status()
# 使用 aiofiles 进行异步文件写入
async with aiofiles.open(temp_path, 'wb') as f:
async for chunk in response.aiter_bytes():
await f.write(chunk)
if os.path.exists(destination_path):
os.remove(destination_path)
print("删除成功")
os.system("cp -f {} {}".format(temp_path, destination_path))
print("移动成功")
except httpx.RequestError as e:
raise HTTPException(
status_code=502,
detail=f"下载文件时网络请求失败: {e}"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"处理文件时发生内部错误: {e}"
)
print({
"message": "文件下载成功",
"url": url,
"saved_at": str(destination_path)
})