Merge branch refs/heads/feature/modal-cluster into refs/heads/feature/temp-comfy-cluster

This commit is contained in:
康宇佳 2025-05-20 14:40:58 +08:00
commit 9f410d30f7
6 changed files with 91 additions and 28 deletions

View File

@ -76,7 +76,7 @@ async def alias_middleware(request: Request, call_next):
@web_app.get("/scalar", include_in_schema=False)
async def scalar():
return get_scalar_api_reference(openapi_url=web_app.openapi_schema, title="Modal worker web endpoint")
return get_scalar_api_reference(openapi_url=web_app.openapi_schema or '/openapi.json', title="Modal worker web endpoint")
web_app.include_router(ffmpeg.router)

View File

@ -3,10 +3,9 @@ from pydantic import BaseModel, Field, computed_field, field_validator
from pydantic.json_schema import JsonSchemaValue
from ..utils.TimeUtils import TimeDelta
class FFMpegSliceSegment(BaseModel):
start: TimeDelta = Field(description="视频切割的开始时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)")
end: TimeDelta = Field(description="视频切割的结束时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)")
start: TimeDelta = Field(description="视频切割的开始时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)或者为标准格式的时间戳")
end: TimeDelta = Field(description="视频切割的结束时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)或者标准格式的时间戳")
@computed_field
@property
@ -15,11 +14,13 @@ class FFMpegSliceSegment(BaseModel):
@field_validator('start', mode='before')
@classmethod
def parse_start(cls, v: Union[float, TimeDelta]):
def parse_start(cls, v: Union[float, str, TimeDelta]):
if isinstance(v, float):
return TimeDelta(seconds=v)
elif isinstance(v, int):
return TimeDelta(seconds=v)
elif isinstance(v, str):
return TimeDelta.from_format_string(v)
elif isinstance(v, TimeDelta):
return v
else:
@ -32,6 +33,8 @@ class FFMpegSliceSegment(BaseModel):
return TimeDelta(seconds=v)
elif isinstance(v, int):
return TimeDelta(seconds=v)
elif isinstance(v, str):
return TimeDelta.from_format_string(v)
elif isinstance(v, TimeDelta):
return v
else:
@ -45,11 +48,11 @@ class FFMpegSliceSegment(BaseModel):
"properties": {
"start": {
"type": "number",
"examples": [5, 10.5]
"examples": [5, 10.5, '00:00:10.500'],
},
"end": {
"type": "number",
"examples": [8, 12.5]
"examples": [8, 12.5, '00:00:12.500'],
}
},
"required": [

View File

@ -6,9 +6,10 @@ from functools import cached_property
from typing import List, Union, Optional, Any, Dict
from urllib.parse import urlparse
from pydantic import (BaseModel, Field, field_validator, ValidationError,
field_serializer, SerializationInfo, computed_field)
field_serializer, SerializationInfo, computed_field, FileUrl)
from pydantic.json_schema import JsonSchemaValue
from ..config import WorkerConfig
from ..utils.TimeUtils import TimeDelta
config = WorkerConfig()
@ -165,7 +166,7 @@ class MediaSource(BaseModel):
def __str__(self):
match self.protocol:
case MediaProtocol.http:
return f"{self.protocol.value}://{self.path}"
return self.urn
case MediaProtocol.hls:
return f"{self.protocol.value}://{self.path[8:]}" # strip "https://" from url
case _:
@ -208,4 +209,5 @@ class CacheResult(BaseModel):
class DownloadResult(BaseModel):
urls: List[str] = Field(description="下载链接")
class UploadResultResponse(BaseModel):
media: MediaSource = Field(description="上传完成的媒体资源")

View File

