"""
Copyright 2018 Lambda Labs. All Rights Reserved.
Licensed under
==========================================================================

"""
from __future__ import print_function
import os
import csv

import tensorflow as tf

from .inputter import Inputter
from source.augmenter.external import vgg_preprocessing


class StyleTransferCSVInputter(Inputter):
  def __init__(self, config, augmenter):
    super(StyleTransferCSVInputter, self).__init__(config, augmenter)

    self.num_samples = -1

    if self.config.mode == "infer":
      self.test_samples = self.config.test_samples

  def create_nonreplicated_fn(self):
    batch_size = (self.config.batch_size_per_gpu *
                  self.config.gpu_count)
    max_step = (self.get_num_samples() * self.config.epochs // batch_size)
    tf.constant(max_step, name="max_step")

  def get_num_samples(self):
    if self.num_samples < 0:
      if self.config.mode == "infer":
        self.num_samples = len(self.test_samples)
      elif self.config.mode == "export":
        self.num_samples = 1           
      else:
        self.num_samples = 0
        for meta in self.config.dataset_meta:
          with open(meta) as f:
            parsed = csv.reader(f, delimiter=",", quotechar="'")
            self.num_samples += len(list(parsed))          
    return self.num_samples

  def get_samples_fn(self):
    if self.config.mode == "infer":
      images_path = self.test_samples
    elif self.config.mode == "train" or \
            self.config.mode == "eval":
      for meta in self.config.dataset_meta:
        assert os.path.exists(meta), (
          "Cannot find dataset_meta file {}.".format(meta))

      images_path = []

      for meta in self.config.dataset_meta:
        dirname = os.path.dirname(meta)
        with open(meta) as f:
          parsed = csv.reader(f, delimiter=",", quotechar="'")
          for row in parsed:
            images_path.append(os.path.join(dirname, row[0]))

    return (images_path,)

  def parse_fn(self, image_path):
    """Parse a single input sample
    """
    image = tf.read_file(image_path)
    image = tf.image.decode_jpeg(image,
                                 channels=self.config.image_depth,
                                 dct_method="INTEGER_ACCURATE")

    if self.config.mode == "infer":
      image = tf.to_float(image)
      image = vgg_preprocessing._mean_image_subtraction(image)
    else:
      if self.augmenter:
        is_training = (self.config.mode == "train")
        image = self.augmenter.augment(
          image,
          self.config.image_height,
          self.config.image_width,
          self.config.resize_side_min,
          self.config.resize_side_max,
          is_training=is_training,
          speed_mode=self.config.augmenter_speed_mode)
    return (image,)

  def input_fn(self, test_samples=[]):
    if self.config.mode == "export":
      image = tf.placeholder(tf.float32,
                             shape=(None, None, 3),
                             name="input_image")      
      image = tf.to_float(image)
      image = vgg_preprocessing._mean_image_subtraction(image)
      image = tf.expand_dims(image, 0)
      return image
    else:
      batch_size = (self.config.batch_size_per_gpu *
                    self.config.gpu_count)

      samples = self.get_samples_fn()

      dataset = tf.data.Dataset.from_tensor_slices(samples)

      if self.config.mode == "train":
        dataset = dataset.shuffle(self.get_num_samples())

      dataset = dataset.repeat(self.config.epochs)

      dataset = dataset.map(
        lambda image: self.parse_fn(image),
        num_parallel_calls=4)

      dataset = dataset.apply(
          tf.contrib.data.batch_and_drop_remainder(batch_size))

      dataset = dataset.prefetch(2)

      iterator = dataset.make_one_shot_iterator()
      return iterator.get_next()


def build(config, augmenter):
  return StyleTransferCSVInputter(config, augmenter)