import collections

import numpy as np
import tensorflow as tf

Datasets = collections.namedtuple('Datasets', ['train', 'validation', 'test', 'height', 'width', 'channels'])


def _preprocess_dataset(dataset, preprocess_fcn, dtype=tf.float32, reshape=True):
  from tensorflow.contrib.learn.python.learn.datasets.mnist import DataSet
  images, labels = preprocess_fcn(dataset.images, dataset.labels)    
  return DataSet(images, labels, dtype, reshape)

  
def _colorize(preprocess_fcn=None):
  
  def colorize_fcn(images, labels):
    num_images = images.shape[0]
    num_rgb_channels = 3
    num_exclude = np.random.randint(num_rgb_channels, size=num_images)
    exclude_channels = [np.sort(np.random.choice(num_rgb_channels, ne, replace=False)) for ne in num_exclude]  
    rgb_images = np.repeat(images, num_rgb_channels, axis=3)
    for i, ec in enumerate(exclude_channels):
      rgb_images[i, :, :, ec] = 0
  
    if preprocess_fcn is not None:
      rgb_images, labels = preprocess_fcn(rgb_images, labels)  
    return (rgb_images, labels)
  
  return colorize_fcn


def get_dataset(data_dir, preprocess_fcn=None, dtype=tf.float32, reshape=True):
  """Construct a DataSet.
  `dtype` can be either
  `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
  `[0, 1]`.
   `reshape` Convert shape from [num examples, rows, columns, depth]
    to [num examples, rows*columns] (assuming depth == 1)    
  """
  from tensorflow.examples.tutorials.mnist import input_data

  datasets = input_data.read_data_sets(data_dir, dtype=dtype, reshape=reshape)
  
  if preprocess_fcn is not None:
    train = _preprocess_dataset(datasets.train, preprocess_fcn, dtype, reshape)
    validation = _preprocess_dataset(datasets.validation, preprocess_fcn, dtype, reshape)
    test = _preprocess_dataset(datasets.test, preprocess_fcn, dtype, reshape)
  else:
    train = datasets.train
    validation = datasets.validation
    test = datasets.test

  height, width, channels = 28, 28, 1 
  return Datasets(train, validation, test, height, width, channels)


def get_colorized_dataset(data_dir, preprocess_fcn=None, dtype=tf.float32, reshape=True):
  datasets = get_dataset(data_dir, _colorize(preprocess_fcn), dtype, reshape)
  channels = 3
  return Datasets(datasets.train, datasets.validation, datasets.test, datasets.height, datasets.width, channels)