Merge remote-tracking branch 'origin/feature/temp-comfy-cluster' into feature/temp-comfy-cluster
# Conflicts: # src/cluster/web.py
This commit is contained in:
commit
8a3404dc2e
|
|
@ -76,7 +76,7 @@ async def alias_middleware(request: Request, call_next):
|
||||||
|
|
||||||
@web_app.get("/scalar", include_in_schema=False)
|
@web_app.get("/scalar", include_in_schema=False)
|
||||||
async def scalar():
|
async def scalar():
|
||||||
return get_scalar_api_reference(openapi_url=web_app.openapi_schema, title="Modal worker web endpoint")
|
return get_scalar_api_reference(openapi_url=web_app.openapi_schema or '/openapi.json', title="Modal worker web endpoint")
|
||||||
|
|
||||||
|
|
||||||
web_app.include_router(ffmpeg.router)
|
web_app.include_router(ffmpeg.router)
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,9 @@ from pydantic import BaseModel, Field, computed_field, field_validator
|
||||||
from pydantic.json_schema import JsonSchemaValue
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
from ..utils.TimeUtils import TimeDelta
|
from ..utils.TimeUtils import TimeDelta
|
||||||
|
|
||||||
|
|
||||||
class FFMpegSliceSegment(BaseModel):
|
class FFMpegSliceSegment(BaseModel):
|
||||||
start: TimeDelta = Field(description="视频切割的开始时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)")
|
start: TimeDelta = Field(description="视频切割的开始时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)或者为标准格式的时间戳")
|
||||||
end: TimeDelta = Field(description="视频切割的结束时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)")
|
end: TimeDelta = Field(description="视频切割的结束时间点秒数, 可为浮点小数(精确到小数点后3位,毫秒级)或者标准格式的时间戳")
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
|
|
@ -15,11 +14,13 @@ class FFMpegSliceSegment(BaseModel):
|
||||||
|
|
||||||
@field_validator('start', mode='before')
|
@field_validator('start', mode='before')
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_start(cls, v: Union[float, TimeDelta]):
|
def parse_start(cls, v: Union[float, str, TimeDelta]):
|
||||||
if isinstance(v, float):
|
if isinstance(v, float):
|
||||||
return TimeDelta(seconds=v)
|
return TimeDelta(seconds=v)
|
||||||
elif isinstance(v, int):
|
elif isinstance(v, int):
|
||||||
return TimeDelta(seconds=v)
|
return TimeDelta(seconds=v)
|
||||||
|
elif isinstance(v, str):
|
||||||
|
return TimeDelta.from_format_string(v)
|
||||||
elif isinstance(v, TimeDelta):
|
elif isinstance(v, TimeDelta):
|
||||||
return v
|
return v
|
||||||
else:
|
else:
|
||||||
|
|
@ -32,6 +33,8 @@ class FFMpegSliceSegment(BaseModel):
|
||||||
return TimeDelta(seconds=v)
|
return TimeDelta(seconds=v)
|
||||||
elif isinstance(v, int):
|
elif isinstance(v, int):
|
||||||
return TimeDelta(seconds=v)
|
return TimeDelta(seconds=v)
|
||||||
|
elif isinstance(v, str):
|
||||||
|
return TimeDelta.from_format_string(v)
|
||||||
elif isinstance(v, TimeDelta):
|
elif isinstance(v, TimeDelta):
|
||||||
return v
|
return v
|
||||||
else:
|
else:
|
||||||
|
|
@ -45,11 +48,11 @@ class FFMpegSliceSegment(BaseModel):
|
||||||
"properties": {
|
"properties": {
|
||||||
"start": {
|
"start": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"examples": [5, 10.5]
|
"examples": [5, 10.5, '00:00:10.500'],
|
||||||
},
|
},
|
||||||
"end": {
|
"end": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"examples": [8, 12.5]
|
"examples": [8, 12.5, '00:00:12.500'],
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,10 @@ from functools import cached_property
|
||||||
from typing import List, Union, Optional, Any, Dict
|
from typing import List, Union, Optional, Any, Dict
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from pydantic import (BaseModel, Field, field_validator, ValidationError,
|
from pydantic import (BaseModel, Field, field_validator, ValidationError,
|
||||||
field_serializer, SerializationInfo, computed_field)
|
field_serializer, SerializationInfo, computed_field, FileUrl)
|
||||||
from pydantic.json_schema import JsonSchemaValue
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
from ..config import WorkerConfig
|
from ..config import WorkerConfig
|
||||||
|
from ..utils.TimeUtils import TimeDelta
|
||||||
|
|
||||||
config = WorkerConfig()
|
config = WorkerConfig()
|
||||||
|
|
||||||
|
|
@ -165,7 +166,7 @@ class MediaSource(BaseModel):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
match self.protocol:
|
match self.protocol:
|
||||||
case MediaProtocol.http:
|
case MediaProtocol.http:
|
||||||
return f"{self.protocol.value}://{self.path}"
|
return self.urn
|
||||||
case MediaProtocol.hls:
|
case MediaProtocol.hls:
|
||||||
return f"{self.protocol.value}://{self.path[8:]}" # strip "https://" from url
|
return f"{self.protocol.value}://{self.path[8:]}" # strip "https://" from url
|
||||||
case _:
|
case _:
|
||||||
|
|
@ -208,4 +209,5 @@ class CacheResult(BaseModel):
|
||||||
class DownloadResult(BaseModel):
|
class DownloadResult(BaseModel):
|
||||||
urls: List[str] = Field(description="下载链接")
|
urls: List[str] = Field(description="下载链接")
|
||||||
|
|
||||||
|
class UploadResultResponse(BaseModel):
|
||||||
|
media: MediaSource = Field(description="上传完成的媒体资源")
|
||||||
|
|
@ -1,17 +1,22 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
import modal
|
import modal
|
||||||
|
from loguru import logger
|
||||||
from modal import current_function_call_id
|
from modal import current_function_call_id
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends, UploadFile, HTTPException, File, Form
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse, RedirectResponse
|
||||||
from starlette import status
|
from starlette import status
|
||||||
from starlette.responses import RedirectResponse
|
|
||||||
|
|
||||||
from ..config import WorkerConfig
|
from ..config import WorkerConfig
|
||||||
from ..middleware.authorization import verify_token
|
from ..middleware.authorization import verify_token
|
||||||
from ..models.media_model import MediaSources, CacheResult, MediaSource, MediaCacheStatus, DownloadResult
|
from ..models.media_model import (MediaSources,
|
||||||
|
CacheResult,
|
||||||
|
MediaSource,
|
||||||
|
MediaCacheStatus,
|
||||||
|
DownloadResult, UploadResultResponse
|
||||||
|
)
|
||||||
from ..models.web_model import SentryTransactionInfo
|
from ..models.web_model import SentryTransactionInfo
|
||||||
from ..utils.KVCache import KVCache
|
from ..utils.KVCache import KVCache
|
||||||
from ..utils.SentryUtils import SentryUtils
|
from ..utils.SentryUtils import SentryUtils
|
||||||
|
|
@ -52,15 +57,11 @@ async def cache(medias: MediaSources) -> CacheResult:
|
||||||
queue_publish_span.set_data("messaging.destination.name",
|
queue_publish_span.set_data("messaging.destination.name",
|
||||||
"video-downloader.cache_submit")
|
"video-downloader.cache_submit")
|
||||||
queue_publish_span.set_data("messaging.message.body.size", 0)
|
queue_publish_span.set_data("messaging.message.body.size", 0)
|
||||||
# video_cache = MediaCache(status=MediaCacheStatus.downloading,
|
|
||||||
# downloader_id=fn_task.object_id)
|
|
||||||
media.status = MediaCacheStatus.downloading
|
media.status = MediaCacheStatus.downloading
|
||||||
media.downloader_id = fn_task.object_id
|
media.downloader_id = fn_task.object_id
|
||||||
# video_cache_status_json = video_cache.model_dump_json()
|
|
||||||
modal_kv_cache.set_cache(media)
|
modal_kv_cache.set_cache(media)
|
||||||
else:
|
else:
|
||||||
media = cached_media
|
media = cached_media
|
||||||
# video_cache = MediaCache.model_validate_json(video_cache_status_json)
|
|
||||||
match media.status:
|
match media.status:
|
||||||
case MediaCacheStatus.ready:
|
case MediaCacheStatus.ready:
|
||||||
cache_hit = True
|
cache_hit = True
|
||||||
|
|
@ -79,14 +80,9 @@ async def cache(medias: MediaSources) -> CacheResult:
|
||||||
queue_publish_span.set_data("messaging.message.body.size", 0)
|
queue_publish_span.set_data("messaging.message.body.size", 0)
|
||||||
media.status = MediaCacheStatus.downloading
|
media.status = MediaCacheStatus.downloading
|
||||||
media.downloader_id = fn_task.object_id
|
media.downloader_id = fn_task.object_id
|
||||||
# video_cache = MediaCache(status=MediaCacheStatus.downloading,
|
|
||||||
# downloader_id=fn_task.object_id)
|
|
||||||
# video_cache_status_json = video_cache.model_dump_json()
|
|
||||||
modal_kv_cache.set_cache(media)
|
modal_kv_cache.set_cache(media)
|
||||||
cache_hit = False
|
cache_hit = False
|
||||||
|
|
||||||
# caches[media.urn] = video_cache
|
|
||||||
# logger.info(f"Media cache hit ? {cache_hit}")
|
|
||||||
cache_span.set_data("cache.hit", cache_hit)
|
cache_span.set_data("cache.hit", cache_hit)
|
||||||
return media
|
return media
|
||||||
|
|
||||||
|
|
@ -98,6 +94,30 @@ async def cache(medias: MediaSources) -> CacheResult:
|
||||||
KVCache.batch_update_cloudflare_kv(cache_task_result)
|
KVCache.batch_update_cloudflare_kv(cache_task_result)
|
||||||
return CacheResult(caches={media.urn: media for media in cache_task_result})
|
return CacheResult(caches={media.urn: media for media in cache_task_result})
|
||||||
|
|
||||||
|
@router.delete("/",
|
||||||
|
tags=["缓存"],
|
||||||
|
summary="清除指定的所有缓存",
|
||||||
|
description="清除指定的所有缓存(包括KV记录和S3存储文件)",
|
||||||
|
dependencies=[Depends(verify_token)])
|
||||||
|
async def purge_media_kv_file(medias: MediaSources):
|
||||||
|
fn_id = current_function_call_id()
|
||||||
|
fn = modal.Function.from_name(config.modal_app_name, "cache_delete", environment_name=config.modal_environment)
|
||||||
|
|
||||||
|
@SentryUtils.sentry_tracker(name="清除媒体源缓存", op="cache.purge", fn_id=fn_id,
|
||||||
|
sentry_trace_id=None, sentry_baggage=None)
|
||||||
|
async def purge_handle(media: MediaSource):
|
||||||
|
cache_media = modal_kv_cache.pop(media.urn)
|
||||||
|
if cache_media:
|
||||||
|
deleted_cache: MediaSource = await fn.remote.aio(cache_media)
|
||||||
|
return deleted_cache.urn
|
||||||
|
return None
|
||||||
|
|
||||||
|
async with asyncio.TaskGroup() as group:
|
||||||
|
tasks = [group.create_task(purge_handle(media)) for media in medias.inputs]
|
||||||
|
|
||||||
|
keys = [task.result() for task in tasks]
|
||||||
|
KVCache.batch_remove_cloudflare_kv(keys)
|
||||||
|
return JSONResponse(content={"success": True, "keys": keys})
|
||||||
|
|
||||||
@router.post("/download",
|
@router.post("/download",
|
||||||
tags=["缓存"],
|
tags=["缓存"],
|
||||||
|
|
@ -157,7 +177,8 @@ async def purge_kv(medias: MediaSources):
|
||||||
@router.post("/media",
|
@router.post("/media",
|
||||||
tags=["缓存"],
|
tags=["缓存"],
|
||||||
summary="清除指定的所有缓存",
|
summary="清除指定的所有缓存",
|
||||||
description="清除指定的所有缓存(包括KV记录和S3存储文件)",
|
description="清除指定的所有缓存(包括KV记录和S3存储文件), 将要被淘汰,使用DELETE /cache/替代",
|
||||||
|
deprecated=True,
|
||||||
dependencies=[Depends(verify_token)])
|
dependencies=[Depends(verify_token)])
|
||||||
async def purge_media(medias: MediaSources):
|
async def purge_media(medias: MediaSources):
|
||||||
fn_id = current_function_call_id()
|
fn_id = current_function_call_id()
|
||||||
|
|
@ -178,3 +199,27 @@ async def purge_media(medias: MediaSources):
|
||||||
keys = [task.result() for task in tasks]
|
keys = [task.result() for task in tasks]
|
||||||
KVCache.batch_remove_cloudflare_kv(keys)
|
KVCache.batch_remove_cloudflare_kv(keys)
|
||||||
return JSONResponse(content={"success": True, "keys": keys})
|
return JSONResponse(content={"success": True, "keys": keys})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload-s3", tags=['缓存'],
|
||||||
|
summary="上传文件到S3",
|
||||||
|
description="上传文件到S3当文件大于200M",
|
||||||
|
dependencies=[Depends(verify_token)])
|
||||||
|
async def s3_upload(file: Annotated[UploadFile, File(description="上传的文件")],
|
||||||
|
prefix: Annotated[Optional[str], Form()] = None) -> UploadResultResponse:
|
||||||
|
fn_id = current_function_call_id()
|
||||||
|
|
||||||
|
if file.size > 200 * 1024 * 1024: # 上传文件不大于200M
|
||||||
|
raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="上传文件不可超过200MB")
|
||||||
|
|
||||||
|
key = f"upload/{prefix}/{file.filename}" if prefix else f"upload/{file.filename}"
|
||||||
|
local_path = f"{config.S3_mount_dir}/{key}"
|
||||||
|
logger.info(f"s3上传到{key}, size={file.size}")
|
||||||
|
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||||
|
with open(local_path, 'wb') as f:
|
||||||
|
f.write(file.file.read())
|
||||||
|
logger.info(f"{local_path} 保存成功")
|
||||||
|
media_source = MediaSource.from_str(f"s3://{config.S3_region}/{config.S3_bucket_name}/{key}")
|
||||||
|
media_source.status = MediaCacheStatus.ready
|
||||||
|
media_source.downloader_id = fn_id
|
||||||
|
return UploadResultResponse(media=media_source)
|
||||||
|
|
|
||||||
|
|
@ -273,7 +273,8 @@ with downloader_image.imports():
|
||||||
process_span.set_status("failed")
|
process_span.set_status("failed")
|
||||||
case MediaProtocol.http:
|
case MediaProtocol.http:
|
||||||
try:
|
try:
|
||||||
cache_filepath = f"{config.S3_mount_dir}/{media.cache_filepath}"
|
cache_filepath = f"{config.S3_mount_dir}/{media.protocol.value}/{media.cache_filepath}"
|
||||||
|
os.makedirs(os.path.dirname(cache_filepath), exist_ok=True)
|
||||||
download_large_file(url=media.__str__(), output_path=cache_filepath)
|
download_large_file(url=media.__str__(), output_path=cache_filepath)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue