# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Collection of input pipelines.

An input pipeline defines how to read and parse data. It produces a tuple
of (features, labels) that can be read by tf.learn estimators.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import abc
import sys

import six

import tensorflow as tf
from tensorflow.contrib.slim.python.slim.data import tfexample_decoder

from seq2seq.configurable import Configurable
from seq2seq.data import split_tokens_decoder, parallel_data_provider
from seq2seq.data.sequence_example_decoder import TFSEquenceExampleDecoder


def make_input_pipeline_from_def(def_dict, mode, **kwargs):
  """Creates an InputPipeline object from a dictionary definition.

  Args:
    def_dict: A dictionary defining the input pipeline.
      It must have "class" and "params" that correspond to the class
      name and constructor parameters of an InputPipeline, respectively.
    mode: A value in tf.contrib.learn.ModeKeys

  Returns:
    A new InputPipeline object
  """
  if not "class" in def_dict:
    raise ValueError("Input Pipeline definition must have a class property.")

  class_ = def_dict["class"]
  if not hasattr(sys.modules[__name__], class_):
    raise ValueError("Invalid Input Pipeline class: {}".format(class_))

  pipeline_class = getattr(sys.modules[__name__], class_)

  # Constructor arguments
  params = {}
  if "params" in def_dict:
    params.update(def_dict["params"])
  params.update(kwargs)

  return pipeline_class(params=params, mode=mode)


@six.add_metaclass(abc.ABCMeta)
class InputPipeline(Configurable):
  """Abstract InputPipeline class. All input pipelines must inherit from this.
  An InputPipeline defines how data is read, parsed, and separated into
  features and labels.

  Params:
    shuffle: If true, shuffle the data.
    num_epochs: Number of times to iterate through the dataset. If None,
      iterate forever.
  """

  def __init__(self, params, mode):
    Configurable.__init__(self, params, mode)

  @staticmethod
  def default_params():
    return {
        "shuffle": True,
        "num_epochs": None,
    }

  def make_data_provider(self, **kwargs):
    """Creates DataProvider instance for this input pipeline. Additional
    keyword arguments are passed to the DataProvider.
    """
    raise NotImplementedError("Not implemented.")

  @property
  def feature_keys(self):
    """Defines the features that this input pipeline provides. Returns
      a set of strings.
    """
    return set()

  @property
  def label_keys(self):
    """Defines the labels that this input pipeline provides. Returns
      a set of strings.
    """
    return set()

  @staticmethod
  def read_from_data_provider(data_provider):
    """Utility function to read all available items from a DataProvider.
    """
    item_values = data_provider.get(list(data_provider.list_items()))
    items_dict = dict(zip(data_provider.list_items(), item_values))
    return items_dict


class ParallelTextInputPipeline(InputPipeline):
  """An input pipeline that reads two parallel (line-by-line aligned) text
  files.

  Params:
    source_files: An array of file names for the source data.
    target_files: An array of file names for the target data. These must
      be aligned to the `source_files`.
    source_delimiter: A character to split the source text on. Defaults
      to  " " (space). For character-level training this can be set to the
      empty string.
    target_delimiter: Same as `source_delimiter` but for the target text.
  """

  @staticmethod
  def default_params():
    params = InputPipeline.default_params()
    params.update({
        "source_files": [],
        "target_files": [],
        "source_delimiter": " ",
        "target_delimiter": " ",
    })
    return params

  def make_data_provider(self, **kwargs):
    decoder_source = split_tokens_decoder.SplitTokensDecoder(
        tokens_feature_name="source_tokens",
        length_feature_name="source_len",
        append_token="SEQUENCE_END",
        delimiter=self.params["source_delimiter"])

    dataset_source = tf.contrib.slim.dataset.Dataset(
        data_sources=self.params["source_files"],
        reader=tf.TextLineReader,
        decoder=decoder_source,
        num_samples=None,
        items_to_descriptions={})

    dataset_target = None
    if len(self.params["target_files"]) > 0:
      decoder_target = split_tokens_decoder.SplitTokensDecoder(
          tokens_feature_name="target_tokens",
          length_feature_name="target_len",
          prepend_token="SEQUENCE_START",
          append_token="SEQUENCE_END",
          delimiter=self.params["target_delimiter"])

      dataset_target = tf.contrib.slim.dataset.Dataset(
          data_sources=self.params["target_files"],
          reader=tf.TextLineReader,
          decoder=decoder_target,
          num_samples=None,
          items_to_descriptions={})

    return parallel_data_provider.ParallelDataProvider(
        dataset1=dataset_source,
        dataset2=dataset_target,
        shuffle=self.params["shuffle"],
        num_epochs=self.params["num_epochs"],
        **kwargs)

  @property
  def feature_keys(self):
    return set(["source_tokens", "source_len"])

  @property
  def label_keys(self):
    return set(["target_tokens", "target_len"])

