"""Getting the images from Imagenet To download ImageNet dataset, we provide a script which requires an input `txt` file containing the URLs to the images. > Note: There used to be a file containing the image URLs for ImageNet 2011 available without registration on the > official website. Since the link appears to be down, you may want to use a non-official file (see DATASET.md). ``` python -m koalarization.dataset.download urls.txt path/to/dest ``` Use `-h` to see the available options """ import argparse import hashlib import imghdr import sys import tarfile import urllib.request from itertools import islice from os.path import join, isfile from typing import List from .shared import maybe_create_folder class ImagenetDownloader: """Class instance to download the images""" def __init__(self, links_source, dest_dir): """Constructor. Args: links_source (str): Link or path to file containing dataset URLs. Use local file to boost performance. dest_dir (str): Destination folder to save downloaded images. """ print(links_source) # Destination folder maybe_create_folder(dest_dir) self.dest_dir = dest_dir # If the source is a link download it if links_source.startswith("http://"): print( "Using urllib.request for the link archive is extremely", "slow, it is better to download the tgz archive manually", "and pass its path to this constructor", file=sys.stderr, ) links_source, _ = urllib.request.urlretrieve( links_source, 'imagenet_fall11_urls.txt' ) # If the source is an archive extract it if links_source.endswith('.tgz'): with tarfile.open(links_source, 'r:gz') as tar: tar.extractall(path='.') links_source = 'imagenet_fall11_urls.txt' # if not isfile(links_source): # raise Exception('Links source not valid: {}'.format(links_source)) self.links_source = links_source def download_images(self, size=None, skip=0): """Download images. Args: size (int, optional): Number of images to download. Defaults to all images. skip (int, optional): Number of images to skip at first. Defaults to 0. Returns: List[str]: List with image paths. """ urls = self._image_urls_generator() urls = islice(urls, skip, None if size is None else skip+size) downloaded_images = map(self._download_img, urls) valid_images = filter(lambda x: x is not None, downloaded_images) return list(valid_images) def _download_img(self, image_url: str): """Download single image. Args: image_url (str): Image url. Returns: Union[str, None]: Image path if image was succesfully downloaded. Otherwise, None. """ image_name = self._encode_image_name(image_url) image_path = join(self.dest_dir, image_name) if not isfile(image_path): try: # TODO use request.get with accept jpg? request = urllib.request.urlopen(image_url, timeout=5) image = request.read() if imghdr.what("", image) == "jpeg": with open(image_path, "wb") as f: f.write(image) except Exception as e: print("Error downloading {}: {}".format(image_url, e), file=sys.stderr) return None return image_path def _image_urls_generator(self): """Generate image URL. Returns: Union[str, None]: List of image URLs. Yields: Iterator[Union[str, None]]: Iterator over image URLs. """ with open(self.links_source) as sources: while True: try: line = sources.readline() if line.startswith('#') or line == '\n': # Comments or empty lines continue if line == '': # End of file return url = line.rsplit(maxsplit=1)[-1] if url.startswith('http'): yield url except UnicodeDecodeError as ue: print("Unicode error: {}".format(ue), file=sys.stderr) @staticmethod def _encode_image_name(image_url: str) -> str: """Image name encoding. Args: image_url (str): Image URL. Returns: str: Encoded image name. """ hash = hashlib.md5(image_url.encode('utf-8')).hexdigest() encoded_name = f'{hash}.jpeg' return encoded_name def _parse_args(): """Get args. Returns: Namespace: Contains args """ # Argparse setup parser = argparse.ArgumentParser( description='Download and process images from a file of URLs.' ) parser.add_argument( '-c', '--count', default=None, type=int, help='download only COUNT images (default all)' ) parser.add_argument( '-s', '--skip', default=0, type=int, metavar='N', help='skip the first N images (default 0)' ) parser.add_argument( 'source', type=str, metavar='SOURCE', help='set source for the image links, can be the url, the archive or the file itself' ) parser.add_argument( 'output', default='.', type=str, metavar='OUT_DIR', help='save downloaded images in OUT_DIR' ) args = parser.parse_args() return args if __name__ == '__main__': args = _parse_args() ImagenetDownloader(links_source=args.source, dest_dir=args.output).download_images( size=args.count, skip=args.skip ) print("Done")