添加cache文件的batch操作接口

This commit is contained in:
shuohigh@gmail.com 2025-06-24 15:11:26 +08:00
parent 39044f22dc
commit ff437e63b3
3 changed files with 107 additions and 4 deletions

View File

@ -80,8 +80,11 @@ class MediaSource(BaseModel):
cache_path = os.path.join(s3_mount_path, media_source.cache_filepath)
# 校验媒体文件是否存在缓存中
if not config.modal_is_local and not os.path.exists(cache_path):
raise ValueError(f"媒体文件 {media_source.cache_filepath} 不存在于缓存中")
if media_source.status is MediaCacheStatus.unknown:
if not os.path.exists(cache_path):
media_source.status = MediaCacheStatus.missing
else:
media_source.status = MediaCacheStatus.ready
return media_source
@ -297,3 +300,9 @@ class UploadMultipartPresignResponse(BaseModel):
complete_url: str = Field(description="用于确认完成分片上传的请求地址")
urn: str = Field(description="上传成功后获得的对应资源URN")
expired_at: datetime = Field(description="上传地址签名过期时间戳")
class MediaCopyRequest(BaseModel):
class MediaCopyTask(BaseModel):
source: MediaSource = Field(description="源媒体")
destination: MediaSource = Field(description="")

View File

@ -24,6 +24,12 @@ class TaskStatus(str, Enum):
expired = "expired"
class CacheOperationType(str, Enum):
copy = "copy"
move = "move"
delete = "delete"
class ErrorCode(int, Enum):
SUCCESS = 0
PARAM_ERROR = 10001
@ -99,6 +105,68 @@ class FFMPEGResult(BaseModel):
return self.urn.replace(prefix, f"{config.S3_cdn_endpoint}")
class CacheTask(BaseModel):
type: CacheOperationType = Field(description="操作类型")
source: MediaSource = Field(description="源媒体URN")
target: Optional[MediaSource] = Field(description="目标媒体URN")
@field_validator('source', mode='before')
@classmethod
def parse_source(cls, v: Union[str, MediaSource]) -> MediaSource:
if isinstance(v, str):
media_source = MediaSource.from_str(v)
if media_source.protocol == MediaProtocol.s3:
return media_source
else:
raise pydantic.ValidationError('media只支持s3格式的urn')
elif isinstance(v, MediaSource):
return v
else:
raise pydantic.ValidationError("media格式读取失败")
@field_validator('target', mode='before')
@classmethod
def parse_target(cls, v: Union[str, MediaSource]) -> Optional[MediaSource]:
if v is None:
return None
if isinstance(v, str):
media_source = MediaSource.from_str(v)
if media_source.protocol == MediaProtocol.s3:
return media_source
else:
raise pydantic.ValidationError('media只支持s3格式的urn')
elif isinstance(v, MediaSource):
return v
else:
raise pydantic.ValidationError("media格式读取失败")
@model_validator(mode="after")
def parse_model(self):
match self.type:
case CacheOperationType.copy:
if self.target is None:
raise pydantic.ValidationError("使用copy行为时必填target URN")
case CacheOperationType.move:
if self.target is None:
raise pydantic.ValidationError("使用move行为时必填target URN")
case _:
return self
class CacheTaskResult(CacheTask):
success: bool = Field(default=False, description="执行成功")
class ClusterCacheBatchRequest(BaseModel):
tasks: List[CacheTask] = Field(description="批量操作任务,按列表顺序执行")
model_config = ConfigDict()
class ClusterCacheBatchResponse(BaseModel):
results: List[CacheTaskResult] = Field(description="批量操作任务结果")
class BaseFFMPEGTaskRequest(BaseModel):
webhook: Optional[WebhookNotify] = Field(description="Task webhook", default=None)

View File

@ -1,7 +1,8 @@
import asyncio
import datetime
import os
from typing import Annotated, Optional, List, Tuple
import shutil
from typing import Annotated, Optional, List, Tuple, Dict
import modal
from loguru import logger
@ -25,7 +26,8 @@ from ..models.media_model import (MediaSources,
UploadMultipartPresignRequest, UploadMultipartPresignResponse, MediaProtocol
)
from ..models.web_model import SentryTransactionInfo, MonitorLiveRoomProductRequest, ModalTaskResponse, \
LiveRoomProductCachesResponse, CacheDeleteTaskResponse
LiveRoomProductCachesResponse, CacheDeleteTaskResponse, ClusterCacheBatchRequest, CacheOperationType, \
ClusterCacheBatchResponse, CacheTaskResult
from ..utils.KVCache import MediaSourceKVCache, LiveProductKVCache
from ..utils.SentryUtils import SentryUtils
@ -226,6 +228,30 @@ async def purge_kv(medias: MediaSources):
return JSONResponse(content={"success": False, "error": str(e)})
@router.post("/batch", summary="批量操作集群S3缓存",
description="批量操作集群S3缓存",
dependencies=[Depends(verify_token)])
async def s3_copy(body: ClusterCacheBatchRequest) -> ClusterCacheBatchResponse:
results: List[CacheTaskResult] = []
for task in body.tasks:
try:
match task.type:
case CacheOperationType.copy:
shutil.copy(task.source.local_mount_path, task.target.local_mount_path)
case CacheOperationType.delete:
os.remove(task.source.local_mount_path)
case CacheOperationType.move:
shutil.copy(task.source.local_mount_path, task.target.local_mount_path)
os.remove(task.source.local_mount_path)
result = CacheTaskResult(**task.model_dump(), success=True)
results.append(result)
except Exception as e:
logger.exception(e)
result = CacheTaskResult(**task.model_dump(), success=False)
results.append(result)
return ClusterCacheBatchResponse(results=results)
@router.post("/upload-s3",
summary="上传文件到S3",
description="上传文件到S3的文件必须小于200M",