# Copyright 2016 The Nader Akoury. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""The trainer for the convolutional model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import time

import numpy as np
from six import iteritems  # pylint: disable=redefined-builtin
from six import itervalues  # pylint: disable=redefined-builtin
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

from dvae.datasets.dataloader import Data
from dvae.models.factory import Model
import dvae.utils.graph as graph_utils
import dvae.utils.stats as stats_utils


class Tower(object):
    """ A single tower to be trained on a device. """
    def __init__(self, scope, device, models, loss, dataset):
        self.scope = scope
        self.device = device
        self.models = models
        self.dataset = dataset.copy()

        if not isinstance(models, (list, tuple, Model)):
            raise TypeError('models must either be a list, tuple, dvae.factory.Model')

        if isinstance(models, Model):
            models = (models)

        if len(models) == 0:
            raise ValueError('At least one model required for training')

        self.graph = models[0].graph
        for model in models[1:]:
            if self.graph != model.graph:
                raise KeyError('All models must be from the same graph!')

        self._initialize_metrics()
        self._initialize_summaries()
        self._initialize_loss(loss)

    def _initialize_metrics(self):
        """ Initialize the model metrics """
        for model in self.models:
            model.initialize_metrics()

    def _initialize_summaries(self):
        """ Initialize the model summaries """
        for model in self.models:
            model.initialize_summaries()

    def _initialize_loss(self, loss):
        """ Initialize the tower loss """
        update_ops = self.update_ops
        if len(update_ops) > 0:
            with tf.control_dependencies(update_ops):
                self.loss = tf.identity(loss)

    def get_collection(self, key):
        """ Get all the variables of the models in the tower """
        collection = set()
        for model in self.models:
            collection.update(model.get_collection(key))

        return collection

    @property
    def global_variables(self):
        """ Get all the variables of the models in the tower """
        return self.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

    @property
    def trainable_variables(self):
        """ Get the trainable variables of the models in the tower """
        return self.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

    @property
    def update_ops(self):
        """ Get the update operations of the models in the tower """
        return self.get_collection(tf.GraphKeys.UPDATE_OPS)

    def compute_gradients(self, optimizer):
        """ Compute the gradients for the model """
        var_list = set()
        for model in self.models:
            var_list.update(model.trainable_variables)

        with tf.name_scope(self.scope), tf.device(self.device):
            kwargs = {'var_list': var_list}
            grads_and_vars = optimizer.compute_gradients(self.loss, **kwargs)

            with tf.name_scope('clip_gradients'):
                return [
                    (tf.clip_by_value(grad, -5, 5), var)
                    for grad, var in grads_and_vars
                    if grad is not None]

    def feed(self, feed_dict, data, training=False):
        """ Feed all the models in the tower """
        for model in self.models:
            model.feed(feed_dict, data, training=training)

    def collect_summaries(self, collection, summaries):
        """ Collect summaries for all the models in the tower """
        for model in self.models:
            model.collect_summaries(collection, summaries)

    def collect_metrics(self, collection, metrics):
        """ Collect the metrics from the models for the values being fetched """
        for model in self.models:
            model.collect_metrics(collection, metrics)


