新增base64格式的文件上传接口

This commit is contained in:
shuohigh@gmail.com 2025-05-21 19:17:57 +08:00
parent 629083210b
commit be72473e4a
4 changed files with 71 additions and 18 deletions

View File

@ -1,5 +1,5 @@
MODAL_ENVIRONMENT=dev MODAL_ENVIRONMENT=test
modal_app_name=cluster-test modal_app_name=bowong-ai-video
S3_mount_dir=/mntS3 S3_mount_dir=/mntS3
S3_bucket_name=modal-media-cache S3_bucket_name=modal-media-cache
S3_region=ap-northeast-2 S3_region=ap-northeast-2

View File

@ -76,7 +76,7 @@ async def alias_middleware(request: Request, call_next):
@web_app.get("/scalar", include_in_schema=False) @web_app.get("/scalar", include_in_schema=False)
async def scalar(): async def scalar():
return get_scalar_api_reference(openapi_url=web_app.openapi_schema or '/openapi.json', title="Modal worker web endpoint") return get_scalar_api_reference(openapi_url='/openapi.json', title="Modal worker web endpoint")
web_app.include_router(ffmpeg.router) web_app.include_router(ffmpeg.router)

View File

@ -1,3 +1,4 @@
import base64
import os import os
import re import re
from datetime import datetime from datetime import datetime
@ -6,7 +7,7 @@ from functools import cached_property
from typing import List, Union, Optional, Any, Dict from typing import List, Union, Optional, Any, Dict
from urllib.parse import urlparse from urllib.parse import urlparse
from pydantic import (BaseModel, Field, field_validator, ValidationError, from pydantic import (BaseModel, Field, field_validator, ValidationError,
field_serializer, SerializationInfo, computed_field, FileUrl) field_serializer, SerializationInfo, computed_field, FileUrl, Base64Str, Base64Bytes)
from pydantic.json_schema import JsonSchemaValue from pydantic.json_schema import JsonSchemaValue
from ..config import WorkerConfig from ..config import WorkerConfig
from ..utils.TimeUtils import TimeDelta from ..utils.TimeUtils import TimeDelta
@ -63,15 +64,14 @@ class MediaSource(BaseModel):
# s3://{endpoint}/{bucket}/{url} # s3://{endpoint}/{bucket}/{url}
paths = media_url[5:].split('/') paths = media_url[5:].split('/')
if len(paths) < 3: if len(paths) < 3:
raise ValidationError("URN-s3 格式错误") raise ValidationError("URN-s3 格式错误")
media_source = MediaSource(path='/'.join(paths[2:]), media_source = MediaSource(path='/'.join(paths[2:]),
protocol=MediaProtocol.s3, protocol=MediaProtocol.s3,
endpoint=paths[0], endpoint=paths[0],
bucket=paths[1], bucket=paths[1],
urn=media_url) urn=media_url)
s3_mount_path = config.S3_mount_dir s3_mount_path = config.S3_mount_dir
cache_path = os.path.join(s3_mount_path, media_source.cache_filepath) cache_path = os.path.join(s3_mount_path, media_source.cache_filepath)
@ -143,7 +143,6 @@ class MediaSource(BaseModel):
case _: case _:
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}" return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
@computed_field(description="文件后缀名") @computed_field(description="文件后缀名")
@cached_property @cached_property
def file_extension(self) -> Optional[str]: def file_extension(self) -> Optional[str]:
@ -209,5 +208,30 @@ class CacheResult(BaseModel):
class DownloadResult(BaseModel): class DownloadResult(BaseModel):
urls: List[str] = Field(description="下载链接") urls: List[str] = Field(description="下载链接")
class UploadResultResponse(BaseModel): class UploadResultResponse(BaseModel):
media: MediaSource = Field(description="上传完成的媒体资源") media: MediaSource = Field(description="上传完成的媒体资源")
class Base64File(BaseModel):
raw_content: Base64Bytes = Field(description="Base64编码的原始内容")
filename: str = Field(description="文件名, 包含文件类型后缀")
content_type: str = Field(description="文件数据类型")
# @computed_field(description="Base64解码后的二进制内容")
# @property
# def content(self) -> Base64Bytes:
# return base64.b64decode(self.raw_content)
@computed_field(description="文件字节大小")
@property
def size(self) -> int:
# b64_str = self.raw_content.strip().replace('\n', '').replace('\r', '')
# padding = b64_str.count('=', -2) # 只统计末尾的 '='
# size = len(b64_str)
return len(self.raw_content)
class UploadBase64Request(BaseModel):
file: Base64File = Field(description="上传的文件")
prefix: Optional[str] = Field(description="文件存在的前缀目录", default=None)

