Python numpy.frombuffer() Examples

The following are 30 code examples of numpy.frombuffer(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module numpy , or try the search function .
Example #1
Source File: base.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def ctypes2numpy_shared(cptr, shape):
    """Convert a ctypes pointer to a numpy array.

    The resulting NumPy array shares the memory with the pointer.

    Parameters
    ----------
    cptr : ctypes.POINTER(mx_float)
        pointer to the memory region

    shape : tuple
        Shape of target `NDArray`.

    Returns
    -------
    out : numpy_array
        A numpy array : numpy array.
    """
    if not isinstance(cptr, ctypes.POINTER(mx_float)):
        raise RuntimeError('expected float pointer')
    size = 1
    for s in shape:
        size *= s
    dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
    return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape) 
Example #2
Source File: async_.py    From chainerrl with MIT License 6 votes vote down vote up
def set_shared_params(a, b):
    """Set shared params (and persistent values) to a link.

    Args:
      a (chainer.Link): link whose params are to be replaced
      b (dict): dict that consists of (param_name, multiprocessing.Array)
    """
    assert isinstance(a, chainer.Link)
    remaining_keys = set(b.keys())
    for param_name, param in a.namedparams():
        if param_name in b:
            shared_param = b[param_name]
            param.array = np.frombuffer(
                shared_param, dtype=param.dtype).reshape(param.shape)
            remaining_keys.remove(param_name)
    for persistent_name, _ in chainerrl.misc.namedpersistent(a):
        if persistent_name in b:
            _set_persistent_values_recursively(
                a, persistent_name, b[persistent_name])
            remaining_keys.remove(persistent_name)
    assert not remaining_keys 
Example #3
Source File: download_and_convert_mnist.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def _extract_labels(filename, num_labels):
  """Extract the labels into a vector of int64 label IDs.

  Args:
    filename: The path to an MNIST labels file.
    num_labels: The number of labels in the file.

  Returns:
    A numpy array of shape [number_of_labels]
  """
  print('Extracting labels from: ', filename)
  with gzip.open(filename) as bytestream:
    bytestream.read(8)
    buf = bytestream.read(1 * num_labels)
    labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
  return labels 
