213 lines
8.7 KiB
Python
213 lines
8.7 KiB
Python
import os
|
||
import re
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
from functools import cached_property
|
||
from typing import List, Union, Optional, Any, Dict
|
||
from urllib.parse import urlparse
|
||
from pydantic import (BaseModel, Field, field_validator, ValidationError,
|
||
field_serializer, SerializationInfo, computed_field, FileUrl)
|
||
from pydantic.json_schema import JsonSchemaValue
|
||
from ..config import WorkerConfig
|
||
from ..utils.TimeUtils import TimeDelta
|
||
|
||
config = WorkerConfig()
|
||
|
||
s3_region = config.S3_region
|
||
s3_bucket_name = config.S3_bucket_name
|
||
|
||
|
||
class MediaProtocol(str, Enum):
|
||
http = "http"
|
||
s3 = "s3"
|
||
vod = "vod"
|
||
cos = "cos"
|
||
hls = "hls"
|
||
|
||
|
||
class MediaCacheStatus(str, Enum):
|
||
downloading = "downloading"
|
||
failed = "failed"
|
||
ready = "ready"
|
||
deleted = "deleted"
|
||
missing = "missing"
|
||
unknown = "unknown"
|
||
|
||
|
||
class MediaSource(BaseModel):
|
||
path: str = Field(description="媒体源的路径")
|
||
protocol: MediaProtocol = Field(description="媒体源的来源协议")
|
||
endpoint: Optional[str] = Field(description="媒体源来源的终端地址,根据使用协议的不同,会有没有endpoint的情况")
|
||
bucket: Optional[str] = Field(description="媒体源所使用的存储桶(s3)/SubAppId(vod)")
|
||
urn: Optional[str] = Field(description="媒体源的唯一指定标识")
|
||
status: MediaCacheStatus = Field(default=MediaCacheStatus.unknown, description="媒体源在Modal集群挂载的缓存状态")
|
||
expired_at: Optional[datetime] = Field(description="缓存过期时间点, 为None时不会过期(过期处理WIP)", default=None)
|
||
downloader_id: Optional[str] = Field(description="正在处理下载的Downloader ID", default=None)
|
||
progress: int = Field(description="缓存进度", default=0)
|
||
|
||
@classmethod
|
||
def from_str(cls, media_url: str) -> 'MediaSource':
|
||
if media_url.startswith('http://') or media_url.startswith('https://'):
|
||
parsed_url = urlparse(media_url)
|
||
path_str = f"{parsed_url.path}?{parsed_url.query}" if parsed_url.query else parsed_url.path
|
||
return MediaSource(path=path_str,
|
||
protocol=MediaProtocol.http,
|
||
endpoint=parsed_url.netloc, # domain of http url
|
||
bucket=None,
|
||
urn=media_url)
|
||
elif media_url.startswith('s3://'):
|
||
|
||
pattern = r"^s3://[a-z]{2}-[a-z]+-\d/.*$"
|
||
if not re.match(pattern, media_url):
|
||
media_url = media_url.replace("s3://", f"s3://{s3_region}/")
|
||
# s3://{endpoint}/{bucket}/{url}
|
||
paths = media_url[5:].split('/')
|
||
|
||
|
||
if len(paths) < 3:
|
||
raise ValidationError("URN-s3 格式错误")
|
||
|
||
media_source = MediaSource(path='/'.join(paths[2:]),
|
||
protocol=MediaProtocol.s3,
|
||
endpoint=paths[0],
|
||
bucket=paths[1],
|
||
urn=media_url)
|
||
|
||
s3_mount_path = config.S3_mount_dir
|
||
cache_path = os.path.join(s3_mount_path, media_source.cache_filepath)
|
||
|
||
# 校验媒体文件是否存在缓存中
|
||
if not os.path.exists(cache_path):
|
||
raise ValueError(f"媒体文件 {media_source.cache_filepath} 不存在于缓存中")
|
||
|
||
return media_source
|
||
|
||
elif media_url.startswith('vod://'): # vod://{endpoint}/{subAppId}/{fileId}
|
||
paths = media_url[6:].split('/')
|
||
if len(paths) < 3:
|
||
raise ValidationError("URN-vod 格式错误")
|
||
# 兼容有文件类型后缀和没有文件类型后缀的格式
|
||
url = paths[2] if '.' in os.path.basename(paths[2]) else paths[2] + ".mp4"
|
||
return MediaSource(path=url,
|
||
protocol=MediaProtocol.vod,
|
||
bucket=paths[1],
|
||
endpoint=paths[0],
|
||
urn=media_url)
|
||
elif media_url.startswith('cos://'): # cos://{endpoint}/{bucket}/{url}
|
||
paths = media_url[6:].split('/')
|
||
return MediaSource(path='/'.join(paths[2:]),
|
||
protocol=MediaProtocol.cos,
|
||
endpoint=paths[0],
|
||
bucket=paths[1],
|
||
urn=media_url)
|
||
elif media_url.startswith('hls://'):
|
||
# hls://merge-local-1324682537.cos.ap-shanghai.myqcloud.com/streams/1264/30322.m3u8
|
||
paths = media_url[6:]
|
||
return MediaSource(path=f"https://{paths}",
|
||
protocol=MediaProtocol.hls,
|
||
endpoint=None,
|
||
bucket=None,
|
||
urn=media_url
|
||
)
|
||
else:
|
||
available_schemas = [member.value for member in MediaProtocol]
|
||
available_schemas_str = ','.join(available_schemas)
|
||
raise ValidationError(f"mediaUrl必须以{available_schemas_str}协议开头")
|
||
|
||
@classmethod
|
||
def __get_pydantic_json_schema__(cls, core_schema: Any, handler: Any) -> JsonSchemaValue:
|
||
# Override the schema to represent it as a string
|
||
return {
|
||
"type": "string",
|
||
"examples": [
|
||
"vod://ap-shanghai/subAppId/fileId",
|
||
"http://example.com/video.mp4",
|
||
"s3://endpoint/bucket/path/to/file",
|
||
"cos://endpoint/bucket/path/to/file"
|
||
]
|
||
}
|
||
|
||
@computed_field(description="s3挂载路径下的缓存相对路径")
|
||
@property
|
||
def cache_filepath(self) -> str:
|
||
match self.protocol:
|
||
case MediaProtocol.s3:
|
||
# 本地挂载缓存
|
||
if self.protocol == MediaProtocol.s3 and self.endpoint == s3_region and self.bucket == s3_bucket_name:
|
||
return f"{self.path}"
|
||
else:
|
||
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
|
||
case MediaProtocol.http:
|
||
clean_path = self.path.split('?')[0] if '?' in self.path else self.path
|
||
return f"{self.protocol.value}/{self.endpoint}/{clean_path}"
|
||
case _:
|
||
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
|
||
|
||
|
||
@computed_field(description="文件后缀名")
|
||
@cached_property
|
||
def file_extension(self) -> Optional[str]:
|
||
return os.path.basename(self.urn).split('.')[-1]
|
||
|
||
@field_serializer('expired_at')
|
||
def serialize_datetime(self, value: Optional[datetime], info: SerializationInfo) -> Optional[str]:
|
||
if value:
|
||
return value.isoformat()
|
||
else:
|
||
return None
|
||
|
||
def get_cdn_url(self) -> str:
|
||
if self.protocol == MediaProtocol.s3:
|
||
return f"{self.path}"
|
||
elif self.protocol == MediaProtocol.http:
|
||
return f"{self.protocol.value}/{self.path}"
|
||
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
|
||
|
||
def __str__(self):
|
||
match self.protocol:
|
||
case MediaProtocol.http:
|
||
return self.urn
|
||
case MediaProtocol.hls:
|
||
return f"{self.protocol.value}://{self.path[8:]}" # strip "https://" from url
|
||
case _:
|
||
return f"{self.protocol.value}://{self.endpoint}/{self.bucket}/{self.path}"
|
||
|
||
|
||
class MediaSources(BaseModel):
|
||
inputs: List[MediaSource] = Field(examples=[
|
||
[
|
||
"vod://ap-shanghai/1500034234/1397757910405340824.mp4",
|
||
"vod://ap-shanghai/1500034234/1397757910403699452.mp4",
|
||
"s3://ap-northeast-2/modal-media-cache/concat/outputs/fc-01JTPV5FCNA74CKX3N3214XJPD/output.mp4"
|
||
]
|
||
], description="支持多种协议['vod://', 's3://'], 计划支持['cos://', http://]")
|
||
|
||
@field_validator('inputs', mode='before')
|
||
@classmethod
|
||
def parse_inputs(cls, v: Union[str, MediaSource]) -> List[MediaSource]:
|
||
if not v:
|
||
raise ValidationError("inputs为空")
|
||
result = []
|
||
for item in v:
|
||
if isinstance(item, str):
|
||
result.append(MediaSource.from_str(item))
|
||
elif isinstance(item, MediaSource):
|
||
result.append(item)
|
||
else:
|
||
raise ValidationError("inputs元素类型错误: 必须是字符串")
|
||
return result
|
||
|
||
model_config = {
|
||
"arbitrary_types_allowed": True
|
||
}
|
||
|
||
|
||
class CacheResult(BaseModel):
|
||
caches: Dict[str, MediaSource] = Field(description="Cache ID")
|
||
|
||
|
||
class DownloadResult(BaseModel):
|
||
urls: List[str] = Field(description="下载链接")
|
||
|
||
class UploadResultResponse(BaseModel):
|
||
media: MediaSource = Field(description="上传完成的媒体资源") |