from collections import deque
import hashlib
import io
import os
import tarfile
from tempfile import NamedTemporaryFile
import gzip

import numpy
from numpy.testing import assert_equal

from PIL import Image
import six
from six.moves import xrange

import zmq

# from fuel.server import recv_arrays, send_arrays
from fuel.converters.ilsvrc2010 import (extract_patch_images,
                                        image_consumer,
                                        load_from_tar_or_patch,
                                        other_set_producer,
                                        prepare_hdf5_file,
                                        prepare_metadata,
                                        process_train_set,
                                        process_other_set,
                                        read_devkit,
                                        read_metadata_mat_file,
                                        train_set_producer,
                                        DEVKIT_META_PATH,
                                        DEVKIT_ARCHIVE,
                                        TEST_GROUNDTRUTH)
from fuel.utils import find_in_data_path
from tests import skip_if_not_available


class MockSocket(object):
    """Mock of a ZeroMQ PUSH or PULL socket."""
    def __init__(self, socket_type, to_recv=()):
        self.socket_type = socket_type
        if self.socket_type not in (zmq.PUSH, zmq.PULL):
            raise NotImplementedError('only PUSH and PULL currently supported')
        self.sent = deque()
        self.to_recv = deque(to_recv)

    def send(self, data, flags=0, copy=True, track=False):
        assert self.socket_type == zmq.PUSH
        if track:
            # We don't emulate the behaviour required by this flag.
            raise NotImplementedError
        message = {'type': 'send', 'data': data, 'flags': flags, 'copy': copy}
        self.sent.append(message)

    def send_pyobj(self, obj, flags=0, protocol=2):
        assert self.socket_type == zmq.PUSH
        message = {'type': 'send_pyobj', 'obj': obj, 'flags': flags,
                   'protocol': protocol}
        self.sent.append(message)

    def recv(self, flags=0, copy=True, track=False):
        if track:
            # We don't emulate the behaviour required by this flag.
            raise NotImplementedError
        message = self.to_recv.popleft()
        assert message['type'] == 'recv'
        if 'flags' in message:
            assert message['flags'] == flags, 'flags did not match expected'
        if 'copy' in message:
            assert message['copy'] == copy, 'copy did not match expected'
        return message['data']

    def recv_pyobj(self, flags=0):
        message = self.to_recv.popleft()
        assert message['type'] == 'recv_pyobj'
        if 'flags' in message:
            assert flags == message['flags']
        return message['obj']


class MockH5PYData(object):
    def __init__(self, shape, dtype):
        self.data = numpy.empty(shape, dtype)
        self.dims = MockH5PYDims(len(shape))

    def __setitem__(self, where, what):
        self.data[where] = what

    def __getitem__(self, index):
        return self.data[index]

    @property
    def shape(self):
        return self.data.shape

    @property
    def dtype(self):
        return self.data.dtype


class MockH5PYFile(dict):
    filename = 'NOT_A_REAL_FILE.hdf5'

    def __init__(self):
        self.attrs = {}
        self.flushed = 0
        self.opened = False
        self.closed = False

    def create_dataset(self, name, shape, dtype):
        self[name] = MockH5PYData(shape, dtype)

    def flush(self):
        self.flushed += 1

    def __enter__(self):
        self.opened = True

    def __exit__(self, type, value, traceback):
        self.closed = True


class MockH5PYDim(object):
    def __init__(self, dims):
        self.dims = dims
        self.scales = []

    def attach_scale(self, dataset):
        # I think this is necessary for it to be valid?
        assert dataset in self.dims.scales.values()
        self.scales.append(dataset)


class MockH5PYDims(object):
    def __init__(self, ndim):
        self._dims = [MockH5PYDim(self) for _ in xrange(ndim)]
        self.scales = {}

    def create_scale(self, dataset, name):
        self.scales[name] = dataset

    def __getitem__(self, index):
        return self._dims[index]


def create_jpeg_data(image):
    """Create a JPEG in memory.

    Parameters
    ----------
    image : ndarray, 3-dimensional
        Array data representing the image to save. Mode ('L', 'RGB',
        'CMYK') will be determined from the last (third) axis.

    Returns
    -------
    jpeg_data : bytes
        The image encoded as a JPEG, returned as raw bytes.

    """
    if image.shape[-1] == 1:
        mode = 'L'
    elif image.shape[-1] == 3:
        mode = 'RGB'
    elif image.shape[-1] == 4:
        mode = 'CMYK'
    else:
        raise ValueError("invalid shape")
    pil_image = Image.frombytes(mode=mode, size=image.shape[:2],
                                data=image.tobytes())
    jpeg_data = io.BytesIO()
    pil_image.save(jpeg_data, format='JPEG')
    return jpeg_data.getvalue()