Example #4
Source File: download_and_convert_mnist.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def _extract_images(filename, num_images):
  """Extract the images into a numpy array.

  Args:
    filename: The path to an MNIST images file.
    num_images: The number of images in the file.

  Returns:
    A numpy array of shape [number_of_images, height, width, channels].
  """
  print('Extracting images from: ', filename)
  with gzip.open(filename) as bytestream:
    bytestream.read(16)
    buf = bytestream.read(
        _IMAGE_SIZE * _IMAGE_SIZE * num_images * _NUM_CHANNELS)
    data = np.frombuffer(buf, dtype=np.uint8)
    data = data.reshape(num_images, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
  return data 
Example #5
Source File: serialization.py    From QCElemental with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def msgpackext_decode(obj: Any) -> Any:
    """
    Decodes a msgpack objects from a dictionary representation.

    Parameters
    ----------
    obj : Any
        An encoded object, likely a dictionary.

    Returns
    -------
    Any
        The decoded form of the object.
    """

    if b"_nd_" in obj:
        arr = np.frombuffer(obj[b"data"], dtype=obj[b"dtype"])
        if b"shape" in obj:
            arr.shape = obj[b"shape"]

        return arr

    return obj 
Example #6
Source File: input_data.py    From IntroToDeepLearning with MIT License 6 votes vote down vote up
def extract_images(filename):
  """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
  print('Extracting', filename)
  with gzip.open(filename) as bytestream:
    magic = _read32(bytestream)
    if magic != 2051:
      raise ValueError(
          'Invalid magic number %d in MNIST image file: %s' %
          (magic, filename))
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf, dtype=numpy.uint8)
    data = data.reshape(num_images, rows, cols, 1)
    return data 
Example #7
Source File: mnist.py    From fine-lm with MIT License 6 votes vote down vote up
def _extract_mnist_images(filename, num_images):
  """Extract images from an MNIST file into a numpy array.

  Args:
    filename: The path to an MNIST images file.
    num_images: The number of images in the file.

  Returns:
    A numpy array of shape [number_of_images, height, width, channels].
  """
  with gzip.open(filename) as bytestream:
    bytestream.read(16)
    buf = bytestream.read(_MNIST_IMAGE_SIZE * _MNIST_IMAGE_SIZE * num_images)
    data = np.frombuffer(buf, dtype=np.uint8)
    data = data.reshape(num_images, _MNIST_IMAGE_SIZE, _MNIST_IMAGE_SIZE, 1)
  return data 
Example #8
Source File: arrayUtil.py    From hsds with Apache License 2.0 6 votes vote down vote up
def bytesToArray(data, dt, shape):
    #print(f"bytesToArray({len(data)}, {dt}, {shape}")
    nelements = getNumElements(shape)
    if not isVlen(dt):
        # regular numpy from string
        arr = np.frombuffer(data, dtype=dt)
    else:
        arr = np.zeros((nelements,), dtype=dt)
        offset = 0
        for index in range(nelements):
            offset = readElement(data, offset, arr, index, dt)
    arr = arr.reshape(shape)
    # check that we can update the array if needed
    # Note: this seems to have been required starting with numpuy v 1.17
    # Setting the flag directly is not recommended. cf: https://github.com/numpy/numpy/issues/9440

    if not arr.flags['WRITEABLE']:
        arr_copy = arr.copy()
        arr = arr_copy

    return arr 
Example #9
Source File: mnist_input_data.py    From python-esppy with Apache License 2.0 6 votes vote down vote up
def extract_images(filename):
  """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
  print('Extracting %s' % filename)
  with gzip.open(filename) as bytestream:
    magic = _read32(bytestream)
    if magic != 2051:
      raise ValueError(
          'Invalid magic number %d in MNIST image file: %s' %
          (magic, filename))
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf, dtype=numpy.uint8)
    data = data.reshape(num_images, rows, cols, 1)
    return data 
