# coding=utf-8
# Copyright 2020 TF.Text Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Read text from RecordIO of tf.Examples and generate sorted word counts."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
from absl import app
from absl import flags
import apache_beam as beam
import tensorflow.compat.v1 as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import dataset_schema
from tensorflow_text.tools.wordpiece_vocab import utils

FLAGS = flags.FLAGS

flags.DEFINE_string('input_path', None, 'The input file path.')
flags.DEFINE_string('output_path', None, 'The output file path.')
flags.DEFINE_string(
    'lang_set', 'en,es,ru,ar,de,fr,it,pt,ja,pl,fa,zh',
    'Set of languages used to build wordpiece model, '
    'given as a comma-separated list.')
flags.DEFINE_string('text_key', 'text', 'Text feature key in input examples.')
flags.DEFINE_string(
    'language_code_key', 'language_code', 'Language code feature key.')
flags.DEFINE_float(
    'smoothing_exponent', 0.5,
    'Exponent used in calculating exponential smoothing coefficients.')
flags.DEFINE_integer('max_word_length', 50,
                     'Keep only words shorter than max_word_length.')


def word_count(input_path, output_path, raw_metadata, min_token_frequency=2):
  """Returns a pipeline counting words and writing the output.

  Args:
    input_path: recordio file to read
    output_path: path in which to write the output
    raw_metadata: metadata of input tf.Examples
    min_token_frequency: the min frequency for a token to be included
  """

  lang_set = set(FLAGS.lang_set.split(','))

  # Create pipeline.
  pipeline = beam.Pipeline()

  with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
    converter = tft.coders.ExampleProtoCoder(
        raw_metadata.schema, serialized=False)

    # Read raw data and convert to TF Transform encoded dict.
    raw_data = (
        pipeline
        | 'ReadInputData' >> beam.io.tfrecordio.ReadFromTFRecord(
            input_path, coder=beam.coders.ProtoCoder(tf.train.Example))
        | 'DecodeInputData' >> beam.Map(converter.decode))

    # Apply TF Transform.
    (transformed_data, _), _ = (
        (raw_data, raw_metadata)
        | 'FilterLangAndExtractToken' >> tft_beam.AnalyzeAndTransformDataset(
            utils.count_preprocessing_fn(FLAGS.text_key,
                                         FLAGS.language_code_key)))

    # Filter by languages.
    tokens = (
        transformed_data
        | 'FilterByLang' >> beam.ParDo(utils.FilterTokensByLang(lang_set)))

    # Calculate smoothing coefficients.
    coeffs = (
        tokens
        | 'CalculateSmoothingCoefficients' >> beam.CombineGlobally(
            utils.CalculateCoefficients(FLAGS.smoothing_exponent)))

    # Apply smoothing, aggregate counts, and sort words by count.
    _ = (
        tokens
        | 'ApplyExponentialSmoothing' >> beam.ParDo(
            utils.ExponentialSmoothing(), beam.pvalue.AsSingleton(coeffs))
        | 'SumCounts' >> beam.CombinePerKey(sum)
        | 'FilterLowCounts' >> beam.ParDo(utils.FilterByCount(
            FLAGS.max_word_length, min_token_frequency))
        | 'MergeAndSortCounts' >> beam.CombineGlobally(utils.SortByCount())
        | 'Flatten' >> beam.FlatMap(lambda x: x)
        | 'FormatCounts' >> beam.Map(lambda tc: '%s\t%s' % (tc[0], tc[1]))
        | 'WriteSortedCount' >> beam.io.WriteToText(
            output_path, shard_name_template=''))

  return pipeline


def main(_):
  # Generate schema of input data.
  raw_metadata = dataset_metadata.DatasetMetadata(
      dataset_schema.from_feature_spec({
          'text': tf.FixedLenFeature([], tf.string),
          'language_code': tf.FixedLenFeature([], tf.string),
      }))

  pipeline = word_count(FLAGS.input_path, FLAGS.output_path, raw_metadata)
  pipeline.run().wait_until_finish()


if __name__ == '__main__':
  app.run(main)