def create_fake_jpeg_tar(seed, min_num_images=5, max_num_images=50,
                         min_size=20, size_range=30, filenames=None,
                         random=True, gzip_probability=0.5, offset=0):
    """Create a TAR file of randomly generated JPEG files.

    Parameters
    ----------
    seed : int or sequence
        Seed for a `numpy.random.RandomState`.
    min_num_images : int, optional
        The minimum number of images to put in the TAR file.
    max_num_images : int, optional
        The maximum number of images to put in the TAR file.
    min_size : int, optional
        The minimum width and minimum height of each image.
    size_range : int, optional
        Maximum number of pixels added to `min_size` for image
        dimensions.
    filenames : list, optional
        If provided, use these filenames. Otherwise generate them
        randomly. Must be at least `max_num_images` long.
    random : bool, optional
        If `False`, substitute an image full of a single number,
        the order of that image in processing.
    gzip_probability : float, optional
        With this probability, randomly gzip the JPEG file without
        appending a gzip suffix.
    offset : int, optional
        Where to start the hashes for filenames. Default: 0.

    Returns
    -------
    tar_data : bytes
        A TAR file represented as raw bytes, containing between
        `min_num_images` and `max_num_images` JPEG files (inclusive).

    Notes
    -----
    Randomly choose between RGB, L and CMYK mode images. Also randomly
    gzips JPEGs to simulate the idiotic distribution format of
    ILSVRC2010.

    """
    rng = numpy.random.RandomState(seed)
    images = []
    if filenames is None:
        files = []
    else:
        if len(filenames) < max_num_images:
            raise ValueError('need at least max_num_images = %d filenames' %
                             max_num_images)
        files = filenames
    for i in xrange(rng.random_integers(min_num_images, max_num_images)):
        if filenames is None:
            max_len = 27  # so that with suffix, 32 characters
            files.append('%s.JPEG' %
                         hashlib.sha1(bytes(i + offset)).hexdigest()[:max_len])
        im = rng.random_integers(0, 255,
                                 size=(rng.random_integers(min_size,
                                                           min_size +
                                                           size_range),
                                       rng.random_integers(min_size,
                                                           min_size +
                                                           size_range),
                                       rng.random_integers(1, 4)))
        if not random:
            im *= 0
            assert (im == 0).all()
            im += i
            assert numpy.isscalar(i)
            assert (im == i).all()
        if im.shape[-1] == 2:
            im = im[:, :, :1]
        images.append(im)
    files = sorted(files)
    temp_tar = io.BytesIO()
    with tarfile.open(fileobj=temp_tar, mode='w') as tar:
        for fn, image in zip(files, images):
            try:
                with NamedTemporaryFile(mode='wb', suffix='.JPEG',
                                        delete=False) as f:
                    if rng.uniform() < gzip_probability:
                        gzip_data = io.BytesIO()
                        with gzip.GzipFile(mode='wb', fileobj=gzip_data) as gz:
                            gz.write(create_jpeg_data(image))
                        f.write(gzip_data.getvalue())
                    else:
                        f.write(create_jpeg_data(image))
                tar.add(f.name, arcname=fn)
            finally:
                os.remove(f.name)
    ordered_files = []
    with tarfile.open(fileobj=io.BytesIO(temp_tar.getvalue()),
                      mode='r') as tar:
        for info in tar.getmembers():
            ordered_files.append(info.name)
    return temp_tar.getvalue(), ordered_files