@ -1,17 +1,22 @@
import asyncio
import os
from typing import Annotated, Optional
import modal
from loguru import logger
from modal import current_function_call_id
import sentry_sdk
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from fastapi import APIRouter, Depends, UploadFile, HTTPException, File, Form
from fastapi.responses import JSONResponse, RedirectResponse
from starlette import status
from starlette.responses import RedirectResponse
from ..config import WorkerConfig
from ..middleware.authorization import verify_token
from ..models.media_model import MediaSources, CacheResult, MediaSource, MediaCacheStatus, DownloadResult
from ..models.media_model import (MediaSources,
CacheResult,
MediaSource,
MediaCacheStatus,
DownloadResult, UploadResultResponse
)
from ..models.web_model import SentryTransactionInfo
from ..utils.KVCache import KVCache
from ..utils.SentryUtils import SentryUtils
@ -52,15 +57,11 @@ async def cache(medias: MediaSources) -> CacheResult:
queue_publish_span.set_data("messaging.destination.name",
"video-downloader.cache_submit")
queue_publish_span.set_data("messaging.message.body.size", 0)
# video_cache = MediaCache(status=MediaCacheStatus.downloading,
# downloader_id=fn_task.object_id)
media.status = MediaCacheStatus.downloading
media.downloader_id = fn_task.object_id
# video_cache_status_json = video_cache.model_dump_json()
modal_kv_cache.set_cache(media)
else:
media = cached_media
# video_cache = MediaCache.model_validate_json(video_cache_status_json)
match media.status:
case MediaCacheStatus.ready:
cache_hit = True
@ -79,14 +80,9 @@ async def cache(medias: MediaSources) -> CacheResult:
queue_publish_span.set_data("messaging.message.body.size", 0)
media.status = MediaCacheStatus.downloading
media.downloader_id = fn_task.object_id
# video_cache = MediaCache(status=MediaCacheStatus.downloading,
# downloader_id=fn_task.object_id)
# video_cache_status_json = video_cache.model_dump_json()
modal_kv_cache.set_cache(media)
cache_hit = False
# caches[media.urn] = video_cache
# logger.info(f"Media cache hit ? {cache_hit}")
cache_span.set_data("cache.hit", cache_hit)
return media
@ -98,6 +94,30 @@ async def cache(medias: MediaSources) -> CacheResult:
KVCache.batch_update_cloudflare_kv(cache_task_result)
return CacheResult(caches={media.urn: media for media in cache_task_result})
@router.delete("/",
tags=["缓存"],
summary="清除指定的所有缓存",
description="清除指定的所有缓存(包括KV记录和S3存储文件)",
dependencies=[Depends(verify_token)])
async def purge_media_kv_file(medias: MediaSources):
fn_id = current_function_call_id()
fn = modal.Function.from_name(config.modal_app_name, "cache_delete", environment_name=config.modal_environment)
@SentryUtils.sentry_tracker(name="清除媒体源缓存", op="cache.purge", fn_id=fn_id,
sentry_trace_id=None, sentry_baggage=None)
async def purge_handle(media: MediaSource):
cache_media = modal_kv_cache.pop(media.urn)
if cache_media:
deleted_cache: MediaSource = await fn.remote.aio(cache_media)
return deleted_cache.urn
return None
async with asyncio.TaskGroup() as group:
tasks = [group.create_task(purge_handle(media)) for media in medias.inputs]
keys = [task.result() for task in tasks]
KVCache.batch_remove_cloudflare_kv(keys)
return JSONResponse(content={"success": True, "keys": keys})
@router.post("/download",
tags=["缓存"],
@ -157,7 +177,8 @@ async def purge_kv(medias: MediaSources):
@router.post("/media",
tags=["缓存"],
summary="清除指定的所有缓存",
description="清除指定的所有缓存(包括KV记录和S3存储文件)",
description="清除指定的所有缓存(包括KV记录和S3存储文件), 将要被淘汰使用DELETE /cache/替代",
deprecated=True,
dependencies=[Depends(verify_token)])
async def purge_media(medias: MediaSources):
fn_id = current_function_call_id()
@ -178,3 +199,27 @@ async def purge_media(medias: MediaSources):
keys = [task.result() for task in tasks]
KVCache.batch_remove_cloudflare_kv(keys)
return JSONResponse(content={"success": True, "keys": keys})
@router.post("/upload-s3", tags=['缓存'],
summary="上传文件到S3",
description="上传文件到S3当文件大于200M",
dependencies=[Depends(verify_token)])
async def s3_upload(file: Annotated[UploadFile, File(description="上传的文件")],
prefix: Annotated[Optional[str], Form()] = None) -> UploadResultResponse:
fn_id = current_function_call_id()
if file.size > 200 * 1024 * 1024: # 上传文件不大于200M
raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="上传文件不可超过200MB")
key = f"upload/{prefix}/{file.filename}" if prefix else f"upload/{file.filename}"
local_path = f"{config.S3_mount_dir}/{key}"
logger.info(f"s3上传到{key}, size={file.size}")
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with open(local_path, 'wb') as f:
f.write(file.file.read())
logger.info(f"{local_path} 保存成功")
media_source = MediaSource.from_str(f"s3://{config.S3_region}/{config.S3_bucket_name}/{key}")
media_source.status = MediaCacheStatus.ready
media_source.downloader_id = fn_id
return UploadResultResponse(media=media_source)

View File

@ -273,7 +273,8 @@ with downloader_image.imports():
process_span.set_status("failed")
case MediaProtocol.http:
try:
cache_filepath = f"{config.S3_mount_dir}/{media.cache_filepath}"
cache_filepath = f"{config.S3_mount_dir}/{media.protocol.value}/{media.cache_filepath}"
os.makedirs(os.path.dirname(cache_filepath), exist_ok=True)
download_large_file(url=media.__str__(), output_path=cache_filepath)
except Exception as e:
logger.exception(e)

View File

@ -1,6 +1,10 @@
import modal
from dotenv import dotenv_values
from BowongModalFunctions.config import WorkerConfig
config = WorkerConfig()
fastapi_image = (
modal.Image
.debian_slim(python_version="3.11")
@ -13,6 +17,7 @@ fastapi_image = (
app = modal.App(
name="web_app",
image=fastapi_image,
secrets=[modal.Secret.from_name("aws-s3-secret", environment_name=config.modal_environment)],
include_source=False)
with fastapi_image.imports():
@ -22,7 +27,14 @@ with fastapi_image.imports():
@app.function(scaledown_window=60,
secrets=[
modal.Secret.from_name("cf-kv-secret", environment_name='dev'),
])
],
volumes={
config.S3_mount_dir: modal.CloudBucketMount(
bucket_name=config.S3_bucket_name,
secret=modal.Secret.from_name("aws-s3-secret",
environment_name=config.modal_environment),
),
}, )
@modal.concurrent(max_inputs=100)
@modal.asgi_app()
def fastapi_webapp():