新增小文件上传接口

This commit is contained in:
shuohigh@gmail.com 2025-05-20 13:54:15 +08:00
parent 668a5a33d4
commit d3afe850ed
6 changed files with 92 additions and 27 deletions

View File

@ -88,7 +88,7 @@ async def alias_middleware(request: Request, call_next):
@web_app.get("/scalar", include_in_schema=False) @web_app.get("/scalar", include_in_schema=False)
async def scalar(): async def scalar():
return get_scalar_api_reference(openapi_url=web_app.openapi_schema, title="Modal worker web endpoint") return get_scalar_api_reference(openapi_url=web_app.openapi_schema or '/openapi.json', title="Modal worker web endpoint")
web_app.include_router(ffmpeg.router) web_app.include_router(ffmpeg.router)

View File

@ -3,10 +3,9 @@ from pydantic import BaseModel, Field, computed_field, field_validator
from pydantic.json_schema import JsonSchemaValue from pydantic.json_schema import JsonSchemaValue
from ..utils.TimeUtils import TimeDelta from ..utils.TimeUtils import TimeDelta
class FFMpegSliceSegment(BaseModel): class FFMpegSliceSegment(BaseModel):
start: TimeDelta = Field(description="视频切割的开始时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)") start: TimeDelta = Field(description="视频切割的开始时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)或者为标准格式的时间戳")
end: TimeDelta = Field(description="视频切割的结束时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)") end: TimeDelta = Field(description="视频切割的结束时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)或者标准格式的时间戳")
@computed_field @computed_field
@property @property
@ -15,11 +14,13 @@ class FFMpegSliceSegment(BaseModel):
@field_validator('start', mode='before') @field_validator('start', mode='before')
@classmethod @classmethod
def parse_start(cls, v: Union[float, TimeDelta]): def parse_start(cls, v: Union[float, str, TimeDelta]):
if isinstance(v, float): if isinstance(v, float):
return TimeDelta(seconds=v) return TimeDelta(seconds=v)
elif isinstance(v, int): elif isinstance(v, int):
return TimeDelta(seconds=v) return TimeDelta(seconds=v)
elif isinstance(v, str):
return TimeDelta.from_format_string(v)
elif isinstance(v, TimeDelta): elif isinstance(v, TimeDelta):
return v return v
else: else:
@ -32,6 +33,8 @@ class FFMpegSliceSegment(BaseModel):
return TimeDelta(seconds=v) return TimeDelta(seconds=v)
elif isinstance(v, int): elif isinstance(v, int):
return TimeDelta(seconds=v) return TimeDelta(seconds=v)
elif isinstance(v, str):
return TimeDelta.from_format_string(v)
elif isinstance(v, TimeDelta): elif isinstance(v, TimeDelta):
return v return v
else: else:
@ -45,11 +48,11 @@ class FFMpegSliceSegment(BaseModel):
"properties": { "properties": {
"start": { "start": {
"type": "number", "type": "number",
"examples": [5, 10.5] "examples": [5, 10.5, '00:00:10.500'],
}, },
"end": { "end": {
"type": "number", "type": "number",
"examples": [8, 12.5] "examples": [8, 12.5, '00:00:12.500'],
} }
}, },
"required": [ "required": [

View File

@ -5,9 +5,10 @@ from functools import cached_property
from typing import List, Union, Optional, Any, Dict from typing import List, Union, Optional, Any, Dict
from urllib.parse import urlparse from urllib.parse import urlparse
from pydantic import (BaseModel, Field, field_validator, ValidationError, from pydantic import (BaseModel, Field, field_validator, ValidationError,
field_serializer, SerializationInfo, computed_field) field_serializer, SerializationInfo, computed_field, FileUrl)
from pydantic.json_schema import JsonSchemaValue from pydantic.json_schema import JsonSchemaValue
from ..config import WorkerConfig from ..config import WorkerConfig
from ..utils.TimeUtils import TimeDelta
config = WorkerConfig() config = WorkerConfig()
@ -145,7 +146,7 @@ class MediaSource(BaseModel):
def __str__(self): def __str__(self):
match self.protocol: match self.protocol:
case MediaProtocol.http: case MediaProtocol.http:
return f"{self.protocol.value}://{self.path}" return self.urn
case MediaProtocol.hls: case MediaProtocol.hls:
return f"{self.protocol.value}://{self.path[8:]}" # strip "https://" from url return f"{self.protocol.value}://{self.path[8:]}" # strip "https://" from url
case _: case _:
@ -187,3 +188,6 @@ class CacheResult(BaseModel):
class DownloadResult(BaseModel): class DownloadResult(BaseModel):
urls: List[str] = Field(description="下载链接") urls: List[str] = Field(description="下载链接")
class UploadResultResponse(BaseModel):
media: MediaSource = Field(description="上传完成的媒体资源")

View File

@ -1,17 +1,22 @@
import asyncio import asyncio
import os
from typing import Annotated, Optional
import modal import modal
from loguru import logger
from modal import current_function_call_id from modal import current_function_call_id
import sentry_sdk import sentry_sdk
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, UploadFile, HTTPException, File, Form
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse, RedirectResponse
from starlette import status from starlette import status
from starlette.responses import RedirectResponse
from ..config import WorkerConfig from ..config import WorkerConfig
from ..middleware.authorization import verify_token from ..middleware.authorization import verify_token
from ..models.media_model import MediaSources, CacheResult, MediaSource, MediaCacheStatus, DownloadResult from ..models.media_model import (MediaSources,
CacheResult,
MediaSource,
MediaCacheStatus,
DownloadResult, UploadResultResponse
)
from ..models.web_model import SentryTransactionInfo from ..models.web_model import SentryTransactionInfo
from ..utils.KVCache import KVCache from ..utils.KVCache import KVCache
from ..utils.SentryUtils import SentryUtils from ..utils.SentryUtils import SentryUtils
@ -52,15 +57,11 @@ async def cache(medias: MediaSources) -> CacheResult:
queue_publish_span.set_data("messaging.destination.name", queue_publish_span.set_data("messaging.destination.name",
"video-downloader.cache_submit") "video-downloader.cache_submit")
queue_publish_span.set_data("messaging.message.body.size", 0) queue_publish_span.set_data("messaging.message.body.size", 0)
# video_cache = MediaCache(status=MediaCacheStatus.downloading,
# downloader_id=fn_task.object_id)
media.status = MediaCacheStatus.downloading media.status = MediaCacheStatus.downloading
media.downloader_id = fn_task.object_id media.downloader_id = fn_task.object_id
# video_cache_status_json = video_cache.model_dump_json()
modal_kv_cache.set_cache(media) modal_kv_cache.set_cache(media)
else: else:
media = cached_media media = cached_media
# video_cache = MediaCache.model_validate_json(video_cache_status_json)
match media.status: match media.status:
case MediaCacheStatus.ready: case MediaCacheStatus.ready:
cache_hit = True cache_hit = True
@ -79,14 +80,9 @@ async def cache(medias: MediaSources) -> CacheResult:
queue_publish_span.set_data("messaging.message.body.size", 0) queue_publish_span.set_data("messaging.message.body.size", 0)
media.status = MediaCacheStatus.downloading media.status = MediaCacheStatus.downloading
media.downloader_id = fn_task.object_id media.downloader_id = fn_task.object_id
# video_cache = MediaCache(status=MediaCacheStatus.downloading,
# downloader_id=fn_task.object_id)
# video_cache_status_json = video_cache.model_dump_json()
modal_kv_cache.set_cache(media) modal_kv_cache.set_cache(media)
cache_hit = False cache_hit = False
# caches[media.urn] = video_cache
# logger.info(f"Media cache hit ? {cache_hit}")
cache_span.set_data("cache.hit", cache_hit) cache_span.set_data("cache.hit", cache_hit)
return media return media
@ -98,6 +94,30 @@ async def cache(medias: MediaSources) -> CacheResult:
KVCache.batch_update_cloudflare_kv(cache_task_result) KVCache.batch_update_cloudflare_kv(cache_task_result)
return CacheResult(caches={media.urn: media for media in cache_task_result}) return CacheResult(caches={media.urn: media for media in cache_task_result})
@router.delete("/",
tags=["缓存"],
summary="清除指定的所有缓存",
description="清除指定的所有缓存(包括KV记录和S3存储文件)",
dependencies=[Depends(verify_token)])
async def purge_media_kv_file(medias: MediaSources):
fn_id = current_function_call_id()
fn = modal.Function.from_name(config.modal_app_name, "cache_delete", environment_name=config.modal_environment)
@SentryUtils.sentry_tracker(name="清除媒体源缓存", op="cache.purge", fn_id=fn_id,
sentry_trace_id=None, sentry_baggage=None)
async def purge_handle(media: MediaSource):
cache_media = modal_kv_cache.pop(media.urn)
if cache_media:
deleted_cache: MediaSource = await fn.remote.aio(cache_media)
return deleted_cache.urn
return None
async with asyncio.TaskGroup() as group:
tasks = [group.create_task(purge_handle(media)) for media in medias.inputs]
keys = [task.result() for task in tasks]
KVCache.batch_remove_cloudflare_kv(keys)
return JSONResponse(content={"success": True, "keys": keys})
@router.post("/download", @router.post("/download",
tags=["缓存"], tags=["缓存"],
@ -157,7 +177,8 @@ async def purge_kv(medias: MediaSources):
@router.post("/media", @router.post("/media",
tags=["缓存"], tags=["缓存"],
summary="清除指定的所有缓存", summary="清除指定的所有缓存",
description="清除指定的所有缓存(包括KV记录和S3存储文件)", description="清除指定的所有缓存(包括KV记录和S3存储文件), 将要被淘汰使用DELETE /cache/替代",
deprecated=True,
dependencies=[Depends(verify_token)]) dependencies=[Depends(verify_token)])
async def purge_media(medias: MediaSources): async def purge_media(medias: MediaSources):
fn_id = current_function_call_id() fn_id = current_function_call_id()
@ -178,3 +199,27 @@ async def purge_media(medias: MediaSources):
keys = [task.result() for task in tasks] keys = [task.result() for task in tasks]
KVCache.batch_remove_cloudflare_kv(keys) KVCache.batch_remove_cloudflare_kv(keys)
return JSONResponse(content={"success": True, "keys": keys}) return JSONResponse(content={"success": True, "keys": keys})
@router.post("/upload-s3", tags=['缓存'],
summary="上传文件到S3",
description="上传文件到S3当文件大于200M",
dependencies=[Depends(verify_token)])
async def s3_upload(file: Annotated[UploadFile, File(description="上传的文件")],
prefix: Annotated[Optional[str], Form()] = None) -> UploadResultResponse:
fn_id = current_function_call_id()
if file.size > 200 * 1024 * 1024: # 上传文件不大于200M
raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="上传文件不可超过200MB")
key = f"upload/{prefix}/{file.filename}" if prefix else f"upload/{file.filename}"
local_path = f"{config.S3_mount_dir}/{key}"
logger.info(f"s3上传到{key}, size={file.size}")
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with open(local_path, 'wb') as f:
f.write(file.file.read())
logger.info(f"{local_path} 保存成功")
media_source = MediaSource.from_str(f"s3://{config.S3_region}/{config.S3_bucket_name}/{key}")
media_source.status = MediaCacheStatus.ready
media_source.downloader_id = fn_id
return UploadResultResponse(media=media_source)

View File

@ -273,7 +273,8 @@ with downloader_image.imports():
process_span.set_status("failed") process_span.set_status("failed")
case MediaProtocol.http: case MediaProtocol.http:
try: try:
cache_filepath = f"{config.S3_mount_dir}/{media.cache_filepath}" cache_filepath = f"{config.S3_mount_dir}/{media.protocol.value}/{media.cache_filepath}"
os.makedirs(os.path.dirname(cache_filepath), exist_ok=True)
download_large_file(url=media.__str__(), output_path=cache_filepath) download_large_file(url=media.__str__(), output_path=cache_filepath)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)

View File

@ -1,6 +1,10 @@
import modal import modal
from dotenv import dotenv_values from dotenv import dotenv_values
from BowongModalFunctions.config import WorkerConfig
config = WorkerConfig()
fastapi_image = ( fastapi_image = (
modal.Image modal.Image
.debian_slim(python_version="3.11") .debian_slim(python_version="3.11")
@ -13,6 +17,7 @@ fastapi_image = (
app = modal.App( app = modal.App(
name="web_app", name="web_app",
image=fastapi_image, image=fastapi_image,
secrets=[modal.Secret.from_name("aws-s3-secret", environment_name=config.modal_environment)],
include_source=False) include_source=False)
with fastapi_image.imports(): with fastapi_image.imports():
@ -22,7 +27,14 @@ with fastapi_image.imports():
@app.function(scaledown_window=60, @app.function(scaledown_window=60,
secrets=[ secrets=[
modal.Secret.from_name("cf-kv-secret", environment_name='dev'), modal.Secret.from_name("cf-kv-secret", environment_name='dev'),
]) ],
volumes={
config.S3_mount_dir: modal.CloudBucketMount(
bucket_name=config.S3_bucket_name,
secret=modal.Secret.from_name("aws-s3-secret",
environment_name=config.modal_environment),
),
}, )
@modal.concurrent(max_inputs=100) @modal.concurrent(max_inputs=100)
@modal.asgi_app() @modal.asgi_app()
def fastapi_webapp(): def fastapi_webapp():