modalDeploy/src/BowongModalFunctions/models/media_model.py

213 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
from datetime import datetime
from enum import Enum
from functools import cached_property
from typing import List, Union, Optional, Any, Dict
from urllib.parse import urlparse
from pydantic import (BaseModel, Field, field_validator, ValidationError,
field_serializer, SerializationInfo, computed_field, FileUrl)
from pydantic.json_schema import JsonSchemaValue
from ..config import WorkerConfig
from ..utils.TimeUtils import TimeDelta
config = WorkerConfig()
s3_region = config.S3_region
s3_bucket_name = config.S3_bucket_name
class MediaProtocol(str, Enum):
http = "http"
s3 = "s3"
vod = "vod"
cos = "cos"
hls = "hls"
class MediaCacheStatus(str, Enum):
downloading = "downloading"
failed = "failed"
ready = "ready"
deleted = "deleted"
missing = "missing"
unknown = "unknown"
class MediaSource(BaseModel):
path: str = Field(description="媒体源的路径")
protocol: MediaProtocol = Field(description="媒体源的来源协议")
endpoint: Optional[str] = Field(description="媒体源来源的终端地址根据使用协议的不同会有没有endpoint的情况")
bucket: Optional[str] = Field(description="媒体源所使用的存储桶(s3)/SubAppId(vod)")
urn: Optional[str] = Field(description="媒体源的唯一指定标识")
status: MediaCacheStatus = Field(default=MediaCacheStatus.unknown, description="媒体源在Modal集群挂载的缓存状态")
expired_at: Optional[datetime] = Field(description="缓存过期时间点, 为None时不会过期过期处理WIP", default=None)
downloader_id: Optional[str] = Field(description="正在处理下载的Downloader ID", default=None)
progress: int = Field(description="缓存进度", default=0)
@classmethod
def from_str(cls, media_url: str) -> 'MediaSource':
if media_url.startswith('http://') or media_url.startswith('https://'):
parsed_url = urlparse(media_url)
path_str = f"{parsed_url.path}?{parsed_url.query}" if parsed_url.query else parsed_url.path
return MediaSource(path=path_str,
protocol=MediaProtocol.http,
endpoint=parsed_url.netloc, # domain of http url
bucket=None,
urn=media_url)
elif media_url.startswith('s3://'):
pattern = r"^s3://[a-z]{2}-[a-z]+-\d/.*$"
if not re.match(pattern, media_url):
media_url = media_url.replace("s3://", f"s3://{s3_region}/")
# s3://{endpoint}/{bucket}/{url}
paths = media_url[5:].split('/')
if len(paths) < 3:
raise ValidationError("URN-s3 格式错误")
media_source = MediaSource(path='/'.join(paths[2:]),
protocol=MediaProtocol.s3,
endpoint=paths[0],
bucket=paths[1],
urn=media_url)
s3_mount_path = config.S3_mount_dir
cache_path = os.path.join(s3_mount_path, media_source.cache_filepath)
# 校验媒体文件是否存在缓存中
if not os.path.exists(cache_path):
raise ValueError(f"媒体文件 {media_source.cache_filepath} 不存在于缓存中")
return media_source
elif media_url.startswith('vod://'): # vod://{endpoint}/{subAppId}/{fileId}
paths = media_url[6:].split('/')
if len(paths) < 3:
raise ValidationError("URN-vod 格式错误")
# 兼容有文件类型后缀和没有文件类型后缀的格式
url = paths[2] if '.' in os.path.basename(paths[2]) else paths[2] + ".mp4"
return MediaSource(path=url,
protocol=MediaProtocol.vod,
bucket=paths[1],
endpoint=paths[0],
urn=media_url)
elif media_url.startswith('cos://'): # cos://{endpoint}/{bucket}/{url}
paths = media_url[6:].split('/')
return MediaSource(path='/'.join(paths[2:]),
protocol=MediaProtocol.cos,
endpoint=paths[0],
bucket=paths[1],
urn=media_url)
elif media_url.startswith('hls://'):
# hls://merge-local-1324682537.cos.ap-shanghai.myqcloud.com/streams/1264/30322.m3u8
paths = media_url[6:]
return MediaSource(path=f"https://{paths}",
protocol=MediaProtocol.hls,
endpoint=None,
bucket=None,
urn=media_url
)
else:
available_schemas = [member.value for member in MediaProtocol]
available_schemas_str = ','.join(available_schemas)
raise ValidationError(f"mediaUrl必须以{available_schemas_str}协议开头")
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: Any, handler: Any) -> JsonSchemaValue:
# Override the schema to represent it as a string
return {
"type": "string",
"examples": [
"vod://ap-shanghai/subAppId/fileId",
"http://example.com/video.mp4",
"s3://endpoint/bucket/path/to/file",
"cos://endpoint/bucket/path/to/file"
]
}
@computed_field(description="s3挂载路径下的缓存相对路径")
@property
def cache_filepath(self) -> str:
match self.protocol:
case MediaProtocol.s3:
# 本地挂载缓存
if self.protocol == MediaProtocol.s3 and self.endpoint == s3_region and self.bucket == s3_bucket_name:
return f"{self.path}"
else:
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
case MediaProtocol.http:
clean_path = self.path.split('?')[0] if '?' in self.path else self.path
return f"{self.protocol.value}/{self.endpoint}/{clean_path}"
case _:
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
@computed_field(description="文件后缀名")
@cached_property
def file_extension(self) -> Optional[str]:
return os.path.basename(self.urn).split('.')[-1]
@field_serializer('expired_at')
def serialize_datetime(self, value: Optional[datetime], info: SerializationInfo) -> Optional[str]:
if value:
return value.isoformat()
else:
return None
def get_cdn_url(self) -> str:
if self.protocol == MediaProtocol.s3:
return f"{self.path}"
elif self.protocol == MediaProtocol.http:
return f"{self.protocol.value}/{self.path}"
return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}"
def __str__(self):
match self.protocol:
case MediaProtocol.http:
return self.urn
case MediaProtocol.hls:
return f"{self.protocol.value}://{self.path[8:]}" # strip "https://" from url
case _:
return f"{self.protocol.value}://{self.endpoint}/{self.bucket}/{self.path}"
class MediaSources(BaseModel):
inputs: List[MediaSource] = Field(examples=[
[
"vod://ap-shanghai/1500034234/1397757910405340824.mp4",
"vod://ap-shanghai/1500034234/1397757910403699452.mp4",
"s3://ap-northeast-2/modal-media-cache/concat/outputs/fc-01JTPV5FCNA74CKX3N3214XJPD/output.mp4"
]
], description="支持多种协议['vod://', 's3://'], 计划支持['cos://', http://]")
@field_validator('inputs', mode='before')
@classmethod
def parse_inputs(cls, v: Union[str, MediaSource]) -> List[MediaSource]:
if not v:
raise ValidationError("inputs为空")
result = []
for item in v:
if isinstance(item, str):
result.append(MediaSource.from_str(item))
elif isinstance(item, MediaSource):
result.append(item)
else:
raise ValidationError("inputs元素类型错误: 必须是字符串")
return result
model_config = {
"arbitrary_types_allowed": True
}
class CacheResult(BaseModel):
caches: Dict[str, MediaSource] = Field(description="Cache ID")
class DownloadResult(BaseModel):
urls: List[str] = Field(description="下载链接")
class UploadResultResponse(BaseModel):
media: MediaSource = Field(description="上传完成的媒体资源")