def create_fake_tar_of_tars(seed, num_inner_tars, *args, **kwargs):
    """Create a nested TAR of TARs of JPEGs.

    Parameters
    ----------
    seed : int or sequence
        Seed for a `numpy.random.RandomState`.
    num_inner_tars : int
        Number of TAR files to place inside.

    Returns
    -------
    tar_data : bytes
        A TAR file represented as raw bytes, TAR files of generated
        JPEGs.
    names : list
        Names of the inner TAR files.
    jpeg_names : list of lists
        A list of lists containing the names of JPEGs in each inner TAR.


    Notes
    -----
    Remainder of positional and keyword arguments are passed on to
    :func:`create_fake_jpeg_tars`.

    """
    seeds = numpy.arange(num_inner_tars) + seed
    tars, fns = [], []
    offset = 0
    for s in seeds:
        tar, fn = create_fake_jpeg_tar(s, *args, offset=offset, **kwargs)
        tars.append(tar)
        fns.append(fn)
        offset += len(fn)
    names = sorted(str(abs(hash(str(-i - 1)))) + '.tar'
                   for i, t in enumerate(tars))
    data = io.BytesIO()
    with tarfile.open(fileobj=data, mode='w') as outer:
        for tar, name in zip(tars, names):
            try:
                with NamedTemporaryFile(mode='wb', suffix='.tar',
                                        delete=False) as f:
                    f.write(tar)
                outer.add(f.name, arcname=name)
            finally:
                os.remove(f.name)
    return data.getvalue(), names, fns


def create_fake_patch_images(filenames=None, num_train=14, num_valid=15,
                             num_test=21):
    if filenames is None:
        num = num_train + num_valid + num_test
        filenames = ['%x' % abs(hash(str(i))) + '.JPEG' for i in xrange(num)]
    else:
        filenames = list(filenames)  # Copy, so list not modified in-place.
    filenames[:num_train] = ['train/' + f
                             for f in filenames[:num_train]]
    filenames[num_train:num_train + num_valid] = [
        'val/' + f for f in filenames[num_train:num_train + num_valid]
    ]
    filenames[num_train + num_valid:] = [
        'test/' + f for f in filenames[num_train + num_valid:]
    ]
    tar = create_fake_jpeg_tar(1, min_num_images=len(filenames),
                               max_num_images=len(filenames),
                               filenames=filenames, random=False,
                               gzip_probability=0.0)[0]
    return tar


def test_prepare_metadata():
    skip_if_not_available(datasets=[DEVKIT_ARCHIVE, TEST_GROUNDTRUTH])
    devkit_path = find_in_data_path(DEVKIT_ARCHIVE)
    test_gt_path = find_in_data_path(TEST_GROUNDTRUTH)
    n_train, v_gt, t_gt, wnid_map = prepare_metadata(devkit_path,
                                                     test_gt_path)
    assert n_train == 1261406
    assert len(v_gt) == 50000
    assert len(t_gt) == 150000
    assert sorted(wnid_map.values()) == list(range(1000))
    assert all(isinstance(k, six.string_types) and len(k) == 9
               for k in wnid_map)


def test_prepare_hdf5_file():
    hdf5_file = MockH5PYFile()
    prepare_hdf5_file(hdf5_file, 10, 5, 2)

    def get_start_stop(hdf5_file, split):
        rows = [r for r in hdf5_file.attrs['split'] if
                (r['split'].decode('utf8') == split)]
        return dict([(r['source'].decode('utf8'), (r['start'], r['stop']))
                     for r in rows])

    # Verify properties of the train split.
    train_splits = get_start_stop(hdf5_file, 'train')
    assert all(v == (0, 10) for v in train_splits.values())
    assert set(train_splits.keys()) == set([u'encoded_images', u'targets',
                                            u'filenames'])

    # Verify properties of the valid split.
    valid_splits = get_start_stop(hdf5_file, 'valid')
    assert all(v == (10, 15) for v in valid_splits.values())
    assert set(valid_splits.keys()) == set([u'encoded_images', u'targets',
                                            u'filenames'])

    # Verify properties of the test split.
    test_splits = get_start_stop(hdf5_file, 'test')
    assert all(v == (15, 17) for v in test_splits.values())
    assert set(test_splits.keys()) == set([u'encoded_images', u'targets',
                                           u'filenames'])

    from numpy import dtype

    # Verify properties of the encoded_images HDF5 dataset.
    assert hdf5_file['encoded_images'].shape[0] == 17
    assert len(hdf5_file['encoded_images'].shape) == 1
    assert hdf5_file['encoded_images'].dtype.kind == 'O'
    assert hdf5_file['encoded_images'].dtype.metadata['vlen'] == dtype('uint8')

    # Verify properties of the filenames dataset.
    assert hdf5_file['filenames'].shape[0] == 17
    assert len(hdf5_file['filenames'].shape) == 2
    assert hdf5_file['filenames'].dtype == dtype('S32')

    # Verify properties of the targets dataset.
    assert hdf5_file['targets'].shape[0] == 17
    assert hdf5_file['targets'].shape[1] == 1
    assert len(hdf5_file['targets'].shape) == 2
    assert hdf5_file['targets'].dtype == dtype('int16')


