142 lines
4.5 KiB
Python
142 lines
4.5 KiB
Python
import os
|
||
import shutil
|
||
import tempfile
|
||
from pathlib import Path
|
||
|
||
import aiofiles
|
||
import httpx
|
||
import modal
|
||
from fastapi import FastAPI, HTTPException, UploadFile
|
||
from pydantic import BaseModel, HttpUrl
|
||
|
||
image = (
|
||
modal.Image.debian_slim(
|
||
python_version="3.10"
|
||
).pip_install(
|
||
["fastapi[standard]", "httpx", "aiofiles"]
|
||
)
|
||
)
|
||
app = modal.App(image=image)
|
||
vol = modal.Volume.from_name("comfy_model", create_if_missing=True)
|
||
DOWNLOAD_BASE_DIR = Path("/models")
|
||
|
||
|
||
@app.function(
|
||
cpu=(0.125, 8),
|
||
memory=(128, 4096),
|
||
scaledown_window=360,
|
||
timeout=600,
|
||
max_containers=500,
|
||
min_containers=0,
|
||
region="ap",
|
||
volumes={
|
||
"/models": vol
|
||
}
|
||
)
|
||
@modal.concurrent(max_inputs=20)
|
||
@modal.asgi_app()
|
||
def fastapi_webapp():
|
||
fastapi_app = FastAPI(
|
||
title="文件下载服务",
|
||
description="一个通过URL下载文件到指定目录的API端点",
|
||
)
|
||
|
||
# --- Pydantic 模型 ---
|
||
# 定义请求体的数据结构和验证规则
|
||
class DownloadRequest(BaseModel):
|
||
url: HttpUrl # Pydantic 会自动验证这是否是一个有效的URL
|
||
save_path: str # 用户指定的相对保存路径(例如 "videos/my_video.mp4" 或 "my_document.pdf")
|
||
|
||
# --- FastAPI 端点 ---
|
||
@fastapi_app.post("/download-file/", summary="从URL下载模型")
|
||
async def download_file_from_url(request: DownloadRequest):
|
||
if request.save_path.endswith("/"):
|
||
request.save_path += str(request.url).split("/")[-1].split("?")[0]
|
||
fn_call = await do_download.spawn.aio(request.url, request.save_path)
|
||
return {"task_id": fn_call.object_id}
|
||
|
||
@fastapi_app.post("/upload-file/", summary="上传模型")
|
||
async def upload_file(file: UploadFile, save_path: str):
|
||
if save_path.endswith("/"):
|
||
save_path += str(file.filename)
|
||
destination_path = DOWNLOAD_BASE_DIR.joinpath(save_path).resolve()
|
||
if os.path.exists(destination_path):
|
||
os.remove(destination_path)
|
||
print("删除成功")
|
||
with open(destination_path, "wb") as f:
|
||
f.write(await file.read())
|
||
return {"msg": "上传成功"}
|
||
|
||
return fastapi_app
|
||
|
||
|
||
@app.function(
|
||
cpu=(0.125, 8),
|
||
memory=(128, 4096),
|
||
scaledown_window=300,
|
||
timeout=3600,
|
||
max_containers=500,
|
||
min_containers=0,
|
||
region="ap",
|
||
volumes={
|
||
"/models": vol
|
||
}
|
||
)
|
||
async def do_download(url, save_path):
|
||
print(f"Downloading {url} to {save_path}")
|
||
file_name = os.path.basename(save_path)
|
||
if not file_name:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="无效的保存路径:无法提取文件名。"
|
||
)
|
||
|
||
# --- 安全性检查 ---
|
||
# 构建绝对目标路径
|
||
temp_path = "/tmp/" + str(file_name)
|
||
destination_path = DOWNLOAD_BASE_DIR.joinpath(save_path).resolve()
|
||
|
||
# 确保解析后的路径仍然在我们的基础下载目录内,防止目录遍历攻击
|
||
if not destination_path.is_relative_to(DOWNLOAD_BASE_DIR.resolve()):
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="不安全的路径:禁止在指定的下载目录之外写入文件。"
|
||
)
|
||
# --- 文件下载与保存 ---
|
||
try:
|
||
# 创建目标文件夹(如果不存在)
|
||
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 使用 httpx 进行异步网络请求
|
||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||
# 使用 stream=True 进行流式下载,适合大文件
|
||
async with client.stream("GET", str(url)) as response:
|
||
# 检查请求是否成功
|
||
response.raise_for_status()
|
||
|
||
# 使用 aiofiles 进行异步文件写入
|
||
async with aiofiles.open(temp_path, 'wb') as f:
|
||
async for chunk in response.aiter_bytes():
|
||
await f.write(chunk)
|
||
if os.path.exists(destination_path):
|
||
os.remove(destination_path)
|
||
print("删除成功")
|
||
os.system("cp -f {} {}".format(temp_path, destination_path))
|
||
print("移动成功")
|
||
except httpx.RequestError as e:
|
||
raise HTTPException(
|
||
status_code=502,
|
||
detail=f"下载文件时网络请求失败: {e}"
|
||
)
|
||
except Exception as e:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"处理文件时发生内部错误: {e}"
|
||
)
|
||
|
||
print({
|
||
"message": "文件下载成功",
|
||
"url": url,
|
||
"saved_at": str(destination_path)
|
||
})
|