Merge remote-tracking branch 'origin/feature/temp-comfy-cluster' into feature/temp-comfy-cluster
# Conflicts: # src/cluster/web.py
This commit is contained in:
commit
8a3404dc2e
|
|
@ -76,7 +76,7 @@ async def alias_middleware(request: Request, call_next):
|
|||
|
||||
@web_app.get("/scalar", include_in_schema=False)
|
||||
async def scalar():
|
||||
return get_scalar_api_reference(openapi_url=web_app.openapi_schema, title="Modal worker web endpoint")
|
||||
return get_scalar_api_reference(openapi_url=web_app.openapi_schema or '/openapi.json', title="Modal worker web endpoint")
|
||||
|
||||
|
||||
web_app.include_router(ffmpeg.router)
|
||||
|
|
|
|||
|
|
@ -3,10 +3,9 @@ from pydantic import BaseModel, Field, computed_field, field_validator
|
|||
from pydantic.json_schema import JsonSchemaValue
|
||||
from ..utils.TimeUtils import TimeDelta
|
||||
|
||||
|
||||
class FFMpegSliceSegment(BaseModel):
|
||||
start: TimeDelta = Field(description="视频切割的开始时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)")
|
||||
end: TimeDelta = Field(description="视频切割的结束时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)")
|
||||
start: TimeDelta = Field(description="视频切割的开始时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)或者为标准格式的时间戳")
|
||||
end: TimeDelta = Field(description="视频切割的结束时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)或者标准格式的时间戳")
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
|
|
@ -15,11 +14,13 @@ class FFMpegSliceSegment(BaseModel):
|
|||
|
||||
@field_validator('start', mode='before')
|
||||
@classmethod
|
||||
def parse_start(cls, v: Union[float, TimeDelta]):
|
||||
def parse_start(cls, v: Union[float, str, TimeDelta]):
|
||||
if isinstance(v, float):
|
||||
return TimeDelta(seconds=v)
|
||||
elif isinstance(v, int):
|
||||
return TimeDelta(seconds=v)
|
||||
elif isinstance(v, str):
|
||||
return TimeDelta.from_format_string(v)
|
||||
elif isinstance(v, TimeDelta):
|
||||
return v
|
||||
else:
|
||||
|
|
@ -32,6 +33,8 @@ class FFMpegSliceSegment(BaseModel):
|
|||
return TimeDelta(seconds=v)
|
||||
elif isinstance(v, int):
|
||||
return TimeDelta(seconds=v)
|
||||
elif isinstance(v, str):
|
||||
return TimeDelta.from_format_string(v)
|
||||
elif isinstance(v, TimeDelta):
|
||||
return v
|
||||
else:
|
||||
|
|
@ -45,11 +48,11 @@ class FFMpegSliceSegment(BaseModel):
|
|||
"properties": {
|
||||
"start": {
|
||||
"type": "number",
|
||||
"examples": [5, 10.5]
|
||||
"examples": [5, 10.5, '00:00:10.500'],
|
||||
},
|
||||
"end": {
|
||||
"type": "number",
|
||||
"examples": [8, 12.5]
|
||||
"examples": [8, 12.5, '00:00:12.500'],
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@ from functools import cached_property
|
|||
from typing import List, Union, Optional, Any, Dict
|
||||
from urllib.parse import urlparse
|
||||
from pydantic import (BaseModel, Field, field_validator, ValidationError,
|
||||
field_serializer, SerializationInfo, computed_field)
|
||||
field_serializer, SerializationInfo, computed_field, FileUrl)
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from ..config import WorkerConfig
|
||||
from ..utils.TimeUtils import TimeDelta
|
||||
|
||||
config = WorkerConfig()
|
||||
|
||||
|
|
@ -165,7 +166,7 @@ class MediaSource(BaseModel):
|
|||
def __str__(self):
|
||||
match self.protocol:
|
||||
case MediaProtocol.http:
|
||||
return f"{self.protocol.value}://{self.path}"
|
||||
return self.urn
|
||||
case MediaProtocol.hls:
|
||||
return f"{self.protocol.value}://{self.path[8:]}" # strip "https://" from url
|
||||
case _:
|
||||
|
|
@ -208,4 +209,5 @@ class CacheResult(BaseModel):
|
|||
class DownloadResult(BaseModel):
|
||||
urls: List[str] = Field(description="下载链接")
|
||||
|
||||
|
||||
class UploadResultResponse(BaseModel):
|
||||
media: MediaSource = Field(description="上传完成的媒体资源")
|
||||
|
|
@ -1,17 +1,22 @@
|
|||
import asyncio
|
||||
import os
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import modal
|
||||
from loguru import logger
|
||||
from modal import current_function_call_id
|
||||
import sentry_sdk
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import APIRouter, Depends, UploadFile, HTTPException, File, Form
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from starlette import status
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from ..config import WorkerConfig
|
||||
from ..middleware.authorization import verify_token
|
||||
from ..models.media_model import MediaSources, CacheResult, MediaSource, MediaCacheStatus, DownloadResult
|
||||
|
||||
from ..models.media_model import (MediaSources,
|
||||
CacheResult,
|
||||
MediaSource,
|
||||
MediaCacheStatus,
|
||||
DownloadResult, UploadResultResponse
|
||||
)
|
||||
from ..models.web_model import SentryTransactionInfo
|
||||
from ..utils.KVCache import KVCache
|
||||
from ..utils.SentryUtils import SentryUtils
|
||||
|
|
@ -52,15 +57,11 @@ async def cache(medias: MediaSources) -> CacheResult:
|
|||
queue_publish_span.set_data("messaging.destination.name",
|
||||
"video-downloader.cache_submit")
|
||||
queue_publish_span.set_data("messaging.message.body.size", 0)
|
||||
# video_cache = MediaCache(status=MediaCacheStatus.downloading,
|
||||
# downloader_id=fn_task.object_id)
|
||||
media.status = MediaCacheStatus.downloading
|
||||
media.downloader_id = fn_task.object_id
|
||||
# video_cache_status_json = video_cache.model_dump_json()
|
||||
modal_kv_cache.set_cache(media)
|
||||
else:
|
||||
media = cached_media
|
||||
# video_cache = MediaCache.model_validate_json(video_cache_status_json)
|
||||
match media.status:
|
||||
case MediaCacheStatus.ready:
|
||||
cache_hit = True
|
||||
|
|
@ -79,14 +80,9 @@ async def cache(medias: MediaSources) -> CacheResult:
|
|||
queue_publish_span.set_data("messaging.message.body.size", 0)
|
||||
media.status = MediaCacheStatus.downloading
|
||||
media.downloader_id = fn_task.object_id
|
||||
# video_cache = MediaCache(status=MediaCacheStatus.downloading,
|
||||
# downloader_id=fn_task.object_id)
|
||||
# video_cache_status_json = video_cache.model_dump_json()
|
||||
modal_kv_cache.set_cache(media)
|
||||
cache_hit = False
|
||||
|
||||
# caches[media.urn] = video_cache
|
||||
# logger.info(f"Media cache hit ? {cache_hit}")
|
||||
cache_span.set_data("cache.hit", cache_hit)
|
||||
return media
|
||||
|
||||
|
|
@ -98,6 +94,30 @@ async def cache(medias: MediaSources) -> CacheResult:
|
|||
KVCache.batch_update_cloudflare_kv(cache_task_result)
|
||||
return CacheResult(caches={media.urn: media for media in cache_task_result})
|
||||
|
||||
@router.delete("/",
|
||||
tags=["缓存"],
|
||||
summary="清除指定的所有缓存",
|
||||
description="清除指定的所有缓存(包括KV记录和S3存储文件)",
|
||||
dependencies=[Depends(verify_token)])
|
||||
async def purge_media_kv_file(medias: MediaSources):
|
||||
fn_id = current_function_call_id()
|
||||
fn = modal.Function.from_name(config.modal_app_name, "cache_delete", environment_name=config.modal_environment)
|
||||
|
||||
@SentryUtils.sentry_tracker(name="清除媒体源缓存", op="cache.purge", fn_id=fn_id,
|
||||
sentry_trace_id=None, sentry_baggage=None)
|
||||
async def purge_handle(media: MediaSource):
|
||||
cache_media = modal_kv_cache.pop(media.urn)
|
||||
if cache_media:
|
||||
deleted_cache: MediaSource = await fn.remote.aio(cache_media)
|
||||
return deleted_cache.urn
|
||||
return None
|
||||
|
||||
async with asyncio.TaskGroup() as group:
|
||||
tasks = [group.create_task(purge_handle(media)) for media in medias.inputs]
|
||||
|
||||
keys = [task.result() for task in tasks]
|
||||
KVCache.batch_remove_cloudflare_kv(keys)
|
||||
return JSONResponse(content={"success": True, "keys": keys})
|
||||
|
||||
@router.post("/download",
|
||||
tags=["缓存"],
|
||||
|
|
@ -157,7 +177,8 @@ async def purge_kv(medias: MediaSources):
|
|||
@router.post("/media",
|
||||
tags=["缓存"],
|
||||
summary="清除指定的所有缓存",
|
||||
description="清除指定的所有缓存(包括KV记录和S3存储文件)",
|
||||
description="清除指定的所有缓存(包括KV记录和S3存储文件), 将要被淘汰,使用DELETE /cache/替代",
|
||||
deprecated=True,
|
||||
dependencies=[Depends(verify_token)])
|
||||
async def purge_media(medias: MediaSources):
|
||||
fn_id = current_function_call_id()
|
||||
|
|
@ -178,3 +199,27 @@ async def purge_media(medias: MediaSources):
|
|||
keys = [task.result() for task in tasks]
|
||||
KVCache.batch_remove_cloudflare_kv(keys)
|
||||
return JSONResponse(content={"success": True, "keys": keys})
|
||||
|
||||
|
||||
@router.post("/upload-s3", tags=['缓存'],
|
||||
summary="上传文件到S3",
|
||||
description="上传文件到S3当文件大于200M",
|
||||
dependencies=[Depends(verify_token)])
|
||||
async def s3_upload(file: Annotated[UploadFile, File(description="上传的文件")],
|
||||
prefix: Annotated[Optional[str], Form()] = None) -> UploadResultResponse:
|
||||
fn_id = current_function_call_id()
|
||||
|
||||
if file.size > 200 * 1024 * 1024: # 上传文件不大于200M
|
||||
raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="上传文件不可超过200MB")
|
||||
|
||||
key = f"upload/{prefix}/{file.filename}" if prefix else f"upload/{file.filename}"
|
||||
local_path = f"{config.S3_mount_dir}/{key}"
|
||||
logger.info(f"s3上传到{key}, size={file.size}")
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
with open(local_path, 'wb') as f:
|
||||
f.write(file.file.read())
|
||||
logger.info(f"{local_path} 保存成功")
|
||||
media_source = MediaSource.from_str(f"s3://{config.S3_region}/{config.S3_bucket_name}/{key}")
|
||||
media_source.status = MediaCacheStatus.ready
|
||||
media_source.downloader_id = fn_id
|
||||
return UploadResultResponse(media=media_source)
|
||||
|
|
|
|||
|
|
@ -273,7 +273,8 @@ with downloader_image.imports():
|
|||
process_span.set_status("failed")
|
||||
case MediaProtocol.http:
|
||||
try:
|
||||
cache_filepath = f"{config.S3_mount_dir}/{media.cache_filepath}"
|
||||
cache_filepath = f"{config.S3_mount_dir}/{media.protocol.value}/{media.cache_filepath}"
|
||||
os.makedirs(os.path.dirname(cache_filepath), exist_ok=True)
|
||||
download_large_file(url=media.__str__(), output_path=cache_filepath)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
|
|
|||
Loading…
Reference in New Issue