import os import re from datetime import datetime from enum import Enum from functools import cached_property from typing import List, Union, Optional, Any, Dict from urllib.parse import urlparse from pydantic import (BaseModel, Field, field_validator, ValidationError, field_serializer, SerializationInfo, computed_field, FileUrl) from pydantic.json_schema import JsonSchemaValue from ..config import WorkerConfig from ..utils.TimeUtils import TimeDelta config = WorkerConfig() s3_region = config.S3_region s3_bucket_name = config.S3_bucket_name class MediaProtocol(str, Enum): http = "http" s3 = "s3" vod = "vod" cos = "cos" hls = "hls" class MediaCacheStatus(str, Enum): downloading = "downloading" failed = "failed" ready = "ready" deleted = "deleted" missing = "missing" unknown = "unknown" class MediaSource(BaseModel): path: str = Field(description="媒体源的路径") protocol: MediaProtocol = Field(description="媒体源的来源协议") endpoint: Optional[str] = Field(description="媒体源来源的终端地址,根据使用协议的不同,会有没有endpoint的情况") bucket: Optional[str] = Field(description="媒体源所使用的存储桶(s3)/SubAppId(vod)") urn: Optional[str] = Field(description="媒体源的唯一指定标识") status: MediaCacheStatus = Field(default=MediaCacheStatus.unknown, description="媒体源在Modal集群挂载的缓存状态") expired_at: Optional[datetime] = Field(description="缓存过期时间点, 为None时不会过期(过期处理WIP)", default=None) downloader_id: Optional[str] = Field(description="正在处理下载的Downloader ID", default=None) progress: int = Field(description="缓存进度", default=0) @classmethod def from_str(cls, media_url: str) -> 'MediaSource': if media_url.startswith('http://') or media_url.startswith('https://'): parsed_url = urlparse(media_url) path_str = f"{parsed_url.path}?{parsed_url.query}" if parsed_url.query else parsed_url.path return MediaSource(path=path_str, protocol=MediaProtocol.http, endpoint=parsed_url.netloc, # domain of http url bucket=None, urn=media_url) elif media_url.startswith('s3://'): pattern = r"^s3://[a-z]{2}-[a-z]+-\d/.*$" if not re.match(pattern, media_url): media_url = media_url.replace("s3://", f"s3://{s3_region}/") # s3://{endpoint}/{bucket}/{url} paths = media_url[5:].split('/') if len(paths) < 3: raise ValidationError("URN-s3 格式错误") media_source = MediaSource(path='/'.join(paths[2:]), protocol=MediaProtocol.s3, endpoint=paths[0], bucket=paths[1], urn=media_url) s3_mount_path = config.S3_mount_dir cache_path = os.path.join(s3_mount_path, media_source.cache_filepath) # 校验媒体文件是否存在缓存中 if not os.path.exists(cache_path): raise ValueError(f"媒体文件 {media_source.cache_filepath} 不存在于缓存中") return media_source elif media_url.startswith('vod://'): # vod://{endpoint}/{subAppId}/{fileId} paths = media_url[6:].split('/') if len(paths) < 3: raise ValidationError("URN-vod 格式错误") # 兼容有文件类型后缀和没有文件类型后缀的格式 url = paths[2] if '.' in os.path.basename(paths[2]) else paths[2] + ".mp4" return MediaSource(path=url, protocol=MediaProtocol.vod, bucket=paths[1], endpoint=paths[0], urn=media_url) elif media_url.startswith('cos://'): # cos://{endpoint}/{bucket}/{url} paths = media_url[6:].split('/') return MediaSource(path='/'.join(paths[2:]), protocol=MediaProtocol.cos, endpoint=paths[0], bucket=paths[1], urn=media_url) elif media_url.startswith('hls://'): # hls://merge-local-1324682537.cos.ap-shanghai.myqcloud.com/streams/1264/30322.m3u8 paths = media_url[6:] return MediaSource(path=f"https://{paths}", protocol=MediaProtocol.hls, endpoint=None, bucket=None, urn=media_url ) else: available_schemas = [member.value for member in MediaProtocol] available_schemas_str = ','.join(available_schemas) raise ValidationError(f"mediaUrl必须以{available_schemas_str}协议开头") @classmethod def __get_pydantic_json_schema__(cls, core_schema: Any, handler: Any) -> JsonSchemaValue: # Override the schema to represent it as a string return { "type": "string", "examples": [ "vod://ap-shanghai/subAppId/fileId", "http://example.com/video.mp4", "s3://endpoint/bucket/path/to/file", "cos://endpoint/bucket/path/to/file" ] } @computed_field(description="s3挂载路径下的缓存相对路径") @property def cache_filepath(self) -> str: match self.protocol: case MediaProtocol.s3: # 本地挂载缓存 if self.protocol == MediaProtocol.s3 and self.endpoint == s3_region and self.bucket == s3_bucket_name: return f"{self.path}" else: return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}" case MediaProtocol.http: clean_path = self.path.split('?')[0] if '?' in self.path else self.path return f"{self.protocol.value}/{self.endpoint}/{clean_path}" case _: return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}" @computed_field(description="文件后缀名") @cached_property def file_extension(self) -> Optional[str]: return os.path.basename(self.urn).split('.')[-1] @field_serializer('expired_at') def serialize_datetime(self, value: Optional[datetime], info: SerializationInfo) -> Optional[str]: if value: return value.isoformat() else: return None def get_cdn_url(self) -> str: if self.protocol == MediaProtocol.s3: return f"{self.path}" elif self.protocol == MediaProtocol.http: return f"{self.protocol.value}/{self.path}" return f"{self.protocol.value}/{self.endpoint}/{self.bucket}/{self.path}" def __str__(self): match self.protocol: case MediaProtocol.http: return self.urn case MediaProtocol.hls: return f"{self.protocol.value}://{self.path[8:]}" # strip "https://" from url case _: return f"{self.protocol.value}://{self.endpoint}/{self.bucket}/{self.path}" class MediaSources(BaseModel): inputs: List[MediaSource] = Field(examples=[ [ "vod://ap-shanghai/1500034234/1397757910405340824.mp4", "vod://ap-shanghai/1500034234/1397757910403699452.mp4", "s3://ap-northeast-2/modal-media-cache/concat/outputs/fc-01JTPV5FCNA74CKX3N3214XJPD/output.mp4" ] ], description="支持多种协议['vod://', 's3://'], 计划支持['cos://', http://]") @field_validator('inputs', mode='before') @classmethod def parse_inputs(cls, v: Union[str, MediaSource]) -> List[MediaSource]: if not v: raise ValidationError("inputs为空") result = [] for item in v: if isinstance(item, str): result.append(MediaSource.from_str(item)) elif isinstance(item, MediaSource): result.append(item) else: raise ValidationError("inputs元素类型错误: 必须是字符串") return result model_config = { "arbitrary_types_allowed": True } class CacheResult(BaseModel): caches: Dict[str, MediaSource] = Field(description="Cache ID") class DownloadResult(BaseModel): urls: List[str] = Field(description="下载链接") class UploadResultResponse(BaseModel): media: MediaSource = Field(description="上传完成的媒体资源")