modalDeploy/src/BowongModalFunctions/models/media_model.py

300 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
from datetime import datetime
from enum import Enum
from functools import cached_property
from typing import List, Union, Optional, Any, Dict
from urllib.parse import urlparse
from pydantic import (BaseModel, Field, field_validator, ValidationError,
field_serializer, SerializationInfo, computed_field, FileUrl, Base64Str, Base64Bytes, ConfigDict)
from pydantic.json_schema import JsonSchemaValue
from ..config import WorkerConfig
from ..utils.TimeUtils import TimeDelta
from ..utils.VideoUtils import VideoMetadata
config = WorkerConfig()
s3_region = config.S3_region
s3_bucket_name = config.S3_bucket_name
s3_mount_point = config.S3_mount_dir
class MediaProtocol(str, Enum):
http = "http"
s3 = "s3"
vod = "vod"
cos = "cos"
hls = "hls"
gs = "gs"
class MediaCacheStatus(str, Enum):
downloading = "downloading"
failed = "failed"
ready = "ready"
deleted = "deleted"
missing = "missing"
unknown = "unknown"
class MediaSource(BaseModel):
path: str = Field(description="媒体源的路径")
protocol: MediaProtocol = Field(description="媒体源的来源协议")
endpoint: Optional[str] = Field(description="媒体源来源的终端地址根据使用协议的不同会有没有endpoint的情况")
bucket: Optional[str] = Field(description="媒体源所使用的存储桶(s3)/SubAppId(vod)")
urn: Optional[str] = Field(description="媒体源的唯一指定标识")
status: MediaCacheStatus = Field(default=MediaCacheStatus.unknown, description="媒体源在Modal集群挂载的缓存状态")
expired_at: Optional[datetime] = Field(description="缓存过期时间点, 为None时不会过期过期处理WIP", default=None)
downloader_id: Optional[str] = Field(description="正在处理下载的Downloader ID", default=None)
progress: int = Field(description="缓存进度", default=0)
metadata: Optional[VideoMetadata] = Field(description="媒体元数据", default=None)
content_length: Optional[int] = Field(description="媒体文件大小", default=None)
@classmethod
def from_str(cls, media_url: str) -> 'MediaSource':
if media_url.startswith('http://') or media_url.startswith('https://'):
parsed_url = urlparse(media_url)
path_str = f"{parsed_url.path}?{parsed_url.query}" if parsed_url.query else parsed_url.path
return MediaSource(path=path_str,
protocol=MediaProtocol.http,
endpoint=parsed_url.netloc, # domain of http url
bucket=None,
urn=media_url)
elif media_url.startswith('s3://'):
pattern = r"^s3://[a-z]{2}-[a-z]+-\d/.*$"
if not re.match(pattern, media_url):
media_url = media_url.replace("s3://", f"s3://{s3_region}/")
# s3://{endpoint}/{bucket}/{url}
paths = media_url[5:].split('/')
if len(paths) < 3:
raise ValidationError("URN-s3 格式错误")
media_source = MediaSource(path='/'.join(paths[2:]),
protocol=MediaProtocol.s3,
endpoint=paths[0],
bucket=paths[1],
urn=media_url)
s3_mount_path = config.S3_mount_dir
cache_path = os.path.join(s3_mount_path, media_source.cache_filepath)
# 校验媒体文件是否存在缓存中
if not config.modal_is_local and not os.path.exists(cache_path):
raise ValueError(f"媒体文件 {media_source.cache_filepath} 不存在于缓存中")
return media_source
elif media_url.startswith('vod://'): # vod://{endpoint}/{subAppId}/{fileId}
paths = media_url[6:].split('/')
if len(paths) < 3:
raise ValidationError("URN-vod 格式错误")
# 兼容有文件类型后缀和没有文件类型后缀的格式
url = paths[2] if '.' in os.path.basename(paths[2]) else paths[2] + ".mp4"
return MediaSource(path=url,
protocol=MediaProtocol.vod,
bucket=paths[1],
endpoint=paths[0],
urn=media_url)
elif media_url.startswith('cos://'): # cos://{endpoint}/{bucket}/{url}
paths = media_url[6:].split('/')
return MediaSource(path='/'.join(paths[2:]),
protocol=MediaProtocol.cos,
endpoint=paths[0],
bucket=paths[1],
urn=media_url)
elif media_url.startswith('hls://'):
# hls://merge-local-1324682537.cos.ap-shanghai.myqcloud.com/streams/1264/30322.m3u8
paths = media_url[6:]
return MediaSource(path=f"https://{paths}",
protocol=MediaProtocol.hls,
endpoint=None,
bucket=None,
urn=media_url
)
else:
available_schemas = [member.value for member in MediaProtocol]
available_schemas_str = ','.join(available_schemas)
raise ValidationError(f"mediaUrl必须以{available_schemas_str}协议开头")
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: Any, handler: Any) -> JsonSchemaValue:
# Override the schema to represent it as a string
return {
"type": "string",
"examples": [
"vod://ap-shanghai/subAppId/fileId",
"http://example.com/video.mp4",
"s3://endpoint/bucket/path/to/file",
"cos://endpoint/bucket/path/to/file"
]
}
@computed_field(description="s3挂载路径下的缓存相对路径")
@property
def cache_filepath(self) -> str:
match self.protocol:
case MediaProtocol.s3:
# 本地挂载缓存
if self.protocol == MediaProtocol.s3 and self.endpoint == s3_region and self.bucket == s3_bucket_name:
return f"{self.path}"
else:
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
case MediaProtocol.http:
clean_path = self.path.split('?')[0] if '?' in self.path else self.path
return f"{self.protocol.value}/{self.endpoint}/{clean_path}"
case _:
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
@computed_field(description="文件后缀名")
@cached_property
def file_extension(self) -> Optional[str]:
return os.path.basename(self.urn).split('.')[-1]
@computed_field(description="是否本地可用")
@property
def local_available(self) -> bool:
if self.status == MediaCacheStatus.ready:
return self.local_exists
return False
@computed_field(description="是否存在本地文件")
@property
def local_exists(self) -> bool:
return os.path.exists(self.local_mount_path)
@computed_field(description="本地挂载地址")
@property
def local_mount_path(self) -> str:
return f"{s3_mount_point}/{self.cache_filepath}"
@field_serializer('expired_at')
def serialize_datetime(self, value: Optional[datetime], info: SerializationInfo) -> Optional[str]:
if value:
return value.isoformat()
else:
return None
def get_cdn_url(self) -> str:
if self.protocol == MediaProtocol.s3:
return f"{self.path}"
elif self.protocol == MediaProtocol.http:
return f"{self.protocol.value}/{self.path}"
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
@computed_field(description="CDN挂载地址")
@property
def url(self) -> str:
return f"{config.S3_cdn_endpoint}/{self.get_cdn_url()}"
def __str__(self):
match self.protocol:
case MediaProtocol.http:
return self.urn
case MediaProtocol.hls:
return f"{self.protocol.value}://{self.path[8:]}" # strip "https://" from url
case _:
return f"{self.protocol.value}://{self.endpoint}/{self.bucket}/{self.path}"
class MediaSources(BaseModel):
inputs: List[MediaSource] = Field(examples=[
[
"vod://ap-shanghai/1500034234/1397757910405340824.mp4",
"vod://ap-shanghai/1500034234/1397757910403699452.mp4",
"s3://ap-northeast-2/modal-media-cache/concat/outputs/fc-01JTPV5FCNA74CKX3N3214XJPD/output.mp4"
]
], description="支持多种协议['vod://', 's3://'], 计划支持['cos://', http://]")
@field_validator('inputs', mode='before')
@classmethod
def parse_inputs(cls, v: Union[str, MediaSource]) -> List[MediaSource]:
if not v:
raise ValidationError([
{
'loc': ('inputs',),
'msg': "inputs为空",
'type': 'value_error',
'input': v
}
], MediaSources)
result = []
for item in v:
if isinstance(item, str):
result.append(MediaSource.from_str(item))
elif isinstance(item, MediaSource):
result.append(item)
else:
raise ValidationError([
{
'loc': ('inputs',),
'msg': "inputs元素类型错误: 必须是字符串",
'type': 'value_error',
'input': v
}
], MediaSources)
return result
model_config = {
"arbitrary_types_allowed": True
}
class CacheResult(BaseModel):
caches: Dict[str, MediaSource] = Field(description="Cache ID")
class DownloadResult(BaseModel):
urls: List[str] = Field(description="下载链接")
class UploadResultResponse(BaseModel):
media: MediaSource = Field(description="上传完成的媒体资源")
class Base64File(BaseModel):
raw_content: Base64Bytes = Field(description="Base64编码的原始内容")
filename: str = Field(description="文件名, 包含文件类型后缀")
content_type: str = Field(description="文件数据类型")
# @computed_field(description="Base64解码后的二进制内容")
# @property
# def content(self) -> Base64Bytes:
# return base64.b64decode(self.raw_content)
@computed_field(description="文件字节大小")
@property
def size(self) -> int:
# b64_str = self.raw_content.strip().replace('\n', '').replace('\r', '')
# padding = b64_str.count('=', -2) # 只统计末尾的 '='
# size = len(b64_str)
return len(self.raw_content)
class UploadBase64Request(BaseModel):
file: Base64File = Field(description="上传的文件")
prefix: Optional[str] = Field(description="文件存在的前缀目录", default=None)
class UploadPresignRequest(BaseModel):
key: str = Field(description="上传文件的key", examples=['123/456/abc.mp4'])
content_type: str = Field(description="上传对象的文件类型", examples=['video/mp4'])
class UploadPresignResponse(BaseModel):
url: str = Field(description="就近加速的PUT上传地址")
urn: str = Field(description="上传成功后获得的对应资源URN")
expired_at: datetime = Field(description="上传地址签名过期时间戳")
class UploadMultipartPresignRequest(UploadPresignRequest):
parts_count: int = Field(description="分片数量")
class UploadMultipartPresignResponse(BaseModel):
urls: List[str] = Field(description="就近加速的PUT分片上传地址")
list_url: str = Field(description="用于确认分片上传状态的请求地址")
complete_url: str = Field(description="用于确认完成分片上传的请求地址")
urn: str = Field(description="上传成功后获得的对应资源URN")
expired_at: datetime = Field(description="上传地址签名过期时间戳")