class ParallelTextInputPipelineFairseq(InputPipeline):
  """An input pipeline that reads two parallel (line-by-line aligned) text
  files.

  Params:
    source_files: An array of file names for the source data.
    target_files: An array of file names for the target data. These must
      be aligned to the `source_files`.
    source_delimiter: A character to split the source text on. Defaults
      to  " " (space). For character-level training this can be set to the
      empty string.
    target_delimiter: Same as `source_delimiter` but for the target text.
  """

  @staticmethod
  def default_params():
    params = InputPipeline.default_params()
    params.update({
        "source_files": [],
        "target_files": [],
        "source_delimiter": " ",
        "target_delimiter": " ",
    })
    return params

  def make_data_provider(self, **kwargs):
    decoder_source = split_tokens_decoder.SplitTokensDecoder(
        tokens_feature_name="source_tokens",
        length_feature_name="source_len",
        append_token="SEQUENCE_END",
        delimiter=self.params["source_delimiter"])

    dataset_source = tf.contrib.slim.dataset.Dataset(
        data_sources=self.params["source_files"],
        reader=tf.TextLineReader,
        decoder=decoder_source,
        num_samples=None,
        items_to_descriptions={})

    dataset_target = None
    if len(self.params["target_files"]) > 0:
      decoder_target = split_tokens_decoder.SplitTokensDecoder(
          tokens_feature_name="target_tokens",
          length_feature_name="target_len",
          prepend_token="SEQUENCE_END",
          append_token="SEQUENCE_END",
          delimiter=self.params["target_delimiter"])

      dataset_target = tf.contrib.slim.dataset.Dataset(
          data_sources=self.params["target_files"],
          reader=tf.TextLineReader,
          decoder=decoder_target,
          num_samples=None,
          items_to_descriptions={})

    return parallel_data_provider.ParallelDataProvider(
        dataset1=dataset_source,
        dataset2=dataset_target,
        shuffle=self.params["shuffle"],
        num_epochs=self.params["num_epochs"],
        **kwargs)

  @property
  def feature_keys(self):
    return set(["source_tokens", "source_len"])

  @property
  def label_keys(self):
    return set(["target_tokens", "target_len"])


