新增base64格式的文件上传接口

This commit is contained in:
shuohigh@gmail.com 2025-05-21 19:17:57 +08:00
parent 629083210b
commit be72473e4a
4 changed files with 71 additions and 18 deletions

View File

@ -1,5 +1,5 @@
MODAL_ENVIRONMENT=dev
modal_app_name=cluster-test
MODAL_ENVIRONMENT=test
modal_app_name=bowong-ai-video
S3_mount_dir=/mntS3
S3_bucket_name=modal-media-cache
S3_region=ap-northeast-2

View File

@ -76,7 +76,7 @@ async def alias_middleware(request: Request, call_next):
@web_app.get("/scalar", include_in_schema=False)
async def scalar():
return get_scalar_api_reference(openapi_url=web_app.openapi_schema or '/openapi.json', title="Modal worker web endpoint")
return get_scalar_api_reference(openapi_url='/openapi.json', title="Modal worker web endpoint")
web_app.include_router(ffmpeg.router)

View File

@ -1,3 +1,4 @@
import base64
import os
import re
from datetime import datetime
@ -6,7 +7,7 @@ from functools import cached_property
from typing import List, Union, Optional, Any, Dict
from urllib.parse import urlparse
from pydantic import (BaseModel, Field, field_validator, ValidationError,
field_serializer, SerializationInfo, computed_field, FileUrl)
field_serializer, SerializationInfo, computed_field, FileUrl, Base64Str, Base64Bytes)
from pydantic.json_schema import JsonSchemaValue
from ..config import WorkerConfig
from ..utils.TimeUtils import TimeDelta
@ -63,7 +64,6 @@ class MediaSource(BaseModel):
# s3://{endpoint}/{bucket}/{url}
paths = media_url[5:].split('/')
if len(paths) < 3:
raise ValidationError("URN-s3 格式错误")
@ -143,7 +143,6 @@ class MediaSource(BaseModel):
case _:
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
@computed_field(description="文件后缀名")
@cached_property
def file_extension(self) -> Optional[str]:
@ -209,5 +208,30 @@ class CacheResult(BaseModel):
class DownloadResult(BaseModel):
urls: List[str] = Field(description="下载链接")
class UploadResultResponse(BaseModel):
media: MediaSource = Field(description="上传完成的媒体资源")
class Base64File(BaseModel):
raw_content: Base64Bytes = Field(description="Base64编码的原始内容")
filename: str = Field(description="文件名, 包含文件类型后缀")
content_type: str = Field(description="文件数据类型")
# @computed_field(description="Base64解码后的二进制内容")
# @property
# def content(self) -> Base64Bytes:
# return base64.b64decode(self.raw_content)
@computed_field(description="文件字节大小")
@property
def size(self) -> int:
# b64_str = self.raw_content.strip().replace('\n', '').replace('\r', '')
# padding = b64_str.count('=', -2) # 只统计末尾的 '='
# size = len(b64_str)
return len(self.raw_content)
class UploadBase64Request(BaseModel):
file: Base64File = Field(description="上传的文件")
prefix: Optional[str] = Field(description="文件存在的前缀目录", default=None)

View File

@ -15,7 +15,9 @@ from ..models.media_model import (MediaSources,
CacheResult,
MediaSource,
MediaCacheStatus,
DownloadResult, UploadResultResponse
DownloadResult,
UploadResultResponse,
UploadBase64Request
)
from ..models.web_model import SentryTransactionInfo
from ..utils.KVCache import KVCache
@ -94,6 +96,7 @@ async def cache(medias: MediaSources) -> CacheResult:
KVCache.batch_update_cloudflare_kv(cache_task_result)
return CacheResult(caches={media.urn: media for media in cache_task_result})
@router.delete("/",
tags=["缓存"],
summary="清除指定的所有缓存",
@ -119,6 +122,7 @@ async def purge_media_kv_file(medias: MediaSources):
KVCache.batch_remove_cloudflare_kv(keys)
return JSONResponse(content={"success": True, "keys": keys})
@router.post("/download",
tags=["缓存"],
summary="批量获取下载地址",
@ -201,9 +205,10 @@ async def purge_media(medias: MediaSources):
return JSONResponse(content={"success": True, "keys": keys})
@router.post("/upload-s3", tags=['缓存'],
@router.post("/upload-s3",
tags=['缓存'],
summary="上传文件到S3",
description="上传文件到S3当文件大于200M",
description="上传文件到S3的文件必须小于200M",
dependencies=[Depends(verify_token)])
async def s3_upload(file: Annotated[UploadFile, File(description="上传的文件")],
prefix: Annotated[Optional[str], Form()] = None) -> UploadResultResponse:
@ -223,3 +228,27 @@ async def s3_upload(file: Annotated[UploadFile, File(description="上传的文
media_source.status = MediaCacheStatus.ready
media_source.downloader_id = fn_id
return UploadResultResponse(media=media_source)
@router.post('/upload-s3-b64',
tags=['缓存'],
summary="基于Base64格式上传文件到S3",
description="上传文件到S3当文件必须小于200M",
dependencies=[Depends(verify_token)])
async def s3_upload_base64(body: UploadBase64Request) -> UploadResultResponse:
fn_id = current_function_call_id()
prefix = body.prefix
file = body.file
if file.size > 200 * 1024 * 1024: # 上传文件不大于200M
raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="上传文件不可超过200MB")
key = f"upload/{prefix}/{file.filename}" if prefix else f"upload/{file.filename}"
local_path = f"{config.S3_mount_dir}/{key}"
logger.info(f"s3上传到{key}, size={file.size}")
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with open(local_path, 'wb') as f:
f.write(file.raw_content)
logger.info(f"{local_path} 保存成功")
media_source = MediaSource.from_str(f"s3://{config.S3_region}/{config.S3_bucket_name}/{key}")
media_source.status = MediaCacheStatus.ready
media_source.downloader_id = fn_id
return UploadResultResponse(media=media_source)