def test_process_train_set():
    tar_data, names, jpeg_names = create_fake_tar_of_tars(20150925, 5,
                                                          min_num_images=45,
                                                          max_num_images=55)
    all_jpegs = numpy.array(sum(jpeg_names, []))
    numpy.random.RandomState(20150925).shuffle(all_jpegs)
    patched_files = all_jpegs[:10]
    patches_data = create_fake_patch_images(filenames=patched_files,
                                            num_train=10, num_valid=0,
                                            num_test=0)
    hdf5_file = MockH5PYFile()
    prepare_hdf5_file(hdf5_file, len(all_jpegs), 0, 0)
    wnid_map = dict(zip((n.split('.')[0] for n in names), range(len(names))))

    process_train_set(hdf5_file, io.BytesIO(tar_data),
                      io.BytesIO(patches_data), len(all_jpegs),
                      wnid_map)

    # Other tests cover that the actual images are what they should be.
    # Just do a basic verification of the filenames and targets.

    assert set(all_jpegs) == set(s.decode('ascii')
                                 for s in hdf5_file['filenames'][:, 0])
    assert len(hdf5_file['encoded_images'][:]) == len(all_jpegs)
    assert len(hdf5_file['targets'][:]) == len(all_jpegs)


def test_process_other_set():
    images, all_filenames = create_fake_jpeg_tar(3, min_num_images=30,
                                                 max_num_images=40,
                                                 gzip_probability=0.0)
    all_filenames_shuffle = numpy.array(all_filenames)
    numpy.random.RandomState(20151202).shuffle(all_filenames_shuffle)
    patched_files = all_filenames_shuffle[:15]
    patches_data = create_fake_patch_images(filenames=patched_files,
                                            num_train=0, num_valid=15,
                                            num_test=0)
    hdf5_file = MockH5PYFile()
    OFFSET = 50
    prepare_hdf5_file(hdf5_file, OFFSET, len(all_filenames), 0)
    groundtruth = [i % 10 for i in range(len(all_filenames))]
    process_other_set(hdf5_file, 'valid', io.BytesIO(images),
                      io.BytesIO(patches_data), groundtruth, OFFSET)

    # Other tests cover that the actual images are what they should be.
    # Just do a basic verification of the filenames.

    assert all(hdf5_file['targets'][OFFSET:, 0] == groundtruth)
    assert all(a.decode('ascii') == b
               for a, b in zip(hdf5_file['filenames'][OFFSET:, 0],
                               all_filenames))


def test_train_set_producer():
    tar_data, names, jpeg_names = create_fake_tar_of_tars(20150923, 5,
                                                          min_num_images=45,
                                                          max_num_images=55)
    all_jpegs = numpy.array(sum(jpeg_names, []))
    numpy.random.RandomState(20150923).shuffle(all_jpegs)
    patched_files = all_jpegs[:10]
    patches_data = create_fake_patch_images(filenames=patched_files,
                                            num_train=10, num_valid=0,
                                            num_test=0)
    train_patches = extract_patch_images(io.BytesIO(patches_data), 'train')
    socket = MockSocket(zmq.PUSH)
    wnid_map = dict(zip((n.split('.')[0] for n in names), range(len(names))))

    train_set_producer(socket, io.BytesIO(tar_data), io.BytesIO(patches_data),
                       wnid_map)
    tar_data, names, jpeg_names = create_fake_tar_of_tars(20150923, 5,
                                                          min_num_images=45,
                                                          max_num_images=55)
    for tar_name in names:
        with tarfile.open(fileobj=io.BytesIO(tar_data)) as outer_tar:
            with tarfile.open(fileobj=outer_tar.extractfile(tar_name)) as tar:
                for record in tar:
                    jpeg = record.name
                    metadata_msg = socket.sent.popleft()
                    assert metadata_msg['type'] == 'send_pyobj'
                    assert metadata_msg['flags'] == zmq.SNDMORE
                    key = tar_name.split('.')[0]
                    assert metadata_msg['obj'] == (jpeg, wnid_map[key])

                    image_msg = socket.sent.popleft()
                    assert image_msg['type'] == 'send'
                    assert image_msg['flags'] == 0
                    if jpeg in train_patches:
                        assert image_msg['data'] == train_patches[jpeg]
                    else:
                        image_data, _ = load_from_tar_or_patch(tar, jpeg,
                                                               train_patches)
                        assert image_msg['data'] == image_data


