modalDeploy/src/cluster/video.py

318 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import modal
from dotenv import dotenv_values
downloader_image = (
modal.Image
.debian_slim(python_version="3.11")
.pip_install_from_pyproject("../pyproject.toml")
.env(dotenv_values("../.runtime.env"))
.add_local_python_source('cluster')
.add_local_python_source('BowongModalFunctions')
)
app = modal.App(
name="media_app",
image=downloader_image,
include_source=False,
secrets=[
modal.Secret.from_name("cf-kv-secret", environment_name='dev'),
])
with downloader_image.imports():
import os, httpx, crcmod
import sentry_sdk
from sentry_sdk.integrations.loguru import LoguruIntegration
from tqdm import tqdm
from typing import Tuple, List
from loguru import logger
from datetime import datetime, UTC, timedelta
from modal import current_function_call_id
from tencentcloud.common.credential import Credential
from tencentcloud.vod.v20180717.vod_client import VodClient
from tencentcloud.vod.v20180717 import models as vod_request_models
from BowongModalFunctions.config import WorkerConfig
from BowongModalFunctions.utils.KVCache import KVCache
from BowongModalFunctions.models.media_model import MediaSource, MediaCacheStatus, MediaProtocol
from BowongModalFunctions.models.web_model import SentryTransactionInfo
config = WorkerConfig()
sentry_sdk.init(dsn="https://85632fdcd62f699c2f88af6ca489e9ec@sentry.bowongai.com/3",
send_default_pii=True,
traces_sample_rate=1.0,
profiles_sample_rate=1.0,
add_full_stack=True,
shutdown_timeout=2,
integrations=[LoguruIntegration()],
environment=config.modal_environment,
)
cf_account_id = os.environ.get("CF_ACCOUNT_ID")
cf_kv_api_token = os.environ.get("CF_KV_API_TOKEN")
cf_kv_namespace_id = os.environ.get("CF_KV_NAMESPACE_ID")
modal_kv_cache = KVCache(kv_name=config.modal_kv_name, environment=config.modal_environment)
@sentry_sdk.trace
def batch_update_cloudflare_kv(caches: List[MediaSource]):
with httpx.Client() as client:
try:
response = client.put(
f"https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/storage/kv/namespaces/{cf_kv_namespace_id}/bulk",
headers={"Authorization": f"Bearer {cf_kv_api_token}"},
json=[
{
"based64": False,
"key": cache.urn,
"value": cache.model_dump_json(),
}
for cache in caches
]
)
response.raise_for_status()
except httpx.RequestError as e:
logger.error(f"An error occurred while put kv to cloudflare")
raise e
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred while get kv from cloudflare {str(e)}")
raise e
except Exception as e:
logger.error(f"An unexpected error occurred: {str(e)}")
raise e
@sentry_sdk.trace
def batch_remove_cloudflare_kv(caches: List[MediaSource]):
with httpx.Client() as client:
try:
response = client.post(
f"https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/storage/kv/namespaces/{cf_kv_namespace_id}/bulk/delete",
headers={"Authorization": f"Bearer {cf_kv_api_token}"},
json=[cache.urn for cache in caches]
)
response.raise_for_status()
except httpx.RequestError as e:
logger.error(f"An error occurred while put kv to cloudflare")
raise e
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred while get kv from cloudflare {str(e)}")
raise e
except Exception as e:
logger.error(f"An unexpected error occurred: {str(e)}")
raise e
@app.function(cpu=1, timeout=1800,
cloud="aws",
max_containers=config.video_downloader_concurrency,
volumes={
"/mntS3": modal.CloudBucketMount(
bucket_name=config.S3_bucket_name,
secret=modal.Secret.from_name("aws-s3-secret", environment_name=config.modal_environment),
),
},
secrets=[modal.Secret.from_name("tencent-cloud-secret", environment_name=config.modal_environment)])
@modal.concurrent(max_inputs=10)
async def cache_submit(media: MediaSource, sentry_trace: SentryTransactionInfo) -> MediaSource:
def vod_init():
tencent_secret_id = os.environ["VOD_SECRET_ID"]
tencent_secret_key = os.environ["VOD_SECRET_KEY"]
cred = Credential(secret_id=tencent_secret_id, secret_key=tencent_secret_key)
return VodClient(credential=cred, region='ap-shanghai')
def vod_info(media: MediaSource) -> Tuple[str, str, str]:
logger.info(f"Downloading media {media}")
request = vod_request_models.DescribeMediaInfosRequest()
request.SubAppId = int(media.bucket)
# 兼容fileId带文件类型的格式和不带文件类型的格式
request.FileIds = [media.path.split('.')[0] if '.' in media.path else media.path]
response = vod_client.DescribeMediaInfos(request)
if len(response.MediaInfoSet) > 0:
media_info = response.MediaInfoSet[0].BasicInfo
logger.info(f"VOD info = {media_info}")
file_extension = media_info.Type
cache_dir = f"/{config.S3_mount_dir}/{media.protocol.value}/{media.endpoint}/{media.bucket}"
cache_file = media.path if '.' in media.path else f"{media.path}.{file_extension}"
return (cache_dir, cache_file, media_info.MediaUrl)
else:
raise FileNotFoundError(
f"FileId : {media.path} not found in SubAppId: {media.bucket} at {media.endpoint}")
def vod_download(media: MediaSource, on_progress_update: callable(float) = None) -> str:
cache_dir, cache_file, url = vod_info(media)
local_cache_filepath = os.path.join(cache_dir, cache_file)
download_large_file(url=url, output_path=local_cache_filepath,
on_progress_callback=on_progress_update)
return local_cache_filepath
def download_large_file(url: str, output_path: str, protocol: MediaProtocol = MediaProtocol.vod,
on_progress_callback: callable(float) = None) -> None:
# 配置日志
logger.info(f"Starting download from {url}")
try:
# 使用 httpx 发送 HEAD 请求获取文件大小
# 设置请求头,支持断点续传
headers = {'Range': 'bytes=0-'}
with httpx.Client() as client:
match protocol:
case MediaProtocol.vod:
head_response = client.head(url)
file_size = int(head_response.headers.get('content-length', 0))
remote_crc64 = int(head_response.headers.get('X-Cos-Hash-Crc64ecma', 0))
logger.info(f"File size: {file_size / (1024 * 1024 * 1024):.2f} GB")
if os.path.exists(output_path):
local_file_size = os.path.getsize(output_path)
logger.info(f"File size match check {local_file_size} = {file_size}")
if local_file_size == file_size:
logger.info(f"Check file CRC64...")
# CRC64使用ECMA-182标准校验 ref: https://cloud.tencent.com/document/product/436/40334#python-sdk
c64 = crcmod.mkCrcFun(0x142F0E1EBA9EA3693, initCrc=0, xorOut=0xffffffffffffffff,
rev=True)
with open(output_path, "rb") as local_file:
local_crc64 = c64(local_file.read())
logger.info(f"File crc64 check {local_crc64} = {remote_crc64}")
if local_crc64 == remote_crc64:
logger.success("File size verification passed!")
return
logger.info(f"Downloading {url}...")
# 发起流式请求
with client.stream('GET', url, headers=headers) as response:
response.raise_for_status()
file_size = int(response.headers.get('content-length', 0))
# 设置进度条
progress_bar = tqdm(
total=file_size,
unit='iB',
unit_scale=True,
desc='Downloading'
)
# 以二进制写模式打开文件
with open(output_path, 'wb') as file:
# 分块下载每次读取1MB
chunk_size = 1024 * 1024 # 1MB
downloaded_size = 0
for chunk in response.iter_bytes(chunk_size=chunk_size):
if chunk:
file.write(chunk)
downloaded_size += len(chunk)
if on_progress_callback:
on_progress_callback(downloaded_size / file_size)
progress_bar.update(len(chunk))
# 每下载100MB记录一次日志
if downloaded_size % (100 * 1024 * 1024) == 0:
logger.info(
f"Downloaded: {downloaded_size / (1024 * 1024 * 1024):.2f} GB")
progress_bar.close()
# 验证下载是否完成
if os.path.exists(output_path):
final_size = os.path.getsize(output_path)
logger.info(f"Download completed successfully!")
logger.info(f"Final file size: {final_size / (1024 * 1024 * 1024):.2f} GB")
logger.info(f"File saved to: {os.path.abspath(output_path)}")
# 验证文件大小是否匹配
if final_size == file_size:
logger.info("File size verification passed!")
else:
logger.warning(f"File size mismatch! Expected: {file_size}, Got: {final_size}")
except httpx.RequestError as e:
logger.error(f"An error occurred while requesting {url}: {str(e)}")
raise e
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred: {str(e)}")
raise e
except Exception as e:
logger.error(f"An unexpected error occurred: {str(e)}")
raise e
finally:
if 'progress_bar' in locals():
progress_bar.close()
vod_client = vod_init()
modal_kv = modal_kv_cache
fn_id = current_function_call_id()
with sentry_sdk.continue_trace(environ_or_headers={"sentry-trace": sentry_trace.x_trace_id,
"baggage": sentry_trace.x_baggage, }) as transaction:
transaction.set_context("runtime_environment", {
"MODAL_CLOUD_PROVIDER": os.environ.get('MODAL_CLOUD_PROVIDER', 'unknown'),
"MODAL_ENVIRONMENT": config.modal_environment,
"MODAL_IMAGE_ID": os.environ.get('MODAL_IMAGE_ID', 'unknown'),
"MODAL_IS_REMOTE": os.environ.get('MODAL_IS_REMOTE', 'unknown'),
"MODAL_REGION": os.environ.get('MODAL_REGION', 'unknown'),
"MODAL_TASK_ID": os.environ.get('MODAL_TASK_ID', 'unknown'),
"MODAL_IDENTITY_TOKEN": os.environ.get('MODAL_IDENTITY_TOKEN', 'unknown'),
})
with transaction.start_child(name="收到缓存视频任务", op="queue.receive") as receive_span:
receive_span.set_data("messaging.message.id", fn_id)
receive_span.set_data("messaging.destination.name", "video-downloader.cache_submit")
receive_span.set_data("messaging.message.retry.count", 0)
receive_span.set_data("cache.key", media.urn)
with receive_span.start_child(name="处理缓存视频任务", op="queue.process") as process_span:
process_span.set_data("messaging.message.id", fn_id)
process_span.set_data("messaging.destination.name", "video-downloader.cache_submit")
process_span.set_data("messaging.message.retry.count", 0)
process_span.set_data("cache.key", media.urn)
volume_cache_path = None
match media.protocol:
case MediaProtocol.vod:
try:
volume_cache_path = vod_download(media)
process_span.set_status("success")
except Exception as e:
logger.exception(e)
media.status = MediaCacheStatus.failed
modal_kv.set_cache(media)
batch_update_cloudflare_kv([media])
process_span.set_status("failed")
case MediaProtocol.http:
try:
cache_filepath = f"{config.S3_mount_dir}/{media.cache_filepath}"
download_large_file(url=media.__str__(), output_path=cache_filepath)
except Exception as e:
logger.exception(e)
media.status = MediaCacheStatus.failed
modal_kv.set_cache(media)
batch_update_cloudflare_kv([media])
process_span.set_status("failed")
case MediaProtocol.s3:
# 本地挂载缓存
if media.protocol == MediaProtocol.s3 and media.endpoint == config.S3_region and media.bucket == config.S3_bucket_name:
volume_cache_path = f"{config.S3_mount_dir}/{media.cache_filepath}"
else:
logger.error("protocol not yet supported")
case _:
process_span.set_status("failed")
logger.error(f"protocol not yet supported")
media.downloader_id = fn_id
media.status = MediaCacheStatus.ready if volume_cache_path else MediaCacheStatus.failed
media.progress = 1 if volume_cache_path else 0
media.expired_at = datetime.now(UTC) + timedelta(days=7) if volume_cache_path else None
modal_kv.set_cache(media)
batch_update_cloudflare_kv([media])
return media
@app.function(cpu=1, timeout=300,
max_containers=config.video_downloader_concurrency,
volumes={
"/mntS3": modal.CloudBucketMount(
bucket_name=config.S3_bucket_name,
secret=modal.Secret.from_name("aws-s3-secret", environment_name=config.modal_environment),
),
})
@modal.concurrent(max_inputs=10)
async def cache_delete(cache: MediaSource) -> MediaSource:
if os.path.exists(cache.cache_filepath):
os.remove(cache.cache_filepath)
cache.status = MediaCacheStatus.deleted
else:
cache.status = MediaCacheStatus.missing
return cache