Example #10
Source File: mnist.py    From dataflow with Apache License 2.0 6 votes vote down vote up
def extract_images(filename):
    """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
    with gzip.open(filename) as bytestream:
        magic = _read32(bytestream)
        if magic != 2051:
            raise ValueError(
                'Invalid magic number %d in MNIST image file: %s' %
                (magic, filename))
        num_images = _read32(bytestream)
        rows = _read32(bytestream)
        cols = _read32(bytestream)
        buf = bytestream.read(rows * cols * num_images)
        data = numpy.frombuffer(buf, dtype=numpy.uint8)
        data = data.reshape(num_images, rows, cols, 1)
        data = data.astype('float32') / 255.0
        return data 
Example #11
Source File: kaldi_io.py    From Attentive-Filtering-Network with MIT License 6 votes vote down vote up
def read_vec_int(file_or_fd):
  """ [int-vec] = read_vec_int(file_or_fd)
   Read kaldi integer vector, ascii or binary input,
  """
  fd = open_or_fd(file_or_fd)
  binary = fd.read(2).decode()
  if binary == '\0B': # binary flag
    assert(fd.read(1).decode() == '\4'); # int-size
    vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim
    # Elements from int32 vector are sored in tuples: (sizeof(int32), value),
    vec = np.frombuffer(fd.read(vec_size*5), dtype=[('size','int8'),('value','int32')], count=vec_size)
    assert(vec[0]['size'] == 4) # int32 size,
    ans = vec[:]['value'] # values are in 2nd column,
  else: # ascii,
    arr = (binary + fd.readline().decode()).strip().split()
    try:
      arr.remove('['); arr.remove(']') # optionally
    except ValueError:
      pass
    ans = np.array(arr, dtype=int)
  if fd is not file_or_fd : fd.close() # cleanup
  return ans

# Writing, 
Example #12
Source File: datasets.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def _get_data(self):
        if self._train:
            data, label = self._train_data, self._train_label
        else:
            data, label = self._test_data, self._test_label

        namespace = 'gluon/dataset/'+self._namespace
        data_file = download(_get_repo_file_url(namespace, data[0]),
                             path=self._root,
                             sha1_hash=data[1])
        label_file = download(_get_repo_file_url(namespace, label[0]),
                              path=self._root,
                              sha1_hash=label[1])

        with gzip.open(label_file, 'rb') as fin:
            struct.unpack(">II", fin.read(8))
            label = np.frombuffer(fin.read(), dtype=np.uint8).astype(np.int32)

        with gzip.open(data_file, 'rb') as fin:
            struct.unpack(">IIII", fin.read(16))
            data = np.frombuffer(fin.read(), dtype=np.uint8)
            data = data.reshape(len(label), 28, 28, 1)

        self._data = nd.array(data, dtype=data.dtype)
        self._label = label 
Example #13
Source File: carbonara.py    From gnocchi with Apache License 2.0 6 votes vote down vote up
def unserialize(cls, data, block_size, back_window):
        uncompressed = lz4.block.decompress(data)
        nb_points = (
            len(uncompressed) // cls._SERIALIZATION_TIMESTAMP_VALUE_LEN
        )

        try:
            timestamps = numpy.frombuffer(uncompressed, dtype='<Q',
                                          count=nb_points)
            values = numpy.frombuffer(
                uncompressed, dtype='<d',
                offset=nb_points * cls._SERIALIZATION_TIMESTAMP_LEN)
        except ValueError:
            raise InvalidData

        return cls.from_data(
            numpy.cumsum(timestamps),
            values,
            block_size=block_size,
            back_window=back_window) 
Example #14
Source File: kaldi_io.py    From Attentive-Filtering-Network with MIT License 6 votes vote down vote up
def _read_mat_binary(fd):
  # Data type
  header = fd.read(3).decode()
  # 'CM', 'CM2', 'CM3' are possible values,
  if header.startswith('CM'): return _read_compressed_mat(fd, header)
  elif header == 'FM ': sample_size = 4 # floats
  elif header == 'DM ': sample_size = 8 # doubles
  else: raise UnknownMatrixHeader("The header contained '%s'" % header)
  assert(sample_size > 0)
  # Dimensions
  s1, rows, s2, cols = np.frombuffer(fd.read(10), dtype='int8,int32,int8,int32', count=1)[0]
  # Read whole matrix
  buf = fd.read(rows * cols * sample_size)
  if sample_size == 4 : vec = np.frombuffer(buf, dtype='float32')
  elif sample_size == 8 : vec = np.frombuffer(buf, dtype='float64')
  else : raise BadSampleSize
  mat = np.reshape(vec,(rows,cols))
  return mat 
Example #15
Source File: input_data.py    From IntroToDeepLearning with MIT License 6 votes vote down vote up
def extract_images(filename):
  """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
  print('Extracting', filename)
  with gzip.open(filename) as bytestream:
    magic = _read32(bytestream)
    if magic != 2051:
      raise ValueError(
          'Invalid magic number %d in MNIST image file: %s' %
          (magic, filename))
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf, dtype=numpy.uint8)
    data = data.reshape(num_images, rows, cols, 1)
    return data 
