新增base64格式的文件上传接口
This commit is contained in:
parent
629083210b
commit
be72473e4a
|
|
@ -1,5 +1,5 @@
|
||||||
MODAL_ENVIRONMENT=dev
|
MODAL_ENVIRONMENT=test
|
||||||
modal_app_name=cluster-test
|
modal_app_name=bowong-ai-video
|
||||||
S3_mount_dir=/mntS3
|
S3_mount_dir=/mntS3
|
||||||
S3_bucket_name=modal-media-cache
|
S3_bucket_name=modal-media-cache
|
||||||
S3_region=ap-northeast-2
|
S3_region=ap-northeast-2
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ async def alias_middleware(request: Request, call_next):
|
||||||
|
|
||||||
@web_app.get("/scalar", include_in_schema=False)
|
@web_app.get("/scalar", include_in_schema=False)
|
||||||
async def scalar():
|
async def scalar():
|
||||||
return get_scalar_api_reference(openapi_url=web_app.openapi_schema or '/openapi.json', title="Modal worker web endpoint")
|
return get_scalar_api_reference(openapi_url='/openapi.json', title="Modal worker web endpoint")
|
||||||
|
|
||||||
|
|
||||||
web_app.include_router(ffmpeg.router)
|
web_app.include_router(ffmpeg.router)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import base64
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
@ -6,7 +7,7 @@ from functools import cached_property
|
||||||
from typing import List, Union, Optional, Any, Dict
|
from typing import List, Union, Optional, Any, Dict
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from pydantic import (BaseModel, Field, field_validator, ValidationError,
|
from pydantic import (BaseModel, Field, field_validator, ValidationError,
|
||||||
field_serializer, SerializationInfo, computed_field, FileUrl)
|
field_serializer, SerializationInfo, computed_field, FileUrl, Base64Str, Base64Bytes)
|
||||||
from pydantic.json_schema import JsonSchemaValue
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
from ..config import WorkerConfig
|
from ..config import WorkerConfig
|
||||||
from ..utils.TimeUtils import TimeDelta
|
from ..utils.TimeUtils import TimeDelta
|
||||||
|
|
@ -63,7 +64,6 @@ class MediaSource(BaseModel):
|
||||||
# s3://{endpoint}/{bucket}/{url}
|
# s3://{endpoint}/{bucket}/{url}
|
||||||
paths = media_url[5:].split('/')
|
paths = media_url[5:].split('/')
|
||||||
|
|
||||||
|
|
||||||
if len(paths) < 3:
|
if len(paths) < 3:
|
||||||
raise ValidationError("URN-s3 格式错误")
|
raise ValidationError("URN-s3 格式错误")
|
||||||
|
|
||||||
|
|
@ -143,7 +143,6 @@ class MediaSource(BaseModel):
|
||||||
case _:
|
case _:
|
||||||
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
|
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
|
||||||
|
|
||||||
|
|
||||||
@computed_field(description="文件后缀名")
|
@computed_field(description="文件后缀名")
|
||||||
@cached_property
|
@cached_property
|
||||||
def file_extension(self) -> Optional[str]:
|
def file_extension(self) -> Optional[str]:
|
||||||
|
|
@ -209,5 +208,30 @@ class CacheResult(BaseModel):
|
||||||
class DownloadResult(BaseModel):
|
class DownloadResult(BaseModel):
|
||||||
urls: List[str] = Field(description="下载链接")
|
urls: List[str] = Field(description="下载链接")
|
||||||
|
|
||||||
|
|
||||||
class UploadResultResponse(BaseModel):
|
class UploadResultResponse(BaseModel):
|
||||||
media: MediaSource = Field(description="上传完成的媒体资源")
|
media: MediaSource = Field(description="上传完成的媒体资源")
|
||||||
|
|
||||||
|
|
||||||
|
class Base64File(BaseModel):
|
||||||
|
raw_content: Base64Bytes = Field(description="Base64编码的原始内容")
|
||||||
|
filename: str = Field(description="文件名, 包含文件类型后缀")
|
||||||
|
content_type: str = Field(description="文件数据类型")
|
||||||
|
|
||||||
|
# @computed_field(description="Base64解码后的二进制内容")
|
||||||
|
# @property
|
||||||
|
# def content(self) -> Base64Bytes:
|
||||||
|
# return base64.b64decode(self.raw_content)
|
||||||
|
|
||||||
|
@computed_field(description="文件字节大小")
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
# b64_str = self.raw_content.strip().replace('\n', '').replace('\r', '')
|
||||||
|
# padding = b64_str.count('=', -2) # 只统计末尾的 '='
|
||||||
|
# size = len(b64_str)
|
||||||
|
return len(self.raw_content)
|
||||||
|
|
||||||
|
|
||||||
|
class UploadBase64Request(BaseModel):
|
||||||
|
file: Base64File = Field(description="上传的文件")
|
||||||
|
prefix: Optional[str] = Field(description="文件存在的前缀目录", default=None)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,9 @@ from ..models.media_model import (MediaSources,
|
||||||
CacheResult,
|
CacheResult,
|
||||||
MediaSource,
|
MediaSource,
|
||||||
MediaCacheStatus,
|
MediaCacheStatus,
|
||||||
DownloadResult, UploadResultResponse
|
DownloadResult,
|
||||||
|
UploadResultResponse,
|
||||||
|
UploadBase64Request
|
||||||
)
|
)
|
||||||
from ..models.web_model import SentryTransactionInfo
|
from ..models.web_model import SentryTransactionInfo
|
||||||
from ..utils.KVCache import KVCache
|
from ..utils.KVCache import KVCache
|
||||||
|
|
@ -94,6 +96,7 @@ async def cache(medias: MediaSources) -> CacheResult:
|
||||||
KVCache.batch_update_cloudflare_kv(cache_task_result)
|
KVCache.batch_update_cloudflare_kv(cache_task_result)
|
||||||
return CacheResult(caches={media.urn: media for media in cache_task_result})
|
return CacheResult(caches={media.urn: media for media in cache_task_result})
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/",
|
@router.delete("/",
|
||||||
tags=["缓存"],
|
tags=["缓存"],
|
||||||
summary="清除指定的所有缓存",
|
summary="清除指定的所有缓存",
|
||||||
|
|
@ -119,6 +122,7 @@ async def purge_media_kv_file(medias: MediaSources):
|
||||||
KVCache.batch_remove_cloudflare_kv(keys)
|
KVCache.batch_remove_cloudflare_kv(keys)
|
||||||
return JSONResponse(content={"success": True, "keys": keys})
|
return JSONResponse(content={"success": True, "keys": keys})
|
||||||
|
|
||||||
|
|
||||||
@router.post("/download",
|
@router.post("/download",
|
||||||
tags=["缓存"],
|
tags=["缓存"],
|
||||||
summary="批量获取下载地址",
|
summary="批量获取下载地址",
|
||||||
|
|
@ -201,9 +205,10 @@ async def purge_media(medias: MediaSources):
|
||||||
return JSONResponse(content={"success": True, "keys": keys})
|
return JSONResponse(content={"success": True, "keys": keys})
|
||||||
|
|
||||||
|
|
||||||
@router.post("/upload-s3", tags=['缓存'],
|
@router.post("/upload-s3",
|
||||||
|
tags=['缓存'],
|
||||||
summary="上传文件到S3",
|
summary="上传文件到S3",
|
||||||
description="上传文件到S3当文件大于200M",
|
description="上传文件到S3的文件必须小于200M",
|
||||||
dependencies=[Depends(verify_token)])
|
dependencies=[Depends(verify_token)])
|
||||||
async def s3_upload(file: Annotated[UploadFile, File(description="上传的文件")],
|
async def s3_upload(file: Annotated[UploadFile, File(description="上传的文件")],
|
||||||
prefix: Annotated[Optional[str], Form()] = None) -> UploadResultResponse:
|
prefix: Annotated[Optional[str], Form()] = None) -> UploadResultResponse:
|
||||||
|
|
@ -223,3 +228,27 @@ async def s3_upload(file: Annotated[UploadFile, File(description="上传的文
|
||||||
media_source.status = MediaCacheStatus.ready
|
media_source.status = MediaCacheStatus.ready
|
||||||
media_source.downloader_id = fn_id
|
media_source.downloader_id = fn_id
|
||||||
return UploadResultResponse(media=media_source)
|
return UploadResultResponse(media=media_source)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/upload-s3-b64',
|
||||||
|
tags=['缓存'],
|
||||||
|
summary="基于Base64格式上传文件到S3",
|
||||||
|
description="上传文件到S3当文件必须小于200M",
|
||||||
|
dependencies=[Depends(verify_token)])
|
||||||
|
async def s3_upload_base64(body: UploadBase64Request) -> UploadResultResponse:
|
||||||
|
fn_id = current_function_call_id()
|
||||||
|
prefix = body.prefix
|
||||||
|
file = body.file
|
||||||
|
if file.size > 200 * 1024 * 1024: # 上传文件不大于200M
|
||||||
|
raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="上传文件不可超过200MB")
|
||||||
|
key = f"upload/{prefix}/{file.filename}" if prefix else f"upload/{file.filename}"
|
||||||
|
local_path = f"{config.S3_mount_dir}/{key}"
|
||||||
|
logger.info(f"s3上传到{key}, size={file.size}")
|
||||||
|
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||||
|
with open(local_path, 'wb') as f:
|
||||||
|
f.write(file.raw_content)
|
||||||
|
logger.info(f"{local_path} 保存成功")
|
||||||
|
media_source = MediaSource.from_str(f"s3://{config.S3_region}/{config.S3_bucket_name}/{key}")
|
||||||
|
media_source.status = MediaCacheStatus.ready
|
||||||
|
media_source.downloader_id = fn_id
|
||||||
|
return UploadResultResponse(media=media_source)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue