300 lines
12 KiB
Python
300 lines
12 KiB
Python
import os
|
||
import re
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
from functools import cached_property
|
||
from typing import List, Union, Optional, Any, Dict
|
||
from urllib.parse import urlparse
|
||
from pydantic import (BaseModel, Field, field_validator, ValidationError,
|
||
field_serializer, SerializationInfo, computed_field, FileUrl, Base64Str, Base64Bytes, ConfigDict)
|
||
from pydantic.json_schema import JsonSchemaValue
|
||
from ..config import WorkerConfig
|
||
from ..utils.TimeUtils import TimeDelta
|
||
from ..utils.VideoUtils import VideoMetadata
|
||
|
||
config = WorkerConfig()
|
||
|
||
s3_region = config.S3_region
|
||
s3_bucket_name = config.S3_bucket_name
|
||
s3_mount_point = config.S3_mount_dir
|
||
|
||
|
||
class MediaProtocol(str, Enum):
|
||
http = "http"
|
||
s3 = "s3"
|
||
vod = "vod"
|
||
cos = "cos"
|
||
hls = "hls"
|
||
gs = "gs"
|
||
|
||
|
||
class MediaCacheStatus(str, Enum):
|
||
downloading = "downloading"
|
||
failed = "failed"
|
||
ready = "ready"
|
||
deleted = "deleted"
|
||
missing = "missing"
|
||
unknown = "unknown"
|
||
|
||
|
||
class MediaSource(BaseModel):
|
||
path: str = Field(description="媒体源的路径")
|
||
protocol: MediaProtocol = Field(description="媒体源的来源协议")
|
||
endpoint: Optional[str] = Field(description="媒体源来源的终端地址,根据使用协议的不同,会有没有endpoint的情况")
|
||
bucket: Optional[str] = Field(description="媒体源所使用的存储桶(s3)/SubAppId(vod)")
|
||
urn: Optional[str] = Field(description="媒体源的唯一指定标识")
|
||
status: MediaCacheStatus = Field(default=MediaCacheStatus.unknown, description="媒体源在Modal集群挂载的缓存状态")
|
||
expired_at: Optional[datetime] = Field(description="缓存过期时间点, 为None时不会过期(过期处理WIP)", default=None)
|
||
downloader_id: Optional[str] = Field(description="正在处理下载的Downloader ID", default=None)
|
||
progress: int = Field(description="缓存进度", default=0)
|
||
metadata: Optional[VideoMetadata] = Field(description="媒体元数据", default=None)
|
||
content_length: Optional[int] = Field(description="媒体文件大小", default=None)
|
||
|
||
@classmethod
|
||
def from_str(cls, media_url: str) -> 'MediaSource':
|
||
if media_url.startswith('http://') or media_url.startswith('https://'):
|
||
parsed_url = urlparse(media_url)
|
||
path_str = f"{parsed_url.path}?{parsed_url.query}" if parsed_url.query else parsed_url.path
|
||
return MediaSource(path=path_str,
|
||
protocol=MediaProtocol.http,
|
||
endpoint=parsed_url.netloc, # domain of http url
|
||
bucket=None,
|
||
urn=media_url)
|
||
elif media_url.startswith('s3://'):
|
||
pattern = r"^s3://[a-z]{2}-[a-z]+-\d/.*$"
|
||
if not re.match(pattern, media_url):
|
||
media_url = media_url.replace("s3://", f"s3://{s3_region}/")
|
||
# s3://{endpoint}/{bucket}/{url}
|
||
paths = media_url[5:].split('/')
|
||
|
||
if len(paths) < 3:
|
||
raise ValidationError("URN-s3 格式错误")
|
||
|
||
media_source = MediaSource(path='/'.join(paths[2:]),
|
||
protocol=MediaProtocol.s3,
|
||
endpoint=paths[0],
|
||
bucket=paths[1],
|
||
urn=media_url)
|
||
|
||
s3_mount_path = config.S3_mount_dir
|
||
cache_path = os.path.join(s3_mount_path, media_source.cache_filepath)
|
||
|
||
# 校验媒体文件是否存在缓存中
|
||
if not config.modal_is_local and not os.path.exists(cache_path):
|
||
raise ValueError(f"媒体文件 {media_source.cache_filepath} 不存在于缓存中")
|
||
|
||
return media_source
|
||
|
||
elif media_url.startswith('vod://'): # vod://{endpoint}/{subAppId}/{fileId}
|
||
paths = media_url[6:].split('/')
|
||
if len(paths) < 3:
|
||
raise ValidationError("URN-vod 格式错误")
|
||
# 兼容有文件类型后缀和没有文件类型后缀的格式
|
||
url = paths[2] if '.' in os.path.basename(paths[2]) else paths[2] + ".mp4"
|
||
return MediaSource(path=url,
|
||
protocol=MediaProtocol.vod,
|
||
bucket=paths[1],
|
||
endpoint=paths[0],
|
||
urn=media_url)
|
||
elif media_url.startswith('cos://'): # cos://{endpoint}/{bucket}/{url}
|
||
paths = media_url[6:].split('/')
|
||
return MediaSource(path='/'.join(paths[2:]),
|
||
protocol=MediaProtocol.cos,
|
||
endpoint=paths[0],
|
||
bucket=paths[1],
|
||
urn=media_url)
|
||
elif media_url.startswith('hls://'):
|
||
# hls://merge-local-1324682537.cos.ap-shanghai.myqcloud.com/streams/1264/30322.m3u8
|
||
paths = media_url[6:]
|
||
return MediaSource(path=f"https://{paths}",
|
||
protocol=MediaProtocol.hls,
|
||
endpoint=None,
|
||
bucket=None,
|
||
urn=media_url
|
||
)
|
||
else:
|
||
available_schemas = [member.value for member in MediaProtocol]
|
||
available_schemas_str = ','.join(available_schemas)
|
||
raise ValidationError(f"mediaUrl必须以{available_schemas_str}协议开头")
|
||
|
||
@classmethod
|
||
def __get_pydantic_json_schema__(cls, core_schema: Any, handler: Any) -> JsonSchemaValue:
|
||
# Override the schema to represent it as a string
|
||
return {
|
||
"type": "string",
|
||
"examples": [
|
||
"vod://ap-shanghai/subAppId/fileId",
|
||
"http://example.com/video.mp4",
|
||
"s3://endpoint/bucket/path/to/file",
|
||
"cos://endpoint/bucket/path/to/file"
|
||
]
|
||
}
|
||
|
||
@computed_field(description="s3挂载路径下的缓存相对路径")
|
||
@property
|
||
def cache_filepath(self) -> str:
|
||
match self.protocol:
|
||
case MediaProtocol.s3:
|
||
# 本地挂载缓存
|
||
if self.protocol == MediaProtocol.s3 and self.endpoint == s3_region and self.bucket == s3_bucket_name:
|
||
return f"{self.path}"
|
||
else:
|
||
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
|
||
case MediaProtocol.http:
|
||
clean_path = self.path.split('?')[0] if '?' in self.path else self.path
|
||
return f"{self.protocol.value}/{self.endpoint}/{clean_path}"
|
||
case _:
|
||
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
|
||
|
||
@computed_field(description="文件后缀名")
|
||
@cached_property
|
||
def file_extension(self) -> Optional[str]:
|
||
return os.path.basename(self.urn).split('.')[-1]
|
||
|
||
@computed_field(description="是否本地可用")
|
||
@property
|
||
def local_available(self) -> bool:
|
||
if self.status == MediaCacheStatus.ready:
|
||
return self.local_exists
|
||
return False
|
||
|
||
@computed_field(description="是否存在本地文件")
|
||
@property
|
||
def local_exists(self) -> bool:
|
||
return os.path.exists(self.local_mount_path)
|
||
|
||
@computed_field(description="本地挂载地址")
|
||
@property
|
||
def local_mount_path(self) -> str:
|
||
return f"{s3_mount_point}/{self.cache_filepath}"
|
||
|
||
@field_serializer('expired_at')
|
||
def serialize_datetime(self, value: Optional[datetime], info: SerializationInfo) -> Optional[str]:
|
||
if value:
|
||
return value.isoformat()
|
||
else:
|
||
return None
|
||
|
||
def get_cdn_url(self) -> str:
|
||
if self.protocol == MediaProtocol.s3:
|
||
return f"{self.path}"
|
||
elif self.protocol == MediaProtocol.http:
|
||
return f"{self.protocol.value}/{self.path}"
|
||
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
|
||
|
||
@computed_field(description="CDN挂载地址")
|
||
@property
|
||
def url(self) -> str:
|
||
return f"{config.S3_cdn_endpoint}/{self.get_cdn_url()}"
|
||
|
||
def __str__(self):
|
||
match self.protocol:
|
||
case MediaProtocol.http:
|
||
return self.urn
|
||
case MediaProtocol.hls:
|
||
return f"{self.protocol.value}://{self.path[8:]}" # strip "https://" from url
|
||
case _:
|
||
return f"{self.protocol.value}://{self.endpoint}/{self.bucket}/{self.path}"
|
||
|
||
|
||
class MediaSources(BaseModel):
|
||
inputs: List[MediaSource] = Field(examples=[
|
||
[
|
||
"vod://ap-shanghai/1500034234/1397757910405340824.mp4",
|
||
"vod://ap-shanghai/1500034234/1397757910403699452.mp4",
|
||
"s3://ap-northeast-2/modal-media-cache/concat/outputs/fc-01JTPV5FCNA74CKX3N3214XJPD/output.mp4"
|
||
]
|
||
], description="支持多种协议['vod://', 's3://'], 计划支持['cos://', http://]")
|
||
|
||
@field_validator('inputs', mode='before')
|
||
@classmethod
|
||
def parse_inputs(cls, v: Union[str, MediaSource]) -> List[MediaSource]:
|
||
if not v:
|
||
raise ValidationError([
|
||
{
|
||
'loc': ('inputs',),
|
||
'msg': "inputs为空",
|
||
'type': 'value_error',
|
||
'input': v
|
||
}
|
||
], MediaSources)
|
||
result = []
|
||
for item in v:
|
||
if isinstance(item, str):
|
||
result.append(MediaSource.from_str(item))
|
||
elif isinstance(item, MediaSource):
|
||
result.append(item)
|
||
else:
|
||
raise ValidationError([
|
||
{
|
||
'loc': ('inputs',),
|
||
'msg': "inputs元素类型错误: 必须是字符串",
|
||
'type': 'value_error',
|
||
'input': v
|
||
}
|
||
], MediaSources)
|
||
return result
|
||
|
||
model_config = {
|
||
"arbitrary_types_allowed": True
|
||
}
|
||
|
||
|
||
class CacheResult(BaseModel):
|
||
caches: Dict[str, MediaSource] = Field(description="Cache ID")
|
||
|
||
|
||
class DownloadResult(BaseModel):
|
||
urls: List[str] = Field(description="下载链接")
|
||
|
||
|
||
class UploadResultResponse(BaseModel):
|
||
media: MediaSource = Field(description="上传完成的媒体资源")
|
||
|
||
|
||
class Base64File(BaseModel):
|
||
raw_content: Base64Bytes = Field(description="Base64编码的原始内容")
|
||
filename: str = Field(description="文件名, 包含文件类型后缀")
|
||
content_type: str = Field(description="文件数据类型")
|
||
|
||
# @computed_field(description="Base64解码后的二进制内容")
|
||
# @property
|
||
# def content(self) -> Base64Bytes:
|
||
# return base64.b64decode(self.raw_content)
|
||
|
||
@computed_field(description="文件字节大小")
|
||
@property
|
||
def size(self) -> int:
|
||
# b64_str = self.raw_content.strip().replace('\n', '').replace('\r', '')
|
||
# padding = b64_str.count('=', -2) # 只统计末尾的 '='
|
||
# size = len(b64_str)
|
||
return len(self.raw_content)
|
||
|
||
|
||
class UploadBase64Request(BaseModel):
|
||
file: Base64File = Field(description="上传的文件")
|
||
prefix: Optional[str] = Field(description="文件存在的前缀目录", default=None)
|
||
|
||
|
||
class UploadPresignRequest(BaseModel):
|
||
key: str = Field(description="上传文件的key", examples=['123/456/abc.mp4'])
|
||
content_type: str = Field(description="上传对象的文件类型", examples=['video/mp4'])
|
||
|
||
|
||
class UploadPresignResponse(BaseModel):
|
||
url: str = Field(description="就近加速的PUT上传地址")
|
||
urn: str = Field(description="上传成功后获得的对应资源URN")
|
||
expired_at: datetime = Field(description="上传地址签名过期时间戳")
|
||
|
||
|
||
class UploadMultipartPresignRequest(UploadPresignRequest):
|
||
parts_count: int = Field(description="分片数量")
|
||
|
||
|
||
class UploadMultipartPresignResponse(BaseModel):
|
||
urls: List[str] = Field(description="就近加速的PUT分片上传地址")
|
||
list_url: str = Field(description="用于确认分片上传状态的请求地址")
|
||
complete_url: str = Field(description="用于确认完成分片上传的请求地址")
|
||
urn: str = Field(description="上传成功后获得的对应资源URN")
|
||
expired_at: datetime = Field(description="上传地址签名过期时间戳")
|