ComfyUI-CustomNode/utils/download_utils.py

146 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import requests
import threading
from tqdm import tqdm
from urllib.parse import urlparse
def download_file(url, output_path=None, num_threads=8, chunk_size=1024 * 1024):
"""
多线程下载文件
参数:
url (str): 下载URL
output_path (str, optional): 输出文件路径默认为URL中的文件名
num_threads (int, optional): 线程数默认为8
chunk_size (int, optional): 每个线程下载的块大小(字节)默认为1MB
返回:
bool: 下载成功返回True失败返回False
"""
try:
# 获取文件名
if not output_path:
output_path = os.path.basename(urlparse(url).path)
if not output_path:
output_path = "downloaded_file"
# 检查服务器是否支持范围请求
response = requests.head(url)
response.raise_for_status()
# 检查是否支持断点续传
supports_range = 'Accept-Ranges' in response.headers and response.headers['Accept-Ranges'] == 'bytes'
if not supports_range:
print("服务器不支持多线程下载,将使用单线程下载")
return _download_single_thread(url, output_path)
# 获取文件大小
file_size = int(response.headers.get('Content-Length', 0))
if not file_size:
print("无法获取文件大小,将使用单线程下载")
return _download_single_thread(url, output_path)
# 创建临时文件
temp_files = [f"{output_path}.part{i}" for i in range(num_threads)]
# 计算每个线程下载的范围
ranges = []
for i in range(num_threads):
start = i * (file_size // num_threads)
end = start + (file_size // num_threads) - 1 if i < num_threads - 1 else file_size - 1
ranges.append((start, end))
# 创建进度条
progress = tqdm(total=file_size, unit='B', unit_scale=True, desc="下载中")
# 启动线程
threads = []
for i in range(num_threads):
start, end = ranges[i]
thread = threading.Thread(target=_download_chunk, args=(url, temp_files[i], start, end, progress))
thread.start()
threads.append(thread)
# 等待所有线程完成
for thread in threads:
thread.join()
progress.close()
# 合并临时文件
if not _merge_files(temp_files, output_path):
print("合并文件失败")
return False
# 删除临时文件
for temp_file in temp_files:
os.remove(temp_file)
print(f"文件下载完成: {output_path}")
return True
except Exception as e:
print(f"下载过程中发生错误: {e}")
# 清理临时文件
for temp_file in temp_files:
if os.path.exists(temp_file):
os.remove(temp_file)
return False
def _download_single_thread(url, output_path):
"""单线程下载文件"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
with open(output_path, 'wb') as f, \
tqdm(desc="下载中", total=int(response.headers.get('Content-Length', 0)),
unit='B', unit_scale=True) as progress:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
progress.update(len(chunk))
print(f"文件下载完成: {output_path}")
return True
except Exception as e:
print(f"单线程下载失败: {e}")
return False
def _download_chunk(url, output_path, start, end, progress):
"""下载文件的指定块"""
headers = {'Range': f'bytes={start}-{end}'}
try:
response = requests.get(url, headers=headers, stream=True)
response.raise_for_status()
with open(output_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
progress.update(len(chunk))
except Exception as e:
print(f"下载块失败 {start}-{end}: {e}")
def _merge_files(temp_files, output_path):
"""合并临时文件"""
try:
with open(output_path, 'wb') as outfile:
for temp_file in temp_files:
with open(temp_file, 'rb') as infile:
outfile.write(infile.read())
return True
except Exception as e:
print(f"合并文件失败: {e}")
return False
# 使用示例
if __name__ == "__main__":
download_url = "https://example.com/large_file.zip" # 替换为实际的下载URL
download_file(download_url, num_threads=4)