MOCK_CONSUMER_MESSAGES = [
    {'type': 'recv_pyobj', 'flags': zmq.SNDMORE, 'obj': ('foo.jpeg', 2)},
    {'type': 'recv', 'flags': 0, 'data': numpy.cast['uint8']([6, 6, 6])},
    {'type': 'recv_pyobj', 'flags': zmq.SNDMORE, 'obj': ('bar.jpeg', 3)},
    {'type': 'recv', 'flags': 0, 'data': numpy.cast['uint8']([1, 8, 1, 2, 0])},
    {'type': 'recv_pyobj', 'flags': zmq.SNDMORE, 'obj': ('baz.jpeg', 5)},
    {'type': 'recv', 'flags': 0, 'data': numpy.cast['uint8']([1, 9, 7, 9])},
    {'type': 'recv_pyobj', 'flags': zmq.SNDMORE, 'obj': ('bur.jpeg', 7)},
    {'type': 'recv', 'flags': 0, 'data': numpy.cast['uint8']([1, 8, 6, 7])},
]


def test_image_consumer():
    mock_messages = MOCK_CONSUMER_MESSAGES
    hdf5_file = MockH5PYFile()
    prepare_hdf5_file(hdf5_file, 4, 5, 8)
    socket = MockSocket(zmq.PULL, to_recv=mock_messages)
    image_consumer(socket, hdf5_file, 4)

    assert_equal(hdf5_file['encoded_images'][0], [6, 6, 6])
    assert_equal(hdf5_file['encoded_images'][1], [1, 8, 1, 2, 0])
    assert_equal(hdf5_file['encoded_images'][2], [1, 9, 7, 9])
    assert_equal(hdf5_file['encoded_images'][3], [1, 8, 6, 7])
    assert_equal(hdf5_file['filenames'][:4], [[b'foo.jpeg'], [b'bar.jpeg'],
                                              [b'baz.jpeg'], [b'bur.jpeg']])
    assert_equal(hdf5_file['targets'][:4], [[2], [3], [5], [7]])


def test_images_consumer_randomized():
    mock_messages = MOCK_CONSUMER_MESSAGES + [
        {'type': 'recv_pyobj', 'flags': zmq.SNDMORE, 'obj': ('jenny.jpeg', 1)},
        {'type': 'recv', 'flags': 0,
         'data': numpy.cast['uint8']([8, 6, 7, 5, 3, 0, 9])}
    ]
    hdf5_file = MockH5PYFile()
    prepare_hdf5_file(hdf5_file, 4, 5, 8)
    socket = MockSocket(zmq.PULL, to_recv=mock_messages)
    image_consumer(socket, hdf5_file, 5, offset=4, shuffle_seed=0)
    written_data = set(tuple(s) for s in hdf5_file['encoded_images'][4:9])
    expected_data = set(tuple(s['data']) for s in mock_messages[1::2])
    assert written_data == expected_data

    written_targets = set(hdf5_file['targets'][4:9].flatten())
    expected_targets = set(s['obj'][1] for s in mock_messages[::2])
    assert written_targets == expected_targets

    written_filenames = set(hdf5_file['filenames'][4:9].flatten())
    expected_filenames = set(s['obj'][0].encode('ascii')
                             for s in mock_messages[::2])
    assert written_filenames == expected_filenames


