# coding=utf-8
# Copyright 2018 The Google AI Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python2, python3
"""Run ALBERT on SQuAD v1.1 using sentence piece tokenization."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import json
import os
import random
import time
from albert import fine_tuning_utils
from albert import modeling
from albert import squad_utils
import six
import tensorflow.compat.v1 as tf

from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
from tensorflow.contrib import tpu as contrib_tpu


# pylint: disable=g-import-not-at-top
if six.PY2:
  import six.moves.cPickle as pickle
else:
  import pickle
# pylint: enable=g-import-not-at-top

flags = tf.flags

FLAGS = flags.FLAGS

## Required parameters
flags.DEFINE_string(
    "albert_config_file", None,
    "The config json file corresponding to the pre-trained BERT model. "
    "This specifies the model architecture.")

flags.DEFINE_string("vocab_file", None,
                    "The vocabulary file that the BERT model was trained on.")

flags.DEFINE_string("spm_model_file", None,
                    "The model file for sentence piece tokenization.")

flags.DEFINE_string(
    "output_dir", None,
    "The output directory where the model checkpoints will be written.")

## Other parameters
flags.DEFINE_string("train_file", None,
                    "SQuAD json for training. E.g., train-v1.1.json")

flags.DEFINE_string(
    "predict_file", None,
    "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")

flags.DEFINE_string("train_feature_file", None,
                    "training feature file.")

flags.DEFINE_string(
    "predict_feature_file", None,
    "Location of predict features. If it doesn't exist, it will be written. "
    "If it does exist, it will be read.")

flags.DEFINE_string(
    "predict_feature_left_file", None,
    "Location of predict features not passed to TPU. If it doesn't exist, it "
    "will be written. If it does exist, it will be read.")

flags.DEFINE_string(
    "init_checkpoint", None,
    "Initial checkpoint (usually from a pre-trained BERT model).")

flags.DEFINE_string(
    "albert_hub_module_handle", None,
    "If set, the ALBERT hub module to use.")

flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")

flags.DEFINE_integer(
    "max_seq_length", 384,
    "The maximum total input sequence length after WordPiece tokenization. "
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded.")

flags.DEFINE_integer(
    "doc_stride", 128,
    "When splitting up a long document into chunks, how much stride to "
    "take between chunks.")

flags.DEFINE_integer(
    "max_query_length", 64,
    "The maximum number of tokens for the question. Questions longer than "
    "this will be truncated to this length.")

flags.DEFINE_bool("do_train", False, "Whether to run training.")

flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")

flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")

flags.DEFINE_integer("predict_batch_size", 8,
                     "Total batch size for predictions.")

flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")

flags.DEFINE_float("num_train_epochs", 3.0,
                   "Total number of training epochs to perform.")

flags.DEFINE_float(
    "warmup_proportion", 0.1,
    "Proportion of training to perform linear learning rate warmup for. "
    "E.g., 0.1 = 10% of training.")

flags.DEFINE_integer("save_checkpoints_steps", 1000,
                     "How often to save the model checkpoint.")

flags.DEFINE_integer("iterations_per_loop", 1000,
                     "How many steps to make in each estimator call.")

flags.DEFINE_integer(
    "n_best_size", 20,
    "The total number of n-best predictions to generate in the "
    "nbest_predictions.json output file.")

flags.DEFINE_integer(
    "max_answer_length", 30,
    "The maximum length of an answer that can be generated. This is needed "
    "because the start and end predictions are not conditioned on one another.")

flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")

tf.flags.DEFINE_string(
    "tpu_name", None,
    "The Cloud TPU to use for training. This should be either the name "
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
    "url.")

tf.flags.DEFINE_string(
    "tpu_zone", None,
    "[Optional] GCE zone where the Cloud TPU is located in. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")

tf.flags.DEFINE_string(
    "gcp_project", None,
    "[Optional] Project name for the Cloud TPU-enabled project. If not "
    "specified, we will attempt to automatically detect the GCE project from "
    "metadata.")

tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")

flags.DEFINE_integer(
    "num_tpu_cores", 8,
    "Only used if `use_tpu` is True. Total number of TPU cores to use.")

flags.DEFINE_bool(
    "use_einsum", True,
    "Whether to use tf.einsum or tf.reshape+tf.matmul for dense layers. Must "
    "be set to False for TFLite compatibility.")

flags.DEFINE_string(
    "export_dir",
    default=None,
    help=("The directory where the exported SavedModel will be stored."))


def validate_flags_or_throw(albert_config):
  """Validate the input FLAGS or throw an exception."""

  if not FLAGS.do_train and not FLAGS.do_predict and not FLAGS.export_dir:
    err_msg = "At least one of `do_train` or `do_predict` or `export_dir`" + "must be True."
    raise ValueError(err_msg)

  if FLAGS.do_train:
    if not FLAGS.train_file:
      raise ValueError(
          "If `do_train` is True, then `train_file` must be specified.")
  if FLAGS.do_predict:
    if not FLAGS.predict_file:
      raise ValueError(
          "If `do_predict` is True, then `predict_file` must be specified.")
    if not FLAGS.predict_feature_file:
      raise ValueError(
          "If `do_predict` is True, then `predict_feature_file` must be "
          "specified.")
    if not FLAGS.predict_feature_left_file:
      raise ValueError(
          "If `do_predict` is True, then `predict_feature_left_file` must be "
          "specified.")

  if FLAGS.max_seq_length > albert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the ALBERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, albert_config.max_position_embeddings))

  if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
    raise ValueError(
        "The max_seq_length (%d) must be greater than max_query_length "
        "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length))


def build_squad_serving_input_fn(seq_length):
  """Builds a serving input fn for raw input."""

  def _seq_serving_input_fn():
    """Serving input fn for raw images."""
    input_ids = tf.placeholder(
        shape=[1, seq_length], name="input_ids", dtype=tf.int32)
    input_mask = tf.placeholder(
        shape=[1, seq_length], name="input_mask", dtype=tf.int32)
    segment_ids = tf.placeholder(
        shape=[1, seq_length], name="segment_ids", dtype=tf.int32)

    inputs = {
        "input_ids": input_ids,
        "input_mask": input_mask,
        "segment_ids": segment_ids
    }
    return tf.estimator.export.ServingInputReceiver(features=inputs,
                                                    receiver_tensors=inputs)

  return _seq_serving_input_fn


def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)

  validate_flags_or_throw(albert_config)

  tf.gfile.MakeDirs(FLAGS.output_dir)

  tokenizer = fine_tuning_utils.create_vocab(
      vocab_file=FLAGS.vocab_file,
      do_lower_case=FLAGS.do_lower_case,
      spm_model_file=FLAGS.spm_model_file,
      hub_module=FLAGS.albert_hub_module_handle)

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
  if FLAGS.do_train:
    iterations_per_loop = int(min(FLAGS.iterations_per_loop,
                                  FLAGS.save_checkpoints_steps))
  else:
    iterations_per_loop = FLAGS.iterations_per_loop
  run_config = contrib_tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      keep_checkpoint_max=0,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      tpu_config=contrib_tpu.TPUConfig(
          iterations_per_loop=iterations_per_loop,
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host))

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None
  if FLAGS.do_train:
    train_examples = squad_utils.read_squad_examples(
        input_file=FLAGS.train_file, is_training=True)
    num_train_steps = int(
        len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    # Pre-shuffle the input to avoid having to make a very large shuffle
    # buffer in in the `input_fn`.
    rng = random.Random(12345)
    rng.shuffle(train_examples)

  model_fn = squad_utils.v1_model_fn_builder(
      albert_config=albert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu,
      use_einsum=FLAGS.use_einsum,
      hub_module=FLAGS.albert_hub_module_handle)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = contrib_tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)

  if FLAGS.do_train:
    # We write to a temporary file to avoid storing very large constant tensors
    # in memory.

    if not tf.gfile.Exists(FLAGS.train_feature_file):
      train_writer = squad_utils.FeatureWriter(
          filename=os.path.join(FLAGS.train_feature_file), is_training=True)
      squad_utils.convert_examples_to_features(
          examples=train_examples,
          tokenizer=tokenizer,
          max_seq_length=FLAGS.max_seq_length,
          doc_stride=FLAGS.doc_stride,
          max_query_length=FLAGS.max_query_length,
          is_training=True,
          output_fn=train_writer.process_feature,
          do_lower_case=FLAGS.do_lower_case)
      train_writer.close()

    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num orig examples = %d", len(train_examples))
    # tf.logging.info("  Num split examples = %d", train_writer.num_features)
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Num steps = %d", num_train_steps)
    del train_examples

    train_input_fn = squad_utils.input_fn_builder(
        input_file=FLAGS.train_feature_file,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True,
        use_tpu=FLAGS.use_tpu,
        bsz=FLAGS.train_batch_size,
        is_v2=False)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

  if FLAGS.do_predict:
    with tf.gfile.Open(FLAGS.predict_file) as predict_file:
      prediction_json = json.load(predict_file)["data"]

    eval_examples = squad_utils.read_squad_examples(
        input_file=FLAGS.predict_file, is_training=False)

    if (tf.gfile.Exists(FLAGS.predict_feature_file) and tf.gfile.Exists(
        FLAGS.predict_feature_left_file)):
      tf.logging.info("Loading eval features from {}".format(
          FLAGS.predict_feature_left_file))
      with tf.gfile.Open(FLAGS.predict_feature_left_file, "rb") as fin:
        eval_features = pickle.load(fin)
    else:
      eval_writer = squad_utils.FeatureWriter(
          filename=FLAGS.predict_feature_file, is_training=False)
      eval_features = []

      def append_feature(feature):
        eval_features.append(feature)
        eval_writer.process_feature(feature)

      squad_utils.convert_examples_to_features(
          examples=eval_examples,
          tokenizer=tokenizer,
          max_seq_length=FLAGS.max_seq_length,
          doc_stride=FLAGS.doc_stride,
          max_query_length=FLAGS.max_query_length,
          is_training=False,
          output_fn=append_feature,
          do_lower_case=FLAGS.do_lower_case)
      eval_writer.close()

      with tf.gfile.Open(FLAGS.predict_feature_left_file, "wb") as fout:
        pickle.dump(eval_features, fout)

    tf.logging.info("***** Running predictions *****")
    tf.logging.info("  Num orig examples = %d", len(eval_examples))
    tf.logging.info("  Num split examples = %d", len(eval_features))
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    predict_input_fn = squad_utils.input_fn_builder(
        input_file=FLAGS.predict_feature_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=False,
        use_tpu=FLAGS.use_tpu,
        bsz=FLAGS.predict_batch_size,
        is_v2=False)

    def get_result(checkpoint):
      """Evaluate the checkpoint on SQuAD 1.0."""
      # If running eval on the TPU, you will need to specify the number of
      # steps.
      reader = tf.train.NewCheckpointReader(checkpoint)
      global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
      all_results = []
      for result in estimator.predict(
          predict_input_fn, yield_single_examples=True,
          checkpoint_path=checkpoint):
        if len(all_results) % 1000 == 0:
          tf.logging.info("Processing example: %d" % (len(all_results)))
        unique_id = int(result["unique_ids"])
        start_log_prob = [float(x) for x in result["start_log_prob"].flat]
        end_log_prob = [float(x) for x in result["end_log_prob"].flat]
        all_results.append(
            squad_utils.RawResult(
                unique_id=unique_id,
                start_log_prob=start_log_prob,
                end_log_prob=end_log_prob))

      output_prediction_file = os.path.join(
          FLAGS.output_dir, "predictions.json")
      output_nbest_file = os.path.join(
          FLAGS.output_dir, "nbest_predictions.json")

      result_dict = {}
      squad_utils.accumulate_predictions_v1(
          result_dict, eval_examples, eval_features,
          all_results, FLAGS.n_best_size, FLAGS.max_answer_length)
      predictions = squad_utils.write_predictions_v1(
          result_dict, eval_examples, eval_features, all_results,
          FLAGS.n_best_size, FLAGS.max_answer_length,
          output_prediction_file, output_nbest_file)

      return squad_utils.evaluate_v1(
          prediction_json, predictions), int(global_step)

    def _find_valid_cands(curr_step):
      filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
      candidates = []
      for filename in filenames:
        if filename.endswith(".index"):
          ckpt_name = filename[:-6]
          idx = ckpt_name.split("-")[-1]
          if idx != "best" and int(idx) > curr_step:
            candidates.append(filename)
      return candidates

    output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
    checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
    key_name = "f1"
    writer = tf.gfile.GFile(output_eval_file, "w")
    if tf.gfile.Exists(checkpoint_path + ".index"):
      result = get_result(checkpoint_path)
      best_perf = result[0][key_name]
      global_step = result[1]
    else:
      global_step = -1
      best_perf = -1
      checkpoint_path = None
    while global_step < num_train_steps:
      steps_and_files = {}
      filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
      for filename in filenames:
        if filename.endswith(".index"):
          ckpt_name = filename[:-6]
          cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
          if cur_filename.split("-")[-1] == "best":
            continue
          gstep = int(cur_filename.split("-")[-1])
          if gstep not in steps_and_files:
            tf.logging.info("Add {} to eval list.".format(cur_filename))
            steps_and_files[gstep] = cur_filename
      tf.logging.info("found {} files.".format(len(steps_and_files)))
      if not steps_and_files:
        tf.logging.info("found 0 file, global step: {}. Sleeping."
                        .format(global_step))
        time.sleep(60)
      else:
        for ele in sorted(steps_and_files.items()):
          step, checkpoint_path = ele
          if global_step >= step:
            if len(_find_valid_cands(step)) > 1:
              for ext in ["meta", "data-00000-of-00001", "index"]:
                src_ckpt = checkpoint_path + ".{}".format(ext)
                tf.logging.info("removing {}".format(src_ckpt))
                tf.gfile.Remove(src_ckpt)
            continue
          result, global_step = get_result(checkpoint_path)
          tf.logging.info("***** Eval results *****")
          for key in sorted(result.keys()):
            tf.logging.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))
          if result[key_name] > best_perf:
            best_perf = result[key_name]
            for ext in ["meta", "data-00000-of-00001", "index"]:
              src_ckpt = checkpoint_path + ".{}".format(ext)
              tgt_ckpt = checkpoint_path.rsplit(
                  "-", 1)[0] + "-best.{}".format(ext)
              tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
              tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
              writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt))
          writer.write("best {} = {}\n".format(key_name, best_perf))
          tf.logging.info("  best {} = {}\n".format(key_name, best_perf))

          if len(_find_valid_cands(global_step)) > 2:
            for ext in ["meta", "data-00000-of-00001", "index"]:
              src_ckpt = checkpoint_path + ".{}".format(ext)
              tf.logging.info("removing {}".format(src_ckpt))
              tf.gfile.Remove(src_ckpt)
          writer.write("=" * 50 + "\n")

    checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
    result, global_step = get_result(checkpoint_path)
    tf.logging.info("***** Final Eval results *****")
    for key in sorted(result.keys()):
      tf.logging.info("  %s = %s", key, str(result[key]))
      writer.write("%s = %s\n" % (key, str(result[key])))
    writer.write("best perf happened at step: {}".format(global_step))

  if FLAGS.export_dir:
    tf.gfile.MakeDirs(FLAGS.export_dir)
    squad_serving_input_fn = (
        build_squad_serving_input_fn(FLAGS.max_seq_length))
    tf.logging.info("Starting to export model.")
    subfolder = estimator.export_saved_model(
        export_dir_base=os.path.join(FLAGS.export_dir, "saved_model"),
        serving_input_receiver_fn=squad_serving_input_fn)

    tf.logging.info("Starting to export TFLite.")
    converter = tf.lite.TFLiteConverter.from_saved_model(
        subfolder,
        input_arrays=["input_ids", "input_mask", "segment_ids"],
        output_arrays=["start_logits", "end_logits"])
    float_model = converter.convert()
    tflite_file = os.path.join(FLAGS.export_dir, "albert_model.tflite")
    with tf.gfile.GFile(tflite_file, "wb") as f:
      f.write(float_model)


if __name__ == "__main__":
  flags.mark_flag_as_required("spm_model_file")
  flags.mark_flag_as_required("albert_config_file")
  flags.mark_flag_as_required("output_dir")
  tf.app.run()