# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example implementation of code to run on the Cloud ML service.

This file is generic and can be reused by other models without modification.
The only assumption this module has is that there exists model module that
implements create_model() function. The function creates class implementing
problem specific implementations of build_train_graph(), build_eval_graph(),
build_prediction_graph() and format_metric_values().
"""

import argparse
import json
import logging
import os
import shutil
import subprocess
import time
import uuid

import model as model_lib
import tensorflow as tf
from tensorflow.python.lib.io import file_io

import mark_done

class Evaluator(object):
  """Loads variables from latest checkpoint and performs model evaluation."""

  def __init__(self, args, model, data_paths, dataset='eval'):
    self.eval_batch_size = args.eval_batch_size
    self.num_eval_batches = args.eval_set_size // self.eval_batch_size
    self.batch_of_examples = []
    self.checkpoint_path = train_dir(args.output_path)
    self.output_path = os.path.join(args.output_path, dataset)
    self.eval_data_paths = data_paths
    self.batch_size = args.batch_size
    self.stream = args.streaming_eval
    self.model = model

  def evaluate(self, num_eval_batches=None):
    """Run one round of evaluation, return loss and accuracy."""

    num_eval_batches = num_eval_batches or self.num_eval_batches
    with tf.Graph().as_default() as graph:
      self.tensors = self.model.build_eval_graph(self.eval_data_paths,
                                                 self.eval_batch_size)

      self.summary = tf.summary.merge_all()

      self.saver = tf.train.Saver()

    self.summary_writer = tf.summary.FileWriter(self.output_path)
    self.sv = tf.train.Supervisor(
        graph=graph,
        logdir=self.output_path,
        summary_op=None,
        global_step=None,
        saver=self.saver)

    last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path)
    with self.sv.managed_session(
        master='', start_standard_services=False) as session:
      self.sv.saver.restore(session, last_checkpoint)

      if self.stream:
        self.sv.start_queue_runners(session)
        for _ in range(num_eval_batches):
          session.run(self.tensors.metric_updates)
      else:
        if not self.batch_of_examples:
          print("no batches")
          self.sv.start_queue_runners(session)
          for i in range(num_eval_batches):
            self.batch_of_examples.append(session.run(self.tensors.examples))
        else:
          print("***************  Already have batches")
          print(len(self.batch_of_examples))
          print("***************  resetting")
          self.batch_of_examples = []
          self.sv.start_queue_runners(session)
          for i in range(num_eval_batches):
            self.batch_of_examples.append(session.run(self.tensors.examples))

        for i in range(num_eval_batches):
          try:
            session.run(self.tensors.metric_updates,
                          {self.tensors.examples: self.batch_of_examples[i]})
          except IndexError as e:
            #  bombs here
            print("******************BOMB ZONE")
            print("num_eval_batches")
            print(num_eval_batches)
            print("len  batch_of_examples")
            print(len(self.batch_of_examples))
            raise

      metric_values = session.run(self.tensors.metric_values)
      global_step = tf.train.global_step(session, self.tensors.global_step)
      summary = session.run(self.summary)
      self.summary_writer.add_summary(summary, global_step)
      self.summary_writer.flush()
      self.batch_of_examples = []
      return metric_values

  def write_predictions(self):
    """Run one round of predictions and write predictions to csv file."""
    num_eval_batches = self.num_eval_batches + 1
    with tf.Graph().as_default() as graph:
      self.tensors = self.model.build_eval_graph(self.eval_data_paths,
                                                 self.batch_size)
      self.saver = tf.train.Saver()
    self.sv = tf.train.Supervisor(
        graph=graph,
        logdir=self.output_path,
        summary_op=None,
        global_step=None,
        saver=self.saver)

    last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path)
    with self.sv.managed_session(
        master='', start_standard_services=False) as session:
      self.sv.saver.restore(session, last_checkpoint)

      with file_io.FileIO(os.path.join(self.output_path,
                                       'predictions.csv'), 'w') as f:
        to_run = [self.tensors.keys] + self.tensors.predictions
        self.sv.start_queue_runners(session)
        last_log_progress = 0
        for i in range(num_eval_batches):
          progress = i * 100 // num_eval_batches
          if progress > last_log_progress:
            logging.info('%3d%% predictions processed', progress)
            last_log_progress = progress

          res = session.run(to_run)
          for element in range(len(res[0])):
            f.write('%s' % res[0][element])
            for prediction in res[1:]:
              f.write(',')
              f.write(str(prediction[element]))
            f.write('\n')


class Trainer(object):
  """Performs model training and optionally evaluation."""

  def __init__(self, args, model, cluster, task):
    self.args = args
    self.model = model
    self.cluster = cluster
    self.task = task
    self.evaluator = Evaluator(self.args, self.model, self.args.eval_data_paths,
                               'eval_set')
    self.train_evaluator = Evaluator(self.args, self.model,
                                     self.args.train_data_paths, 'train_set')
    self.min_train_eval_rate = args.min_train_eval_rate

  def run_training(self):
    """Runs a Master."""
    ensure_output_path(self.args.output_path)
    self.train_path = train_dir(self.args.output_path)
    self.model_path = model_dir(self.args.output_path)
    self.is_master = self.task.type != 'worker'
    log_interval = self.args.log_interval_secs
    self.eval_interval = self.args.eval_interval_secs
    if self.is_master and self.task.index > 0:
      raise StandardError('Only one replica of master expected')

    if self.cluster:
      logging.info('Starting %s/%d', self.task.type, self.task.index)
      server = start_server(self.cluster, self.task)
      target = server.target
      device_fn = tf.train.replica_device_setter(
          ps_device='/job:ps',
                    worker_device='/job:%s/task:%d' % (self.task.type, self.task.index),
          cluster=self.cluster)
      # We use a device_filter to limit the communication between this job
      # and the parameter servers, i.e., there is no need to directly
      # communicate with the other workers; attempting to do so can result
      # in reliability problems.
      device_filters = [
          '/job:ps', '/job:%s/task:%d' % (self.task.type, self.task.index)
      ]
      config = tf.ConfigProto(device_filters=device_filters)
    else:
      target = ''
      device_fn = ''
      config = None

    with tf.Graph().as_default() as graph:
      with tf.device(device_fn):
        # Build the training graph.
        self.tensors = self.model.build_train_graph(self.args.train_data_paths,
                                                    self.args.batch_size)

        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints.
        self.saver = tf.train.Saver()

        self.summary_op = tf.summary.merge_all()

    # Create a "supervisor", which oversees the training process.
    self.sv = tf.train.Supervisor(
        graph,
        is_chief=self.is_master,
        logdir=self.train_path,
        init_op=init_op,
        saver=self.saver,
        # Write summary_ops by hand.
        summary_op=None,
        global_step=self.tensors.global_step,
        # No saving; we do it manually in order to easily evaluate immediately
        # afterwards.
        save_model_secs=0)

    should_retry = True
    to_run = [self.tensors.global_step, self.tensors.train]

    while should_retry:
      try:
        should_retry = False
        with self.sv.managed_session(target, config=config) as session:
          self.start_time = start_time = time.time()
          self.last_save = self.last_log = 0
          self.global_step = self.last_global_step = 0
          self.local_step = self.last_local_step = 0
          self.last_global_time = self.last_local_time = start_time

          # Loop until the supervisor shuts down or args.max_steps have
          # completed.
          max_steps = self.args.max_steps
          while not self.sv.should_stop() and self.global_step < max_steps:
            try:
              # Run one step of the model.
              self.global_step = session.run(to_run)[0]
              self.local_step += 1

              self.now = time.time()
              is_time_to_eval = (self.now - self.last_save) > self.eval_interval
              is_time_to_log = (self.now - self.last_log) > log_interval
              should_eval = self.is_master and is_time_to_eval
              should_log = is_time_to_log or should_eval

              if should_log:
                self.log(session)

              if should_eval:
                self.eval(session)

            except tf.errors.AbortedError:
              should_retry = True

          if self.is_master:
            # Take the final checkpoint and compute the final accuracy.
            self.eval(session)

            # Export the model for inference.
            self.model.export(
                tf.train.latest_checkpoint(self.train_path), self.model_path)
            
            # Creates a "Done" fine to indicate that the rest of the model deployment can
            # continue
            mark_done.mark_done(self.model_path)

      except tf.errors.AbortedError:
        should_retry = True

    # Ask for all the services to stop.
    self.sv.stop()

  def log(self, session):
    """Logs training progress."""
    logging.info('Train [%s/%d], step %d (%.3f sec) %.1f '
                 'global steps/s, %.1f local steps/s', self.task.type,
                 self.task.index, self.global_step,
                 (self.now - self.start_time),
                 (self.global_step - self.last_global_step) /
                 (self.now - self.last_global_time),
                 (self.local_step - self.last_local_step) /
                 (self.now - self.last_local_time))

    self.last_log = self.now
    self.last_global_step, self.last_global_time = self.global_step, self.now
    self.last_local_step, self.last_local_time = self.local_step, self.now

  def eval(self, session):
    """Runs evaluation loop."""
    eval_start = time.time()
    self.saver.save(session, self.sv.save_path, self.tensors.global_step)
    logging.info(
        'Eval, step %d:\n- on train set %s\n-- on eval set %s',
        self.global_step,
        self.model.format_metric_values(self.train_evaluator.evaluate()),
        self.model.format_metric_values(self.evaluator.evaluate()))
    now = time.time()

    # Make sure eval doesn't consume too much of total time.
    eval_time = now - eval_start
    train_eval_rate = self.eval_interval / eval_time
    if train_eval_rate < self.min_train_eval_rate and self.last_save > 0:
      logging.info('Adjusting eval interval from %.2fs to %.2fs',
                   self.eval_interval, self.min_train_eval_rate * eval_time)
      self.eval_interval = self.min_train_eval_rate * eval_time

    self.last_save = now
    self.last_log = now

  def save_summaries(self, session):
    self.sv.summary_computed(session,
                             session.run(self.summary_op), self.global_step)
    self.sv.summary_writer.flush()


def main(_):
  model, argv = model_lib.create_model()
  run(model, argv)


def run(model, argv):
  """Runs the training loop."""
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--train_data_paths',
      type=str,
      action='append',
      help='The paths to the training data files. '
      'Can be comma separated list of files or glob pattern.')
  parser.add_argument(
      '--eval_data_paths',
      type=str,
      action='append',
      help='The path to the files used for evaluation. '
      'Can be comma separated list of files or glob pattern.')
  parser.add_argument(
      '--output_path',
      type=str,
      help='The path to which checkpoints and other outputs '
      'should be saved. This can be either a local or GCS '
      'path.')
  parser.add_argument(
      '--max_steps',
      type=int,)
  parser.add_argument(
      '--batch_size',
      type=int,
      help='Number of examples to be processed per mini-batch.')
  parser.add_argument(
      '--eval_set_size', type=int, help='Number of examples in the eval set.')
  parser.add_argument(
      '--eval_batch_size', type=int, help='Number of examples per eval batch.')
  parser.add_argument(
      '--eval_interval_secs',
      type=float,
      default=5,
      help='Minimal interval between calculating evaluation metrics and saving'
      ' evaluation summaries.')
  parser.add_argument(
      '--log_interval_secs',
      type=float,
      default=5,
      help='Minimal interval between logging training metrics and saving '
      'training summaries.')
  parser.add_argument(
      '--write_predictions',
      action='store_true',
      default=False,
      help='If set, model is restored from latest checkpoint '
      'and predictions are written to a csv file and no training is performed.')
  parser.add_argument(
      '--min_train_eval_rate',
      type=int,
      default=20,
      help='Minimal train / eval time ratio on master. '
      'Default value 20 means that 20x more time is used for training than '
      'for evaluation. If evaluation takes more time the eval_interval_secs '
      'is increased.')
  parser.add_argument(
      '--write_to_tmp',
      action='store_true',
      default=False,
      help='If set, all checkpoints and summaries are written to '
      'local filesystem (/tmp/) and copied to gcs once training is done. '
      'This can speed up training but if training job fails all the summaries '
      'and checkpoints are lost.')
  parser.add_argument(
      '--copy_train_data_to_tmp',
      action='store_true',
      default=False,
      help='If set, training data is copied to local filesystem '
      '(/tmp/). This can speed up training but requires extra space on the '
      'local filesystem.')
  parser.add_argument(
      '--copy_eval_data_to_tmp',
      action='store_true',
      default=False,
      help='If set, evaluation data is copied to local filesystem '
      '(/tmp/). This can speed up training but requires extra space on the '
      'local filesystem.')
  parser.add_argument(
      '--streaming_eval',
      action='store_true',
      default=False,
      help='If set to True the evaluation is performed in streaming mode. '
      'During each eval cycle the evaluation data is read and parsed from '
      'files. This allows for having very large evaluation set. '
      'If set to False (default) evaluation data is read once and cached in '
      'memory. This results in faster evaluation cycle but can potentially '
      'use more memory (in streaming mode large per-file read-ahead buffer is '
      'used - which may exceed eval data size).')
  parser.add_argument(
      '--label_count',
      type=int)

  args, _ = parser.parse_known_args(argv)

  env = json.loads(os.environ.get('TF_CONFIG', '{}'))

  # Print the job data as provided by the service.
  logging.info('Original job data: %s', env.get('job', {}))

  # First find out if there's a task value on the environment variable.
  # If there is none or it is empty define a default one.
  task_data = env.get('task', None) or {'type': 'master', 'index': 0}
  task = type('TaskSpec', (object,), task_data)
  trial = task_data.get('trial')

  logging.info("output_path = " + args.output_path)
  if trial is not None:
    args.output_path = os.path.join(args.output_path, trial)
  if args.write_to_tmp and args.output_path.startswith('gs://'):
    output_path = args.output_path
    args.output_path = os.path.join('/tmp/', str(uuid.uuid4()))
    os.makedirs(args.output_path)
  else:
    output_path = None
  logging.info("output_path = " + args.output_path)
  if args.copy_train_data_to_tmp:
    args.train_data_paths = copy_data_to_tmp(args.train_data_paths)
  if args.copy_eval_data_to_tmp:
    args.eval_data_paths = copy_data_to_tmp(args.eval_data_paths)

  if not args.eval_batch_size:
    # If eval_batch_size not set, use min of batch_size and eval_set_size
    args.eval_batch_size = min(args.batch_size, args.eval_set_size)
    logging.info("setting eval batch size to %s", args.eval_batch_size)

  cluster_data = env.get('cluster', None)
  cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
  if args.write_predictions:
    write_predictions(args, model, cluster, task)
  else:
    dispatch(args, model, cluster, task)

  if output_path and (not cluster or not task or task.type == 'master'):
    subprocess.check_call([
        'gsutil', '-m', '-q', 'cp', '-r', args.output_path + '/*', output_path
    ])
    shutil.rmtree(args.output_path, ignore_errors=True)

def copy_data_to_tmp(input_files):
  """Copies data to /tmp/ and returns glob matching the files."""
  files = []
  for e in input_files:
    for path in e.split(','):
      files.extend(file_io.get_matching_files(path))

  for path in files:
    if not path.startswith('gs://'):
      return input_files

  tmp_path = os.path.join('/tmp/', str(uuid.uuid4()))
  os.makedirs(tmp_path)
  subprocess.check_call(['gsutil', '-m', '-q', 'cp', '-r'] + files + [tmp_path])
  return [os.path.join(tmp_path, '*')]


def write_predictions(args, model, cluster, task):
  if not cluster or not task or task.type == 'master':
    pass  # Run locally.
  else:
    raise ValueError('invalid task_type %s' % (task.type,))

  logging.info('Starting to write predictions on %s/%d', task.type, task.index)
  evaluator = Evaluator(args, model, args.eval_data_paths)
  evaluator.write_predictions()
  logging.info('Done writing predictions on %s/%d', task.type, task.index)


def dispatch(args, model, cluster, task):
  if not cluster or not task or task.type == 'master':
    # Run locally.
    Trainer(args, model, cluster, task).run_training()
  elif task.type == 'ps':
    run_parameter_server(cluster, task)
  elif task.type == 'worker':
    Trainer(args, model, cluster, task).run_training()
  else:
    raise ValueError('invalid task_type %s' % (task.type,))


def run_parameter_server(cluster, task):
  logging.info('Starting parameter server %d', task.index)
  server = start_server(cluster, task)
  server.join()


def start_server(cluster, task):
  if not task.type:
    raise ValueError('--task_type must be specified.')
  if task.index is None:
    raise ValueError('--task_index must be specified.')

  # Create and start a server.
  return tf.train.Server(
      tf.train.ClusterSpec(cluster),
      protocol='grpc',
      job_name=task.type,
      task_index=task.index)


def ensure_output_path(output_path):
  if not output_path:
    raise ValueError('output_path must be specified')

  # GCS doesn't have real directories.
  if output_path.startswith('gs://'):
    return

  ensure_dir(output_path)


def ensure_dir(path):
  try:
    os.makedirs(path)
  except OSError as e:
    # If the directory already existed, ignore the error.
    if e.args[0] == 17:
      pass
    else:
      raise


def train_dir(output_path):
  return os.path.join(output_path, 'train')


def eval_dir(output_path):
  return os.path.join(output_path, 'eval')


def model_dir(output_path):
  return os.path.join(output_path, 'model')


if __name__ == '__main__':
  logging.basicConfig(level=logging.INFO)
  tf.app.run()