class ModelTrainer(object):
    """ Class used to train a model. """
    def __init__(self, session, batch_size, towers,
                 decay_step, learning_rate, summary_writer, name=None):
        if not isinstance(towers, (list, tuple, Tower)):
            raise TypeError('towers must either be a list, tuple, dvae.trainer.Tower')

        if isinstance(towers, Model):
            towers = (towers)

        if len(towers) == 0:
            raise ValueError('At least one tower required for training')

        self.data = towers[0].dataset
        self.graph = towers[0].graph
        for tower in towers[1:]:
            if self.graph != tower.graph:
                raise KeyError('All towers must be from the same graph!')

        self.batch_size = batch_size
        self.name = name
        self.session = session
        self.summary_writer = summary_writer
        self.timer = ((-1, 0), (-1, 0))
        self.towers = towers

        with self.graph.as_default(), tf.variable_scope(self.name, default_name='trainer'):
            self._initialize_metrics()
            self._init_training(decay_step, learning_rate)
            self._init_variables()

            with stats_utils.summary_scope('training', graph=self.graph):
                self._init_summaries()

    @property
    def train_dir(self):
        """ Get the training directory """
        return self.summary_writer.get_logdir()

    @property
    def global_step(self):
        """ Get the current global step value """
        return tf.train.global_step(self.session, self.global_step_tensor)

    @property
    def training_samples(self):
        """ Get the total number of training samples """
        return len(self.data.train)

    def _init_training(self, decay_step, learning_rate):
        """ Initialization of the training parameters """
        self.global_step_tensor = tf.Variable(
            tf.constant(0, tf.int32, shape=[]), False,
            name='global_step', collections=[tf.GraphKeys.LOCAL_VARIABLES])

        self.learning_rate = tf.maximum(tf.train.exponential_decay(
            learning_rate, self.global_step_tensor * self.batch_size,
            int(decay_step * self.training_samples), 0.1, staircase=True), 1e-5)

        self.optimizer = tf.train.AdamOptimizer(self.learning_rate)

        self.gradients = []
        with tf.name_scope('compute_gradients'):
            for grads in zip(*[tower.compute_gradients(self.optimizer) for tower in self.towers]):
                gradient = tf.reduce_mean(tf.stack([grad for grad, _ in grads]), 0)
                self.gradients.append((gradient, grads[0][1]))

        enable_training = graph_utils.set_training(True, graph=self.graph, session=self.session)
        disable_training = graph_utils.set_training(False, graph=self.graph, session=self.session)

        self.summary_operation = disable_training

        self.evaluation_operations = {}
        for data_scope in (Data.VALIDATE, Data.TEST):
            self.evaluation_operations[data_scope] = tf.group(
                disable_training, *self.update_metrics[data_scope])

        with tf.control_dependencies([enable_training]):
            update_ops = self._update_ops()
            with tf.control_dependencies([update_ops]):
                apply_gradients = self.optimizer.apply_gradients(
                    self.gradients, global_step=self.global_step_tensor)

                self.training_operation = tf.group(
                    enable_training, update_ops,
                    self.learning_rate, apply_gradients,
                    *self.update_metrics[Data.TRAIN])

    def _update_ops(self):
        """ Get all the update ops for the towers """
        update_ops = []
        for tower in self.towers:
            update_ops.extend(tower.update_ops)

        return tf.group(*update_ops)

    def _initialize_metrics(self):
        """ Initialize the model metrics """
        self.metrics = {}
        self.metric_values = {}
        self.update_metrics = {}
        self.reset_metrics = {}
        for data_scope in (Data.TRAIN, Data.VALIDATE, Data.TEST):
            metrics = self.collect_metrics(data_scope)
            self.metrics[data_scope] = metrics

            self.metric_values[data_scope] = {
                name: metric['scalar']
                for name, metric in iteritems(metrics)}

            self.update_metrics[data_scope] = [
                metric['update_op']
                for metric in itervalues(metrics)]

            metric_variables = []
            with stats_utils.metric_scope(data_scope, graph=self.graph) as scope:
                for local in tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope):
                    metric_variables.append(local)
            self.reset_metrics[data_scope] = tf.variables_initializer(metric_variables)

    def _init_variables(self):
        """ Create the initialization operation for the variables """
        # Adam optimizer uses two variables that can only be accessed through the use of a protected
        # function since the variables aren't scoped in anyway. Trying to add a tf.variable_scope
        # around apply_gradients where the variables are created did not help.
        var_list = set(self.optimizer._get_beta_accumulators()) # pylint: disable=protected-access
        slot_names = self.optimizer.get_slot_names()
        for tower in self.towers:
            variables = tower.global_variables
            var_list.update(variables)

            for slot_name in slot_names:
                for variable in variables:
                    slot = self.optimizer.get_slot(variable, slot_name)
                    if slot is not None:
                        var_list.add(slot)

        # Initialize all the variables
        self.initialization_operation = tf.group(
            tf.variables_initializer(var_list),

            # Apparently local variables are not part of 'all' variables... go figure
            # This is needed for metrics for example
            tf.local_variables_initializer())

    def _init_summaries(self):
        """ Initialization of the training parameters """
        summaries = []
        summaries.append(tf.summary.scalar('learning_rate', self.learning_rate))

        for gradient, variable in self.gradients:
            if gradient is not None:
                summaries.append(tf.summary.scalar(
                    variable.op.name + '/gradient_avg',
                    tf.reduce_mean(gradient)))

                summaries.append(tf.summary.histogram(
                    variable.op.name + '/gradients', gradient))

        for variable in tf.trainable_variables():
            summaries.append(tf.summary.histogram(variable.op.name, variable))

        self.training_summary = tf.summary.merge(summaries)

    def write_summaries(self, data):
        """ Write the summaries for the current step and collection """
        summaries = [self.summary_operation]
        if data.collection == Data.TRAIN:
            summaries.append(self.training_summary)

        feed_dict = {}
        for tower in self.towers:
            tower.feed(feed_dict, data)
            tower.collect_summaries(data.collection, summaries)

        for summary in summaries:
            self.summary_writer.add_summary(self.session.run(
                summary, feed_dict=feed_dict), global_step=self.global_step)

    def collect_metrics(self, collection):
        """ Collect the metrics for the values being fetched """
        metrics = {}
        for tower in self.towers:
            tower.collect_metrics(collection, metrics)

        return metrics

    def extract_metrics(self, metrics, values):
        """ Extract the metrics from the passed in values """
        for key, value in iteritems(values):
            if key in metrics:
                metrics[key]['value'] = value

        return metrics

    def run(self, operation, data, run_options=None, run_metadata=None):
        """ Execute a single batch for the given data """
        training = (data.collection == Data.TRAIN)

        feed_dict = {}
        for tower in self.towers:
            tower.feed(feed_dict, data, training=training)

        values = self.session.run(
            operation, feed_dict=feed_dict,
            options=run_options, run_metadata=run_metadata)

        metric_values = self.session.run(self.metric_values[data.collection], feed_dict=feed_dict)
        return values, self.extract_metrics(self.metrics[data.collection], metric_values)

    def evaluate(self, data):
        """ Evaluate the model with the given data. """
        # Reset any metrics before evaluating
        if data.collection not in self.evaluation_operations:
            raise ValueError('Cannot evaluate data from {0}'.format(data.collection))

        self.session.run(self.reset_metrics[data.collection])
        for step in xrange(int(np.ceil(len(data) / self.batch_size))):
            offset = (step * self.batch_size)
            batch = data[offset:(offset + self.batch_size), ...]
            _, metrics = self.run(self.evaluation_operations[data.collection], batch)

        self.write_summaries(data[:self.batch_size, ...])
        self.output_metrics(data, metrics)

    def optimize(self, data, with_metrics=False, with_trace=False):
        """ Optimize a single batch """
        run_metadata = tf.RunMetadata() if with_trace else None
        trace = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) if with_trace else None

        _, metrics = self.run(
            self.training_operation, data,
            run_options=trace, run_metadata=run_metadata)

        if with_metrics:
            self.timer_update()
            steps, elapsed = self.elapsed()
            num_devices = len(self.towers)
            examples = steps * self.batch_size * num_devices
            print('Step {}, examples/sec {:.3f}, ms/batch {:.1f}'.format(
                self.global_step, examples / elapsed, 1000 * elapsed / num_devices))

            self.output_metrics(data, metrics)
            self.write_summaries(data)

        if with_trace:
            step = '{}/step{}'.format(self.name, self.global_step)
            self.summary_writer.add_run_metadata(run_metadata, step, global_step=self.global_step)

    def reset_timer(self):
        """ Reset the timer """
        self.timer = ((self.global_step, time.time()), (-1, 0))

    def timer_update(self):
        """ Update the current training timer """
        if self.timer[0][0] == self.global_step:
            return

        self.timer = ((self.global_step, time.time()), self.timer[0])

    def elapsed(self):
        """ Return the elapsed steps and time since the last training update """
        return tuple(np.subtract(*zip(self.timer)).squeeze())

    def output_metrics(self, data, metrics):
        """ Output the current training metrics """
        print('{} {}'.format(data.collection, ', '.join(
            [('{}: ' + metric['format']).format(name, metric['value'])
             for name, metric in iteritems(metrics)])))
        sys.stdout.flush()

    def shuffle(self):
        """ Shuffle the data in each tower """
        for tower in self.towers:
            tower.dataset.train.shuffle()

    def train(self, num_epochs, metric_frequency=0, validation_frequency=0, trace_frequency=0):
        """Generate a classification prediction using the passed in data """
        self.session.run(self.initialization_operation)
        self.reset_timer()

        # TODO: Have num_epochs account for training with multiple GPUs
        for _ in xrange(num_epochs):
            self.shuffle()

            for step in xrange(int(np.ceil(self.training_samples / self.batch_size))):
                offset = (step * self.batch_size)
                data = self.data.train[offset:(offset + self.batch_size), ...]

                # Global step is initialized to zero and isn't updated until a call to optimize
                # so need to add 1 for the current step being processed
                global_step = self.global_step + 1

                with_trace = trace_frequency > 0 and global_step % trace_frequency == 0
                with_metrics = metric_frequency > 0 and global_step % metric_frequency == 0
                self.optimize(data, with_metrics=with_metrics, with_trace=with_trace)

                if validation_frequency > 0 and global_step % validation_frequency == 0:
                    self.evaluate(self.data.validation)

        self.evaluate(self.data.test)