# Taken from https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10_estimator/cifar10_utils.py
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util


def str_to_bool(in_str):
    if "t" in in_str.lower():
        return True
    else:
        return False


class ExamplesPerSecondHook(session_run_hook.SessionRunHook):
    """Hook to print out examples per second.
    Total time is tracked and then divided by the total number of steps
    to get the average step time and then batch_size is used to determine
    the running average of examples per second. The examples per second for the
    most recent interval is also logged.
  """

    def __init__(self, batch_size, every_n_steps=100, every_n_secs=None):
        """Initializer for ExamplesPerSecondHook.
      Args:
          batch_size: Total batch size used to calculate examples/second from
          global time.
          every_n_steps: Log stats every n steps.
          every_n_secs: Log stats every n seconds.
    """
        if (every_n_steps is None) == (every_n_secs is None):
            raise ValueError(
                "exactly one of every_n_steps" " and every_n_secs should be provided."
            )
        self._timer = basic_session_run_hooks.SecondOrStepTimer(
            every_steps=every_n_steps, every_secs=every_n_secs
        )

        self._step_train_time = 0
        self._total_steps = 0
        self._batch_size = batch_size

    def begin(self):
        self._global_step_tensor = training_util.get_global_step()
        if self._global_step_tensor is None:
            raise RuntimeError("Global step should be created to use StepCounterHook.")

    def before_run(self, run_context):  # pylint: disable=unused-argument
        return basic_session_run_hooks.SessionRunArgs(self._global_step_tensor)

    def after_run(self, run_context, run_values):
        _ = run_context

        global_step = run_values.results
        if self._timer.should_trigger_for_step(global_step):
            elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
                global_step
            )
            if elapsed_time is not None:
                steps_per_sec = elapsed_steps / elapsed_time
                self._step_train_time += elapsed_time
                self._total_steps += elapsed_steps

                average_examples_per_sec = self._batch_size * (
                    self._total_steps / self._step_train_time
                )
                current_examples_per_sec = steps_per_sec * self._batch_size
                # Average examples/sec followed by current examples/sec
                logging.info(
                    "%s: %g (%g), step = %g",
                    "Average examples/sec",
                    average_examples_per_sec,
                    current_examples_per_sec,
                    self._total_steps,
                )