class TFRecordInputPipeline(InputPipeline):
  """An input pipeline that reads a TFRecords containing both source
  and target sequences.

  Params:
    files: An array of file names to read from.
    source_field: The TFRecord feature field containing the source text.
    target_field: The TFRecord feature field containing the target text.
    source_delimiter: A character to split the source text on. Defaults
      to  " " (space). For character-level training this can be set to the
      empty string.
    target_delimiter: Same as `source_delimiter` but for the target text.
  """

  @staticmethod
  def default_params():
    params = InputPipeline.default_params()
    params.update({
        "files": [],
        "source_field": "source",
        "target_field": "target",
        "source_delimiter": " ",
        "target_delimiter": " ",
    })
    return params

  def make_data_provider(self, **kwargs):

    splitter_source = split_tokens_decoder.SplitTokensDecoder(
        tokens_feature_name="source_tokens",
        length_feature_name="source_len",
        append_token="SEQUENCE_END",
        delimiter=self.params["source_delimiter"])

    splitter_target = split_tokens_decoder.SplitTokensDecoder(
        tokens_feature_name="target_tokens",
        length_feature_name="target_len",
        prepend_token="SEQUENCE_START",
        append_token="SEQUENCE_END",
        delimiter=self.params["target_delimiter"])

    keys_to_features = {
        self.params["source_field"]: tf.FixedLenFeature((), tf.string),
        self.params["target_field"]: tf.FixedLenFeature(
            (), tf.string, default_value="")
    }

    items_to_handlers = {}
    items_to_handlers["source_tokens"] = tfexample_decoder.ItemHandlerCallback(
        keys=[self.params["source_field"]],
        func=lambda dict: splitter_source.decode(
            dict[self.params["source_field"]], ["source_tokens"])[0])
    items_to_handlers["source_len"] = tfexample_decoder.ItemHandlerCallback(
        keys=[self.params["source_field"]],
        func=lambda dict: splitter_source.decode(
            dict[self.params["source_field"]], ["source_len"])[0])
    items_to_handlers["target_tokens"] = tfexample_decoder.ItemHandlerCallback(
        keys=[self.params["target_field"]],
        func=lambda dict: splitter_target.decode(
            dict[self.params["target_field"]], ["target_tokens"])[0])
    items_to_handlers["target_len"] = tfexample_decoder.ItemHandlerCallback(
        keys=[self.params["target_field"]],
        func=lambda dict: splitter_target.decode(
            dict[self.params["target_field"]], ["target_len"])[0])

    decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                 items_to_handlers)

    dataset = tf.contrib.slim.dataset.Dataset(
        data_sources=self.params["files"],
        reader=tf.TFRecordReader,
        decoder=decoder,
        num_samples=None,
        items_to_descriptions={})

    return tf.contrib.slim.dataset_data_provider.DatasetDataProvider(
        dataset=dataset,
        shuffle=self.params["shuffle"],
        num_epochs=self.params["num_epochs"],
        **kwargs)

  @property
  def feature_keys(self):
    return set(["source_tokens", "source_len"])

  @property
  def label_keys(self):
    return set(["target_tokens", "target_len"])


class ImageCaptioningInputPipeline(InputPipeline):
  """An input pipeline that reads a TFRecords containing both source
  and target sequences.

  Params:
    files: An array of file names to read from.
    source_field: The TFRecord feature field containing the source text.
    target_field: The TFRecord feature field containing the target text.
    source_delimiter: A character to split the source text on. Defaults
      to  " " (space). For character-level training this can be set to the
      empty string.
    target_delimiter: Same as `source_delimiter` but for the target text.
  """

  @staticmethod
  def default_params():
    params = InputPipeline.default_params()
    params.update({
        "files": [],
        "image_field": "image/data",
        "image_format": "jpg",
        "caption_ids_field": "image/caption_ids",
        "caption_tokens_field": "image/caption",
    })
    return params

  def make_data_provider(self, **kwargs):

    context_keys_to_features = {
        self.params["image_field"]: tf.FixedLenFeature(
            [], dtype=tf.string),
        "image/format": tf.FixedLenFeature(
            [], dtype=tf.string, default_value=self.params["image_format"]),
    }

    sequence_keys_to_features = {
        self.params["caption_ids_field"]: tf.FixedLenSequenceFeature(
            [], dtype=tf.int64),
        self.params["caption_tokens_field"]: tf.FixedLenSequenceFeature(
            [], dtype=tf.string)
    }

    items_to_handlers = {
        "image": tfexample_decoder.Image(
            image_key=self.params["image_field"],
            format_key="image/format",
            channels=3),
        "target_ids":
        tfexample_decoder.Tensor(self.params["caption_ids_field"]),
        "target_tokens":
        tfexample_decoder.Tensor(self.params["caption_tokens_field"]),
        "target_len": tfexample_decoder.ItemHandlerCallback(
            keys=[self.params["caption_tokens_field"]],
            func=lambda x: tf.size(x[self.params["caption_tokens_field"]]))
    }

    decoder = TFSEquenceExampleDecoder(
        context_keys_to_features, sequence_keys_to_features, items_to_handlers)

    dataset = tf.contrib.slim.dataset.Dataset(
        data_sources=self.params["files"],
        reader=tf.TFRecordReader,
        decoder=decoder,
        num_samples=None,
        items_to_descriptions={})

    return tf.contrib.slim.dataset_data_provider.DatasetDataProvider(
        dataset=dataset,
        shuffle=self.params["shuffle"],
        num_epochs=self.params["num_epochs"],
        **kwargs)

  @property
  def feature_keys(self):
    return set(["image"])

  @property
  def label_keys(self):
    return set(["target_tokens", "target_ids", "target_len"])