146 lines
4.7 KiB
Python
146 lines
4.7 KiB
Python
import os
|
||
import requests
|
||
import threading
|
||
from tqdm import tqdm
|
||
from urllib.parse import urlparse
|
||
|
||
|
||
def download_file(url, output_path=None, num_threads=8, chunk_size=1024 * 1024):
|
||
"""
|
||
多线程下载文件
|
||
|
||
参数:
|
||
url (str): 下载URL
|
||
output_path (str, optional): 输出文件路径,默认为URL中的文件名
|
||
num_threads (int, optional): 线程数,默认为8
|
||
chunk_size (int, optional): 每个线程下载的块大小(字节),默认为1MB
|
||
|
||
返回:
|
||
bool: 下载成功返回True,失败返回False
|
||
"""
|
||
try:
|
||
# 获取文件名
|
||
if not output_path:
|
||
output_path = os.path.basename(urlparse(url).path)
|
||
if not output_path:
|
||
output_path = "downloaded_file"
|
||
|
||
# 检查服务器是否支持范围请求
|
||
response = requests.head(url)
|
||
response.raise_for_status()
|
||
|
||
# 检查是否支持断点续传
|
||
supports_range = 'Accept-Ranges' in response.headers and response.headers['Accept-Ranges'] == 'bytes'
|
||
|
||
if not supports_range:
|
||
print("服务器不支持多线程下载,将使用单线程下载")
|
||
return _download_single_thread(url, output_path)
|
||
|
||
# 获取文件大小
|
||
file_size = int(response.headers.get('Content-Length', 0))
|
||
if not file_size:
|
||
print("无法获取文件大小,将使用单线程下载")
|
||
return _download_single_thread(url, output_path)
|
||
|
||
# 创建临时文件
|
||
temp_files = [f"{output_path}.part{i}" for i in range(num_threads)]
|
||
|
||
# 计算每个线程下载的范围
|
||
ranges = []
|
||
for i in range(num_threads):
|
||
start = i * (file_size // num_threads)
|
||
end = start + (file_size // num_threads) - 1 if i < num_threads - 1 else file_size - 1
|
||
ranges.append((start, end))
|
||
|
||
# 创建进度条
|
||
progress = tqdm(total=file_size, unit='B', unit_scale=True, desc="下载中")
|
||
|
||
# 启动线程
|
||
threads = []
|
||
for i in range(num_threads):
|
||
start, end = ranges[i]
|
||
thread = threading.Thread(target=_download_chunk, args=(url, temp_files[i], start, end, progress))
|
||
thread.start()
|
||
threads.append(thread)
|
||
|
||
# 等待所有线程完成
|
||
for thread in threads:
|
||
thread.join()
|
||
|
||
progress.close()
|
||
|
||
# 合并临时文件
|
||
if not _merge_files(temp_files, output_path):
|
||
print("合并文件失败")
|
||
return False
|
||
|
||
# 删除临时文件
|
||
for temp_file in temp_files:
|
||
os.remove(temp_file)
|
||
|
||
print(f"文件下载完成: {output_path}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"下载过程中发生错误: {e}")
|
||
# 清理临时文件
|
||
for temp_file in temp_files:
|
||
if os.path.exists(temp_file):
|
||
os.remove(temp_file)
|
||
return False
|
||
|
||
|
||
def _download_single_thread(url, output_path):
|
||
"""单线程下载文件"""
|
||
try:
|
||
response = requests.get(url, stream=True)
|
||
response.raise_for_status()
|
||
|
||
with open(output_path, 'wb') as f, \
|
||
tqdm(desc="下载中", total=int(response.headers.get('Content-Length', 0)),
|
||
unit='B', unit_scale=True) as progress:
|
||
for chunk in response.iter_content(chunk_size=1024):
|
||
if chunk:
|
||
f.write(chunk)
|
||
progress.update(len(chunk))
|
||
|
||
print(f"文件下载完成: {output_path}")
|
||
return True
|
||
except Exception as e:
|
||
print(f"单线程下载失败: {e}")
|
||
return False
|
||
|
||
|
||
def _download_chunk(url, output_path, start, end, progress):
|
||
"""下载文件的指定块"""
|
||
headers = {'Range': f'bytes={start}-{end}'}
|
||
try:
|
||
response = requests.get(url, headers=headers, stream=True)
|
||
response.raise_for_status()
|
||
|
||
with open(output_path, 'wb') as f:
|
||
for chunk in response.iter_content(chunk_size=1024):
|
||
if chunk:
|
||
f.write(chunk)
|
||
progress.update(len(chunk))
|
||
except Exception as e:
|
||
print(f"下载块失败 {start}-{end}: {e}")
|
||
|
||
|
||
def _merge_files(temp_files, output_path):
|
||
"""合并临时文件"""
|
||
try:
|
||
with open(output_path, 'wb') as outfile:
|
||
for temp_file in temp_files:
|
||
with open(temp_file, 'rb') as infile:
|
||
outfile.write(infile.read())
|
||
return True
|
||
except Exception as e:
|
||
print(f"合并文件失败: {e}")
|
||
return False
|
||
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
download_url = "https://example.com/large_file.zip" # 替换为实际的下载URL
|
||
download_file(download_url, num_threads=4) |