import csv import errno import hashlib import logging import os import sys import tarfile import threading import zipfile from _io import TextIOWrapper from queue import Queue from typing import Any, Iterable, List, Optional, Tuple, Union import torch import urllib import urllib.request from torch.utils.data import Dataset from torch.utils.model_zoo import tqdm def unicode_csv_reader(unicode_csv_data: TextIOWrapper, **kwargs: Any) -> Any: r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper. Borrowed and slightly modified from the Python docs: https://docs.python.org/2/library/csv.html#csv-examples Args: unicode_csv_data (TextIOWrapper): unicode csv data (see example below) Examples: >>> from torchaudio.datasets.utils import unicode_csv_reader >>> import io >>> with io.open(data_path, encoding="utf8") as f: >>> reader = unicode_csv_reader(f) """ # Fix field larger than field limit error maxInt = sys.maxsize while True: # decrease the maxInt value by factor 10 # as long as the OverflowError occurs. try: csv.field_size_limit(maxInt) break except OverflowError: maxInt = int(maxInt / 10) csv.field_size_limit(maxInt) for line in csv.reader(unicode_csv_data, **kwargs): yield line def makedir_exist_ok(dirpath: str) -> None: """ Python2 support for os.makedirs(.., exist_ok=True) """ try: os.makedirs(dirpath) except OSError as e: if e.errno == errno.EEXIST: pass else: raise def stream_url(url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True) -> Iterable: """Stream url by chunk Args: url (str): Url. start_byte (int, optional): Start streaming at that point (Default: ``None``). block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``). progress_bar (bool, optional): Display a progress bar (Default: ``True``). """ # If we already have the whole file, there is no need to download it again req = urllib.request.Request(url, method="HEAD") url_size = int(urllib.request.urlopen(req).info().get("Content-Length", -1)) if url_size == start_byte: return req = urllib.request.Request(url) if start_byte: req.headers["Range"] = "bytes={}-".format(start_byte) with urllib.request.urlopen(req) as upointer, tqdm( unit="B", unit_scale=True, unit_divisor=1024, total=url_size, disable=not progress_bar, ) as pbar: num_bytes = 0 while True: chunk = upointer.read(block_size) if not chunk: break yield chunk num_bytes += len(chunk) pbar.update(len(chunk)) def download_url(url: str, download_folder: str, filename: Optional[str] = None, hash_value: Optional[str] = None, hash_type: str = "sha256", progress_bar: bool = True, resume: bool = False) -> None: """Download file to disk. Args: url (str): Url. download_folder (str): Folder to download file. filename (str, optional): Name of downloaded file. If None, it is inferred from the url (Default: ``None``). hash_value (str, optional): Hash for url (Default: ``None``). hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). progress_bar (bool, optional): Display a progress bar (Default: ``True``). resume (bool, optional): Enable resuming download (Default: ``False``). """ req = urllib.request.Request(url, method="HEAD") req_info = urllib.request.urlopen(req).info() # Detect filename filename = filename or req_info.get_filename() or os.path.basename(url) filepath = os.path.join(download_folder, filename) if resume and os.path.exists(filepath): mode = "ab" local_size: Optional[int] = os.path.getsize(filepath) elif not resume and os.path.exists(filepath): raise RuntimeError( "{} already exists. Delete the file manually and retry.".format(filepath) ) else: mode = "wb" local_size = None if hash_value and local_size == int(req_info.get("Content-Length", -1)): with open(filepath, "rb") as file_obj: if validate_file(file_obj, hash_value, hash_type): return raise RuntimeError( "The hash of {} does not match. Delete the file manually and retry.".format( filepath ) ) with open(filepath, mode) as fpointer: for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar): fpointer.write(chunk) with open(filepath, "rb") as file_obj: if hash_value and not validate_file(file_obj, hash_value, hash_type): raise RuntimeError( "The hash of {} does not match. Delete the file manually and retry.".format( filepath ) ) def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool: """Validate a given file object with its hash. Args: file_obj: File object to read from. hash_value (str): Hash for url. hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). Returns: bool: return True if its a valid file, else False. """ if hash_type == "sha256": hash_func = hashlib.sha256() elif hash_type == "md5": hash_func = hashlib.md5() else: raise ValueError while True: # Read by chunk to avoid filling memory chunk = file_obj.read(1024 ** 2) if not chunk: break hash_func.update(chunk) return hash_func.hexdigest() == hash_value def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: """Extract archive. Args: from_path (str): the path of the archive. to_path (str, optional): the root path of the extraced files (directory of from_path) (Default: ``None``) overwrite (bool, optional): overwrite existing files (Default: ``False``) Returns: list: List of paths to extracted files even if not overwritten. Examples: >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' >>> from_path = './validation.tar.gz' >>> to_path = './' >>> torchaudio.datasets.utils.download_from_url(url, from_path) >>> torchaudio.datasets.utils.extract_archive(from_path, to_path) """ if to_path is None: to_path = os.path.dirname(from_path) try: with tarfile.open(from_path, "r") as tar: logging.info("Opened tar file {}.".format(from_path)) files = [] for file_ in tar: # type: Any file_path = os.path.join(to_path, file_.name) if file_.isfile(): files.append(file_path) if os.path.exists(file_path): logging.info("{} already extracted.".format(file_path)) if not overwrite: continue tar.extract(file_, to_path) return files except tarfile.ReadError: pass try: with zipfile.ZipFile(from_path, "r") as zfile: logging.info("Opened zip file {}.".format(from_path)) files = zfile.namelist() for file_ in files: file_path = os.path.join(to_path, file_) if os.path.exists(file_path): logging.info("{} already extracted.".format(file_path)) if not overwrite: continue zfile.extract(file_, to_path) return files except zipfile.BadZipFile: pass raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.") def walk_files(root: str, suffix: Union[str, Tuple[str]], prefix: bool = False, remove_suffix: bool = False) -> Iterable[str]: """List recursively all files ending with a suffix at a given root Args: root (str): Path to directory whose folders need to be listed suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). It uses the Python "str.endswith" method and is passed directly prefix (bool, optional): If true, prepends the full path to each result, otherwise only returns the name of the files found (Default: ``False``) remove_suffix (bool, optional): If true, removes the suffix to each result defined in suffix, otherwise will return the result as found (Default: ``False``). """ root = os.path.expanduser(root) for dirpath, _, files in os.walk(root): for f in files: if f.endswith(suffix): if remove_suffix: f = f[: -len(suffix)] if prefix: f = os.path.join(dirpath, f) yield f class _DiskCache(Dataset): """ Wrap a dataset so that, whenever a new item is returned, it is saved to disk. """ def __init__(self, dataset: Dataset, location: str = ".cached") -> None: self.dataset = dataset self.location = location self._id = id(self) self._cache: List = [None] * len(dataset) def __getitem__(self, n: int) -> Any: if self._cache[n]: f = self._cache[n] return torch.load(f) f = str(self._id) + "-" + str(n) f = os.path.join(self.location, f) item = self.dataset[n] self._cache[n] = f makedir_exist_ok(self.location) torch.save(item, f) return item def __len__(self) -> int: return len(self.dataset) def diskcache_iterator(dataset: Dataset, location: str = ".cached") -> Dataset: return _DiskCache(dataset, location) class _ThreadedIterator(threading.Thread): """ Prefetch the next queue_length items from iterator in a background thread. Example: >> for i in bg_iterator(range(10)): >> print(i) """ class _End: pass def __init__(self, generator: Iterable, maxsize: int) -> None: threading.Thread.__init__(self) self.queue: Queue = Queue(maxsize) self.generator = generator self.daemon = True self.start() def run(self) -> None: for item in self.generator: self.queue.put(item) self.queue.put(self._End) def __iter__(self) -> Any: return self def __next__(self) -> Any: next_item = self.queue.get() if next_item == self._End: raise StopIteration return next_item # Required for Python 2.7 compatibility def next(self) -> Any: return self.__next__() def bg_iterator(iterable: Iterable, maxsize: int) -> Any: return _ThreadedIterator(iterable, maxsize=maxsize)