def test_other_set_producer():
    # Create some fake data.
    num = 21
    image_archive, filenames = create_fake_jpeg_tar(seed=1979,
                                                    min_num_images=num,
                                                    max_num_images=num)
    patches = create_fake_patch_images(filenames=filenames,
                                       num_train=7, num_valid=7, num_test=7)

    valid_patches = extract_patch_images(io.BytesIO(patches), 'valid')
    test_patches = extract_patch_images(io.BytesIO(patches), 'test')
    assert len(valid_patches) == 7
    assert len(test_patches) == 7

    groundtruth = numpy.random.RandomState(1979).random_integers(0, 50,
                                                                 size=num)
    assert len(groundtruth) == 21
    gt_lookup = dict(zip(sorted(filenames), groundtruth))
    assert len(gt_lookup) == 21

    def check(which_set, set_patches):
        # Run other_set_producer and push to a fake socket.
        socket = MockSocket(zmq.PUSH)
        other_set_producer(socket, which_set, io.BytesIO(image_archive),
                           io.BytesIO(patches), groundtruth)

        # Now verify the data that socket received.
        with tarfile.open(fileobj=io.BytesIO(image_archive)) as tar:
            num_patched = 0
            for im_fn in filenames:
                # Verify the label and flags of the first (metadata)
                # message.
                label = gt_lookup[im_fn]
                metadata_msg = socket.sent.popleft()
                assert metadata_msg['type'] == 'send_pyobj'
                assert metadata_msg['flags'] == zmq.SNDMORE
                assert metadata_msg['obj'] == (im_fn, label)
                # Verify that the second (data) message came from
                # the right place, either a patch file or a TAR.
                data_msg = socket.sent.popleft()
                assert data_msg['type'] == 'send'
                assert data_msg['flags'] == 0
                expected, patched = load_from_tar_or_patch(tar, im_fn,
                                                           set_patches)
                num_patched += int(patched)
                assert data_msg['data'] == expected
            assert num_patched == len(set_patches)

    check('valid', valid_patches)
    check('test', test_patches)


def test_load_from_tar_or_patch():
    # Setup fake tar files.
    images, all_filenames = create_fake_jpeg_tar(3, min_num_images=200,
                                                 max_num_images=200,
                                                 gzip_probability=0.0)
    patch_data = create_fake_patch_images(all_filenames[::4], num_train=50,
                                          num_valid=0, num_test=0)

    patches = extract_patch_images(io.BytesIO(patch_data), 'train')

    assert len(patches) == 50
    with tarfile.open(fileobj=io.BytesIO(images)) as tar:
        for fn in all_filenames:
            image, patched = load_from_tar_or_patch(tar, fn, patches)
            if fn in patches:
                assert image == patches[fn]
                assert patched
            else:
                tar_image = tar.extractfile(fn).read()
                assert image == tar_image
                assert not patched


def test_read_devkit():
    skip_if_not_available(datasets=[DEVKIT_ARCHIVE])
    synsets, cost_mat, raw_valid_gt = read_devkit(
        find_in_data_path(DEVKIT_ARCHIVE))
    # synset and cost_matrix sanity tests appear in test_read_metadata_mat_file
    assert raw_valid_gt.min() == 1
    assert raw_valid_gt.max() == 1000
    assert raw_valid_gt.dtype.kind == 'i'
    assert raw_valid_gt.shape == (50000,)


def test_read_metadata_mat_file():
    skip_if_not_available(datasets=[DEVKIT_ARCHIVE])
    with tarfile.open(find_in_data_path(DEVKIT_ARCHIVE)) as tar:
        meta_mat = tar.extractfile(DEVKIT_META_PATH)
        synsets, cost_mat = read_metadata_mat_file(meta_mat)
    assert (synsets['ILSVRC2010_ID'] ==
            numpy.arange(1, len(synsets) + 1)).all()
    assert synsets['num_train_images'][1000:].sum() == 0
    assert (synsets['num_train_images'][:1000] > 0).all()
    assert synsets.ndim == 1
    assert synsets['wordnet_height'].min() == 0
    assert synsets['wordnet_height'].max() == 19
    assert synsets['WNID'].dtype == numpy.dtype('S9')
    assert (synsets['num_children'][:1000] == 0).all()
    assert (synsets['children'][:1000] == -1).all()

    # Assert the basics about the cost matrix.
    assert cost_mat.shape == (1000, 1000)
    assert cost_mat.dtype == 'uint8'
    assert cost_mat.min() == 0
    assert cost_mat.max() == 18
    assert (cost_mat == cost_mat.T).all()
    # Assert that the diagonal is 0.
    assert (cost_mat.flat[::1001] == 0).all()


def test_extract_patch_images():
    tar = create_fake_patch_images()
    assert len(extract_patch_images(io.BytesIO(tar), 'train')) == 14
    assert len(extract_patch_images(io.BytesIO(tar), 'valid')) == 15
    assert len(extract_patch_images(io.BytesIO(tar), 'test')) == 21