新增base64格式的文件上传接口
This commit is contained in:
parent
629083210b
commit
be72473e4a
|
|
@ -1,5 +1,5 @@
|
|||
MODAL_ENVIRONMENT=dev
|
||||
modal_app_name=cluster-test
|
||||
MODAL_ENVIRONMENT=test
|
||||
modal_app_name=bowong-ai-video
|
||||
S3_mount_dir=/mntS3
|
||||
S3_bucket_name=modal-media-cache
|
||||
S3_region=ap-northeast-2
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ async def alias_middleware(request: Request, call_next):
|
|||
|
||||
@web_app.get("/scalar", include_in_schema=False)
|
||||
async def scalar():
|
||||
return get_scalar_api_reference(openapi_url=web_app.openapi_schema or '/openapi.json', title="Modal worker web endpoint")
|
||||
return get_scalar_api_reference(openapi_url='/openapi.json', title="Modal worker web endpoint")
|
||||
|
||||
|
||||
web_app.include_router(ffmpeg.router)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import base64
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
|
@ -6,7 +7,7 @@ from functools import cached_property
|
|||
from typing import List, Union, Optional, Any, Dict
|
||||
from urllib.parse import urlparse
|
||||
from pydantic import (BaseModel, Field, field_validator, ValidationError,
|
||||
field_serializer, SerializationInfo, computed_field, FileUrl)
|
||||
field_serializer, SerializationInfo, computed_field, FileUrl, Base64Str, Base64Bytes)
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from ..config import WorkerConfig
|
||||
from ..utils.TimeUtils import TimeDelta
|
||||
|
|
@ -63,15 +64,14 @@ class MediaSource(BaseModel):
|
|||
# s3://{endpoint}/{bucket}/{url}
|
||||
paths = media_url[5:].split('/')
|
||||
|
||||
|
||||
if len(paths) < 3:
|
||||
raise ValidationError("URN-s3 格式错误")
|
||||
|
||||
media_source = MediaSource(path='/'.join(paths[2:]),
|
||||
protocol=MediaProtocol.s3,
|
||||
endpoint=paths[0],
|
||||
bucket=paths[1],
|
||||
urn=media_url)
|
||||
protocol=MediaProtocol.s3,
|
||||
endpoint=paths[0],
|
||||
bucket=paths[1],
|
||||
urn=media_url)
|
||||
|
||||
s3_mount_path = config.S3_mount_dir
|
||||
cache_path = os.path.join(s3_mount_path, media_source.cache_filepath)
|
||||
|
|
@ -143,7 +143,6 @@ class MediaSource(BaseModel):
|
|||
case _:
|
||||
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
|
||||
|
||||
|
||||
@computed_field(description="文件后缀名")
|
||||
@cached_property
|
||||
def file_extension(self) -> Optional[str]:
|
||||
|
|
@ -209,5 +208,30 @@ class CacheResult(BaseModel):
|
|||
class DownloadResult(BaseModel):
|
||||
urls: List[str] = Field(description="下载链接")
|
||||
|
||||
|
||||
class UploadResultResponse(BaseModel):
|
||||
media: MediaSource = Field(description="上传完成的媒体资源")
|
||||
media: MediaSource = Field(description="上传完成的媒体资源")
|
||||
|
||||
|
||||
class Base64File(BaseModel):
|
||||
raw_content: Base64Bytes = Field(description="Base64编码的原始内容")
|
||||
filename: str = Field(description="文件名, 包含文件类型后缀")
|
||||
content_type: str = Field(description="文件数据类型")
|
||||
|
||||
# @computed_field(description="Base64解码后的二进制内容")
|
||||
# @property
|
||||
# def content(self) -> Base64Bytes:
|
||||
# return base64.b64decode(self.raw_content)
|
||||
|
||||
@computed_field(description="文件字节大小")
|
||||
@property
|
||||
def size(self) -> int:
|
||||
# b64_str = self.raw_content.strip().replace('\n', '').replace('\r', '')
|
||||
# padding = b64_str.count('=', -2) # 只统计末尾的 '='
|
||||
# size = len(b64_str)
|
||||
return len(self.raw_content)
|
||||
|
||||
|
||||
class UploadBase64Request(BaseModel):
|
||||
file: Base64File = Field(description="上传的文件")
|
||||
prefix: Optional[str] = Field(description="文件存在的前缀目录", default=None)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ from ..models.media_model import (MediaSources,
|
|||
CacheResult,
|
||||
MediaSource,
|
||||
MediaCacheStatus,
|
||||
DownloadResult, UploadResultResponse
|
||||
DownloadResult,
|
||||
UploadResultResponse,
|
||||
UploadBase64Request
|
||||
)
|
||||
from ..models.web_model import SentryTransactionInfo
|
||||
from ..utils.KVCache import KVCache
|
||||
|
|
@ -94,11 +96,12 @@ async def cache(medias: MediaSources) -> CacheResult:
|
|||
KVCache.batch_update_cloudflare_kv(cache_task_result)
|
||||
return CacheResult(caches={media.urn: media for media in cache_task_result})
|
||||
|
||||
|
||||
@router.delete("/",
|
||||
tags=["缓存"],
|
||||
summary="清除指定的所有缓存",
|
||||
description="清除指定的所有缓存(包括KV记录和S3存储文件)",
|
||||
dependencies=[Depends(verify_token)])
|
||||
tags=["缓存"],
|
||||
summary="清除指定的所有缓存",
|
||||
description="清除指定的所有缓存(包括KV记录和S3存储文件)",
|
||||
dependencies=[Depends(verify_token)])
|
||||
async def purge_media_kv_file(medias: MediaSources):
|
||||
fn_id = current_function_call_id()
|
||||
fn = modal.Function.from_name(config.modal_app_name, "cache_delete", environment_name=config.modal_environment)
|
||||
|
|
@ -119,6 +122,7 @@ async def purge_media_kv_file(medias: MediaSources):
|
|||
KVCache.batch_remove_cloudflare_kv(keys)
|
||||
return JSONResponse(content={"success": True, "keys": keys})
|
||||
|
||||
|
||||
@router.post("/download",
|
||||
tags=["缓存"],
|
||||
summary="批量获取下载地址",
|
||||
|
|
@ -201,9 +205,10 @@ async def purge_media(medias: MediaSources):
|
|||
return JSONResponse(content={"success": True, "keys": keys})
|
||||
|
||||
|
||||
@router.post("/upload-s3", tags=['缓存'],
|
||||
@router.post("/upload-s3",
|
||||
tags=['缓存'],
|
||||
summary="上传文件到S3",
|
||||
description="上传文件到S3当文件大于200M",
|
||||
description="上传文件到S3的文件必须小于200M",
|
||||
dependencies=[Depends(verify_token)])
|
||||
async def s3_upload(file: Annotated[UploadFile, File(description="上传的文件")],
|
||||
prefix: Annotated[Optional[str], Form()] = None) -> UploadResultResponse:
|
||||
|
|
@ -223,3 +228,27 @@ async def s3_upload(file: Annotated[UploadFile, File(description="上传的文
|
|||
media_source.status = MediaCacheStatus.ready
|
||||
media_source.downloader_id = fn_id
|
||||
return UploadResultResponse(media=media_source)
|
||||
|
||||
|
||||
@router.post('/upload-s3-b64',
|
||||
tags=['缓存'],
|
||||
summary="基于Base64格式上传文件到S3",
|
||||
description="上传文件到S3当文件必须小于200M",
|
||||
dependencies=[Depends(verify_token)])
|
||||
async def s3_upload_base64(body: UploadBase64Request) -> UploadResultResponse:
|
||||
fn_id = current_function_call_id()
|
||||
prefix = body.prefix
|
||||
file = body.file
|
||||
if file.size > 200 * 1024 * 1024: # 上传文件不大于200M
|
||||
raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="上传文件不可超过200MB")
|
||||
key = f"upload/{prefix}/{file.filename}" if prefix else f"upload/{file.filename}"
|
||||
local_path = f"{config.S3_mount_dir}/{key}"
|
||||
logger.info(f"s3上传到{key}, size={file.size}")
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
with open(local_path, 'wb') as f:
|
||||
f.write(file.raw_content)
|
||||
logger.info(f"{local_path} 保存成功")
|
||||
media_source = MediaSource.from_str(f"s3://{config.S3_region}/{config.S3_bucket_name}/{key}")
|
||||
media_source.status = MediaCacheStatus.ready
|
||||
media_source.downloader_id = fn_id
|
||||
return UploadResultResponse(media=media_source)
|
||||
|
|
|
|||
Loading…
Reference in New Issue