# Copyright 2016 TensorLab. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

# _ds_examples.py
# Implementation of ExamplesDataSource.

import tensorflow as tf
from ._dataset import DataSet, DataSource
from ._schema import SchemaFieldType


class ExamplesDataSet(DataSet):
  """A DataSet representing data in tf.Example protobuf within a TFRecord format.
  """
  def __init__(self, schema, metadata=None, features=None, **kwargs):
    """Initializes a ExamplesDataSet with the specified DataSource instances.

    Arguments:
      schema: the description of the source data.
      metadata: additional per-field information associated with the data.
      features: the optional description of the transformed data.
      kwargs: the set of ExamplesDataSource instances or TFRecord paths to populate this DataSet.
    """
    datasources = {}
    for name, value in kwargs.iteritems():
      if isinstance(value, str):
        value = ExamplesDataSource(value)

      if isinstance(value, ExamplesDataSource):
        datasources[name] = value
      else:
        raise ValueError('The specified DataSource is not a ExamplesDataSource')

    if not len(datasources):
      raise ValueError('At least one DataSource must be specified.')

    super(ExamplesDataSet, self).__init__(datasources, schema, metadata, features)

  def parse_instances(self, instances, prediction=False):
    """Parses input instances according to the associated schema.

    Arguments:
      instances: The tensor containing input strings.
      prediction: Whether the instances are being parsed for producing predictions or not.
    Returns:
      A dictionary of tensors key'ed by field names.
    """
    # Convert the schema into an equivalent Example schema (expressed as features in Example
    # terminology).
    features = {}
    for field in self.schema:
      if field.type == SchemaFieldType.integer:
        dtype = tf.int64
        default_value = [0]
      elif field.type == SchemaFieldType.real:
        dtype = tf.float32
        default_value = [0.0]
      else:
        # discrete
        dtype = tf.string
        default_value = ['']

      if field.length == 0:
        feature = tf.VarLenFeature(dtype=dtype)
      else:
        if field.length != 1:
          default_value = default_value * field.length
        feature = tf.FixedLenFeature(shape=[field.length], dtype=dtype, default_value=default_value)

      features[field.name] = feature

    return tf.parse_example(instances, features, name='examples')


class ExamplesDataSource(DataSource):
  """A DataSource representing one or more TFRecord files containing tf.Example data.
  """
  def __init__(self, path, compressed=False):
    """Initializes an instance of a ExamplesDataSource with the specified TFRecord file(s).

    Arguments:
      path: TFRecord file containing the data. This can be a pattern to represent a set of files.
      compressed: Whether the TFRecord files are compressed.
    """
    super(ExamplesDataSource, self).__init__()
    self._path = path
    self._compressed = compressed

  @property
  def path(self):
    """Retrives the path represented by the DataSource.
    """
    return self._path

  def read_instances(self, count, shuffle, epochs):
    """Reads the data represented by this DataSource using a TensorFlow reader.

    Arguments:
      epochs: The number of epochs or passes over the data to perform.
    Returns:
      A tensor containing instances that are read.
    """
    # None implies unlimited; switch the value to None when epochs is 0.
    epochs = epochs or None

    options = None
    if self._compressed:
      options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)

    files = tf.train.match_filenames_once(self._path, name='files')
    queue = tf.train.string_input_producer(files, num_epochs=epochs, shuffle=shuffle,
                                           name='queue')
    reader = tf.TFRecordReader(options=options, name='reader')
    _, instances = reader.read_up_to(queue, count, name='read')

    return instances