Use custom download_url instead of torchvision's function

- importing torchvision is quiet timed consuming
This commit is contained in:
Dr.Lt.Data 2024-08-02 01:20:30 +09:00
parent 707ec9dff8
commit d7f5f08179

View File

@ -8,10 +8,34 @@ import configparser
import re
import json
import yaml
from torchvision.datasets.utils import download_url
import requests
from tqdm.auto import tqdm
from git.remote import RemoteProgress
def download_url(url, dest_folder, filename=None):
# Ensure the destination folder exists
if not os.path.exists(dest_folder):
os.makedirs(dest_folder)
# Extract filename from URL if not provided
if filename is None:
filename = os.path.basename(url)
# Full path to save the file
dest_path = os.path.join(dest_folder, filename)
# Download the file
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(dest_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
file.write(chunk)
else:
print(f"Failed to download file from {url}")
config_path = os.path.join(os.path.dirname(__file__), "config.ini")
nodelist_path = os.path.join(os.path.dirname(__file__), "custom-node-list.json")
working_directory = os.getcwd()