Example #16
Source File: mnist_input_data.py    From Make_Money_with_Tensorflow with GNU General Public License v3.0 6 votes vote down vote up
def extract_images(filename):
  """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
  print('Extracting %s' % filename)
  with gzip.open(filename) as bytestream:
    magic = _read32(bytestream)
    if magic != 2051:
      raise ValueError(
          'Invalid magic number %d in MNIST image file: %s' %
          (magic, filename))
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf, dtype=numpy.uint8)
    data = data.reshape(num_images, rows, cols, 1)
    return data 
Example #17
Source File: dataset_tool.py    From disentangling_conditional_gans with MIT License 6 votes vote down vote up
def create_mnistrgb(tfrecord_dir, mnist_dir, num_images=1000000, random_seed=123):
    print('Loading MNIST from "%s"' % mnist_dir)
    import gzip
    with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:
        images = np.frombuffer(file.read(), np.uint8, offset=16)
    images = images.reshape(-1, 28, 28)
    images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
    assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
    assert np.min(images) == 0 and np.max(images) == 255
    
    with TFRecordExporter(tfrecord_dir, num_images) as tfr:
        rnd = np.random.RandomState(random_seed)
        for idx in range(num_images):
            tfr.add_image(images[rnd.randint(images.shape[0], size=3)])

#---------------------------------------------------------------------------- 
Example #18
Source File: dataset_tool.py    From disentangling_conditional_gans with MIT License 6 votes vote down vote up
def create_mnist(tfrecord_dir, mnist_dir):
    print('Loading MNIST from "%s"' % mnist_dir)
    import gzip
    with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:
        images = np.frombuffer(file.read(), np.uint8, offset=16)
    with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file:
        labels = np.frombuffer(file.read(), np.uint8, offset=8)
    images = images.reshape(-1, 1, 28, 28)
    images = np.pad(images, [(0,0), (0,0), (2,2), (2,2)], 'constant', constant_values=0)
    assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8
    assert labels.shape == (60000,) and labels.dtype == np.uint8
    assert np.min(images) == 0 and np.max(images) == 255
    assert np.min(labels) == 0 and np.max(labels) == 9
    onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
    onehot[np.arange(labels.size), labels] = 1.0
    
    with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
        order = tfr.choose_shuffled_order()
        for idx in range(order.size):
            tfr.add_image(images[order[idx]])
        tfr.add_labels(onehot[order])

#---------------------------------------------------------------------------- 
Example #19
Source File: test_sequentialfile.py    From baseband with GNU General Public License v3.0 6 votes vote down vote up
def _setup(self, tmpdir):
        self.data = b'abcdefghijklmnopqrstuvwxyz'
        self.uint8_data = np.frombuffer(self.data, dtype=np.uint8)
        self.size = len(self.data)
        self.files = [str(tmpdir.join('file{:1d}.raw'.format(i)))
                      for i in range(3)]
        self.max_file_size = 10
        self.sizes = []
        self.offsets = [0]
        offset = 0
        for filename in self.files:
            with open(filename, 'wb') as fw:
                part = self.data[offset:offset+self.max_file_size]
                fw.write(part)
                self.sizes.append(len(part))
                self.offsets.append(self.offsets[-1] + len(part))
            offset += self.max_file_size 
Example #20
Source File: __init__.py    From gnocchi with Apache License 2.0 5 votes vote down vote up
def _unserialize_measures(self, measure_id, data):
        try:
            return numpy.frombuffer(data, dtype=TIMESERIES_ARRAY_DTYPE)
        except ValueError:
            LOG.error(
                "Unable to decode measure %s, possible data corruption",
                measure_id)
            raise 
Example #21
Source File: mnist_input_data.py    From python-esppy with Apache License 2.0 5 votes vote down vote up
def extract_labels(filename, one_hot=False):
  """Extract the labels into a 1D uint8 numpy array [index]."""
  print('Extracting %s' % filename)
  with gzip.open(filename) as bytestream:
    magic = _read32(bytestream)
    if magic != 2049:
      raise ValueError(
          'Invalid magic number %d in MNIST label file: %s' %
          (magic, filename))
    num_items = _read32(bytestream)
    buf = bytestream.read(num_items)
    labels = numpy.frombuffer(buf, dtype=numpy.uint8)
    if one_hot:
      return dense_to_one_hot(labels)
    return labels 
Example #22
Source File: kaldi_io.py    From Attentive-Filtering-Network with MIT License 5 votes vote down vote up
def read_vec_flt(file_or_fd):
  """ [flt-vec] = read_vec_flt(file_or_fd)
   Read kaldi float vector, ascii or binary input,
  """
  fd = open_or_fd(file_or_fd)
  binary = fd.read(2).decode()
  if binary == '\0B': # binary flag
    # Data type,
    header = fd.read(3).decode()
    if header == 'FV ': sample_size = 4 # floats
    elif header == 'DV ': sample_size = 8 # doubles
    else: raise UnknownVectorHeader("The header contained '%s'" % header)
    assert(sample_size > 0)
    # Dimension,
    assert(fd.read(1).decode() == '\4'); # int-size
    vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim
    # Read whole vector,
    buf = fd.read(vec_size * sample_size)
    if sample_size == 4 : ans = np.frombuffer(buf, dtype='float32')
    elif sample_size == 8 : ans = np.frombuffer(buf, dtype='float64')
    else : raise BadSampleSize
    return ans
  else: # ascii,
    arr = (binary + fd.readline().decode()).strip().split()
    try:
      arr.remove('['); arr.remove(']') # optionally
    except ValueError:
      pass
    ans = np.array(arr, dtype=float)
  if fd is not file_or_fd : fd.close() # cleanup
  return ans

# Writing, 
Example #23
Source File: header.py    From baseband with GNU General Public License v3.0 5 votes vote down vote up
def fromfile(cls, fh, ntrack, decade=None, ref_time=None, verify=True):
        """Read Mark 4 header from file.

        Parameters
        ----------
        fh : filehandle
            To read header from.
        ntrack : int
            Number of Mark 4 bitstreams.
        decade : int or None
            Decade in which the observations were taken.  Can instead pass an
            approximate ``ref_time``.
        ref_time : `~astropy.time.Time` or None
            Reference time within 4 years of the observation time.  Used only
            if ``decade`` is not given.
        verify : bool, optional
            Whether to do basic verification of integrity.  Default: `True`.
        """
        dtype = cls._stream_dtype(ntrack)
        header_nbytes = ntrack * 160 // 8
        try:
            stream = np.frombuffer(fh.read(header_nbytes), dtype=dtype)
            assert len(stream) * dtype.itemsize == header_nbytes
        except (ValueError, AssertionError):
            raise EOFError("could not read full Mark 4 Header.")

        words = stream2words(stream)
        self = cls(words, decade=decade, ref_time=ref_time, verify=verify)
        self.mutable = False
        return self 
Example #24
Source File: mnist_input_data.py    From python-esppy with Apache License 2.0 5 votes vote down vote up
def _read32(bytestream):
  dt = numpy.dtype(numpy.uint32).newbyteorder('>')
  return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 
Example #25
Source File: shmem_vec_env.py    From HardRLWithYoutube with MIT License 5 votes vote down vote up
def _decode_obses(self, obs):
        result = {}
        for k in self.obs_keys:

            bufs = [b[k] for b in self.obs_bufs]
            o = [np.frombuffer(b.get_obj(), dtype=self.obs_dtypes[k]).reshape(self.obs_shapes[k]) for b in bufs]
            result[k] = np.array(o)
        return dict_to_obs(result) 
Example #26
Source File: carbonara.py    From gnocchi with Apache License 2.0 5 votes vote down vote up
def unserialize(cls, data, key, aggregation):
        """Unserialize an aggregated timeserie.

        :param data: Raw data buffer.
        :param key: A :class:`SplitKey` key.
        :param aggregation: The Aggregation object of this timeseries.
        """
        x, y = [], []

        if data:
            if cls.is_compressed(data):
                # Compressed format
                uncompressed = lz4.block.decompress(
                    memoryview(data)[1:].tobytes())
                nb_points = len(uncompressed) // cls.COMPRESSED_SERIAL_LEN

                try:
                    y = numpy.frombuffer(uncompressed, dtype='<H',
                                         count=nb_points)
                    x = numpy.frombuffer(
                        uncompressed, dtype='<d',
                        offset=nb_points*cls.COMPRESSED_TIMESPAMP_LEN)
                except ValueError:
                    raise InvalidData()
                y = numpy.cumsum(y * key.sampling) + key.key
            else:
                # Padded format
                try:
                    everything = numpy.frombuffer(data, dtype=[('b', '<?'),
                                                               ('v', '<d')])
                except ValueError:
                    raise InvalidData()
                index = numpy.nonzero(everything['b'])[0]
                y = index * key.sampling + key.key
                x = everything['v'][index]

        return cls.from_data(aggregation, y, x) 
Example #27
Source File: bitcoding.py    From L3C-PyTorch with GNU General Public License v3.0 5 votes vote down vote up
def read_bytes(f, ts):
    for t in ts:
        num_bytes_to_read = t().itemsize
        yield np.frombuffer(f.read(num_bytes_to_read), t, count=1)


# --- 
Example #28
Source File: figure_plotter.py    From L3C-PyTorch with GNU General Public License v3.0 5 votes vote down vote up
def _render_to_rgb(figure, close):
    canvas = plt_backend_agg.FigureCanvasAgg(figure)
    canvas.draw()
    data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
    w, h = figure.canvas.get_width_height()
    image_hwc = data.reshape([h, w, 4])[..., :3]
    image_chw = np.moveaxis(image_hwc, source=2, destination=0)
    if close:
        plt.close(figure)
    return image_chw 
Example #29
Source File: shmem_vec_env.py    From HardRLWithYoutube with MIT License 5 votes vote down vote up
def _subproc_worker(pipe, parent_pipe, env_fn_wrapper, obs_bufs, obs_shapes, obs_dtypes, keys):
    """
    Control a single environment instance using IPC and
    shared memory.
    """
    def _write_obs(maybe_dict_obs):
        flatdict = obs_to_dict(maybe_dict_obs)
        for k in keys:
            dst = obs_bufs[k].get_obj()
            dst_np = np.frombuffer(dst, dtype=obs_dtypes[k]).reshape(obs_shapes[k])  # pylint: disable=W0212
            np.copyto(dst_np, flatdict[k])

    env = env_fn_wrapper.x()
    parent_pipe.close()
    try:
        while True:
            cmd, data = pipe.recv()
            if cmd == 'reset':
                pipe.send(_write_obs(env.reset()))
            elif cmd == 'step':
                obs, reward, done, info = env.step(data)
                if done:
                    obs = env.reset()
                pipe.send((_write_obs(obs), reward, done, info))
            elif cmd == 'render':
                pipe.send(env.render(mode='rgb_array'))
            elif cmd == 'close':
                pipe.send(None)
                break
            else:
                raise RuntimeError('Got unrecognized cmd %s' % cmd)
    except KeyboardInterrupt:
        print('ShmemVecEnv worker: got KeyboardInterrupt')
    finally:
        env.close() 
Example #30
Source File: variational_autoencoder.py    From Recipes with MIT License 5 votes vote down vote up
def load_dataset():
    if sys.version_info[0] == 2:
        from urllib import urlretrieve
    else:
        from urllib.request import urlretrieve

    def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
        print("Downloading %s" % filename)
        urlretrieve(source + filename, filename)

    import gzip
    def load_mnist_images(filename):
        if not os.path.exists(filename):
            download(filename)
        with gzip.open(filename, 'rb') as f:
            data = np.frombuffer(f.read(), np.uint8, offset=16)
        data = data.reshape(-1, 1, 28, 28).transpose(0,1,3,2)
        return data / np.float32(255)

    X_train = load_mnist_images('train-images-idx3-ubyte.gz')
    X_test = load_mnist_images('t10k-images-idx3-ubyte.gz')
    X_train, X_val = X_train[:-10000], X_train[-10000:]
    return X_train, X_val, X_test

# ############################# Output images ################################
# image processing using PIL