# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset from the Natural Questions long answer task.

Fields:
  `question`: <string> [question_len]; tokens in the question.
  `context`: <string> [num_candidates, context_len]; tokens in each candidate.
  'answer_indices': <int32>[num_annotations]: answer indicated by each
  annotator.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import flags

import tensorflow.compat.v1 as tf

flags.DEFINE_string("nq_long_train_pattern", None,
                    "Path to NQ long answer training data.")

flags.DEFINE_string("nq_long_eval_pattern", None,
                    "Path to NQ long answer eval data.")

FLAGS = flags.FLAGS


def split_on_whitespace(str_tensor):
  return tf.string_split(tf.expand_dims(str_tensor, -1)).values


def parse_example(serialized_example):
  """Parse example."""
  features = tf.parse_single_example(
      serialized_example,
      features={
          "question":
              tf.FixedLenFeature([], tf.string),
          "context":
              tf.FixedLenSequenceFeature(
                  dtype=tf.string, shape=[], allow_missing=True),
          "long_answer_indices":
              tf.FixedLenSequenceFeature(
                  dtype=tf.int64, shape=[], allow_missing=True)
      })
  features["question"] = features["question"]
  features["context"] = features["context"]
  features["long_answer_indices"] = tf.to_int32(features["long_answer_indices"])
  return features


def get_dataset(is_train):
  """Gets a tf.data.Dataset representing the NQ data."""
  if is_train:
    data_pattern = FLAGS.nq_long_train_pattern
  else:
    data_pattern = FLAGS.nq_long_eval_pattern

  data_files = tf.gfile.Glob(data_pattern)
  assert data_files

  def _load_records(filenames):
    return tf.data.TFRecordDataset(filenames, buffer_size=16 * 1024)

  if is_train:
    # During training, read from all files in parallel to improve the speed of
    # the input pipeline.
    dataset = tf.data.Dataset.from_tensor_slices(tf.constant(data_files))
    dataset = dataset.apply(
        tf.contrib.data.shuffle_and_repeat(buffer_size=len(data_files)))
    dataset = dataset.apply(
        tf.contrib.data.parallel_interleave(
            _load_records, sloppy=is_train, cycle_length=len(data_files)))
  else:
    dataset = _load_records(data_files)
  dataset = dataset.map(parse_example, num_parallel_calls=6)
  return dataset