View File

@ -15,7 +15,9 @@ from ..models.media_model import (MediaSources,
CacheResult, CacheResult,
MediaSource, MediaSource,
MediaCacheStatus, MediaCacheStatus,
DownloadResult, UploadResultResponse DownloadResult,
UploadResultResponse,
UploadBase64Request
) )
from ..models.web_model import SentryTransactionInfo from ..models.web_model import SentryTransactionInfo
from ..utils.KVCache import KVCache from ..utils.KVCache import KVCache
@ -94,11 +96,12 @@ async def cache(medias: MediaSources) -> CacheResult:
KVCache.batch_update_cloudflare_kv(cache_task_result) KVCache.batch_update_cloudflare_kv(cache_task_result)
return CacheResult(caches={media.urn: media for media in cache_task_result}) return CacheResult(caches={media.urn: media for media in cache_task_result})
@router.delete("/", @router.delete("/",
tags=["缓存"], tags=["缓存"],
summary="清除指定的所有缓存", summary="清除指定的所有缓存",
description="清除指定的所有缓存(包括KV记录和S3存储文件)", description="清除指定的所有缓存(包括KV记录和S3存储文件)",
dependencies=[Depends(verify_token)]) dependencies=[Depends(verify_token)])
async def purge_media_kv_file(medias: MediaSources): async def purge_media_kv_file(medias: MediaSources):
fn_id = current_function_call_id() fn_id = current_function_call_id()
fn = modal.Function.from_name(config.modal_app_name, "cache_delete", environment_name=config.modal_environment) fn = modal.Function.from_name(config.modal_app_name, "cache_delete", environment_name=config.modal_environment)
@ -119,6 +122,7 @@ async def purge_media_kv_file(medias: MediaSources):
KVCache.batch_remove_cloudflare_kv(keys) KVCache.batch_remove_cloudflare_kv(keys)
return JSONResponse(content={"success": True, "keys": keys}) return JSONResponse(content={"success": True, "keys": keys})
@router.post("/download", @router.post("/download",
tags=["缓存"], tags=["缓存"],
summary="批量获取下载地址", summary="批量获取下载地址",
@ -201,9 +205,10 @@ async def purge_media(medias: MediaSources):
return JSONResponse(content={"success": True, "keys": keys}) return JSONResponse(content={"success": True, "keys": keys})
@router.post("/upload-s3", tags=['缓存'], @router.post("/upload-s3",
tags=['缓存'],
summary="上传文件到S3", summary="上传文件到S3",
description="上传文件到S3当文件大于200M", description="上传文件到S3的文件必须小于200M",
dependencies=[Depends(verify_token)]) dependencies=[Depends(verify_token)])
async def s3_upload(file: Annotated[UploadFile, File(description="上传的文件")], async def s3_upload(file: Annotated[UploadFile, File(description="上传的文件")],
prefix: Annotated[Optional[str], Form()] = None) -> UploadResultResponse: prefix: Annotated[Optional[str], Form()] = None) -> UploadResultResponse:
@ -223,3 +228,27 @@ async def s3_upload(file: Annotated[UploadFile, File(description="上传的文
media_source.status = MediaCacheStatus.ready media_source.status = MediaCacheStatus.ready
media_source.downloader_id = fn_id media_source.downloader_id = fn_id
return UploadResultResponse(media=media_source) return UploadResultResponse(media=media_source)
@router.post('/upload-s3-b64',
tags=['缓存'],
summary="基于Base64格式上传文件到S3",
description="上传文件到S3当文件必须小于200M",
dependencies=[Depends(verify_token)])
async def s3_upload_base64(body: UploadBase64Request) -> UploadResultResponse:
fn_id = current_function_call_id()
prefix = body.prefix
file = body.file
if file.size > 200 * 1024 * 1024: # 上传文件不大于200M
raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="上传文件不可超过200MB")
key = f"upload/{prefix}/{file.filename}" if prefix else f"upload/{file.filename}"
local_path = f"{config.S3_mount_dir}/{key}"
logger.info(f"s3上传到{key}, size={file.size}")
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with open(local_path, 'wb') as f:
f.write(file.raw_content)
logger.info(f"{local_path} 保存成功")
media_source = MediaSource.from_str(f"s3://{config.S3_region}/{config.S3_bucket_name}/{key}")
media_source.status = MediaCacheStatus.ready
media_source.downloader_id = fn_id
return UploadResultResponse(media=media_source)