#!/usr/bin/env python
# Copyright 2017 IIE, CAS.
# Written by Shancheng Fang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

"""Main script to run training and evaluation of models.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import logging
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import Experiment
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn.estimators import run_config

import config
import datasets
from model.model import Model
from utils.hooks import Prediction, FalsePrediction
from utils.metrics import sequence_accuracy, char_accuracy

FLAGS = tf.flags.FLAGS
config.define()

def _create_dataset_params():
  """Create dataset params
  """
  dparams = {
      "dataset_name": FLAGS.dataset_name,
      "dataset_dir": FLAGS.dataset_dir,
      "batch_size": FLAGS.batch_size
  }

  if FLAGS.schedule == 'train':
    split_name = FLAGS.split_name or 'train'
    dparams.update({
        'shuffle': True,
        'smaller_final_batch': False,
        'num_epochs': None,
        'split_name': split_name})
  elif FLAGS.schedule == 'evaluate':
    split_name = FLAGS.split_name or 'test'
    dparams.update({
        'shuffle': False,
        'smaller_final_batch': True,
        'num_epochs': 1,
        'split_name': split_name})
  else:
    split_name = FLAGS.split_name or 'test'
    dparams.update({
        'shuffle': False,
        'smaller_final_batch': False,
        'num_epochs': None,
        'split_name': split_name})
  return dparams

def _create_model_params(dataset):
  """Create model params
  """
  mparams = {
      "optimizer": FLAGS.optimizer,
      "learning_rate": FLAGS.learning_rate,
      "clip_gradients": FLAGS.clip_gradients,
      "dataset": dataset.params,
      "optimizer_params": {
          "momentum": FLAGS.momentum,
          "use_nesterov": FLAGS.use_nesterov
      },
      "summary": FLAGS.summary,
      "max_outputs": FLAGS.max_outputs,
      "beam_width": FLAGS.beam_width,
      "output_dir": FLAGS.output_dir,
      "checkpoint": FLAGS.checkpoint
  }
  return mparams

def _create_hooks(mparams, output_dir):
  """Create hooks
  """
  # Create training hooks
  train_hooks = []
  # Create evaluating hooks and eval config
  eval_hooks = []

  # Write prediction to file
  prediction_hook = Prediction(mparams, FLAGS.output_dir)
  eval_hooks.append(prediction_hook)

  # Write false prediction to file
  false_prediction_hook = FalsePrediction(mparams, FLAGS.output_dir)
  eval_hooks.append(false_prediction_hook)

  if FLAGS.schedule == 'continuous_eval':
    eval_output_dir = os.path.join(output_dir, 'eval_continuous')
    eval_hooks.append(tf.contrib.training.SummaryAtEndHook(eval_output_dir))
  elif FLAGS.schedule == 'evaluate':
    # stop until data are exhausted
    FLAGS.eval_steps = None

  if FLAGS.debug:
    from tensorflow.python import debug as tf_debug
    debug_hook = tf_debug.LocalCLIDebugHook()
    train_hooks.append(debug_hook)
    eval_hooks.append(debug_hook)
  return train_hooks, eval_hooks

def _create_experiment(output_dir):
  """
  Creates a new Experiment instance.

  Args:
    output_dir: Output directory for model checkpoints and summaries.
  """
  # Runconfig
  session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
      per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction,
      allow_growth=FLAGS.gpu_allow_growth))
  estimator_config = run_config.RunConfig(
      session_config=session_config,
      gpu_memory_fraction=FLAGS.gpu_memory_fraction,
      tf_random_seed=FLAGS.tf_random_seed,
      log_step_count_steps=FLAGS.log_step,
      save_checkpoints_secs=FLAGS.save_checkpoints_secs,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      keep_checkpoint_max=FLAGS.keep_checkpoint_max,
      keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours)

  # Dataset
  mode = tf.contrib.learn.ModeKeys.TRAIN if FLAGS.schedule == 'train' \
                                    else tf.contrib.learn.ModeKeys.EVAL
  dataset = datasets.create_dataset(
      def_dict=_create_dataset_params(),
      mode=mode,
      use_beam_search=FLAGS.beam_width)

  # Model function
  def model_fn(features, labels, params, mode):
    """Builds the model graph"""
    model = Model(params, mode)
    predictions, loss, train_op = model(features, labels)
    eval_metrics = {
        'character': char_accuracy(predictions['predicted_ids'],
                                   labels['label']),
        'sequence': sequence_accuracy(predictions['predicted_ids'],
                                      labels['label'])
    }
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metrics)
  # Model parameters
  mparams = _create_model_params(dataset)
  # Estimator
  estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      model_dir=output_dir,
      config=estimator_config,
      params=mparams)

  train_hooks, eval_hooks = _create_hooks(mparams, output_dir)

  if FLAGS.schedule != 'train':
    # log to file
    file_name = "{}-tensorflow.log".format(mparams['dataset']['dataset_name'])
    file_name = os.path.join(FLAGS.output_dir, file_name)
    log = logging.getLogger('tensorflow')
    handle = logging.FileHandler(file_name)
    log.addHandler(handle)

  return Experiment(
      estimator=estimator,
      train_input_fn=dataset.create_input_fn,
      eval_input_fn=dataset.create_input_fn,
      train_steps=FLAGS.train_steps,
      eval_steps=FLAGS.eval_steps,
      train_monitors=train_hooks,
      eval_hooks=eval_hooks,
      eval_delay_secs=0)

def main(_argv):
  """Main function
  """
  schedules = ['train', 'evaluate', 'continuous_eval']
  assert FLAGS.schedule in schedules,\
                      "Only schedules: %s supported!"%(','.join(schedules))

  learn_runner.run(
      experiment_fn=_create_experiment,
      output_dir=FLAGS.output_dir,
      schedule=FLAGS.schedule)

if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.app.run()