from abc import ABCMeta, abstractmethod
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import itertools


class Mode:
    TRAIN = 'train'
    EVAL = 'eval'
    PRED = 'pred'


class BaseModel(metaclass=ABCMeta):
    """Base model class.

    Arguments:
        data: A dictionary of `tf.data.Dataset` objects, can include the keys
            `"training"`, `"validation"`, and `"test"`.
        n_gpus: An integer, the number of GPUs available.
        data_shape: A dictionary, where the keys are the input features of the prediction
            network and the values are the associated shapes. Only required if `data` is
            empty or `None`.
        config: A dictionary containing the configuration parameters.
            Entries `"batch_size"` and `"learning_rate"` are required if `data`is given.

    Models should inherit from this class and implement the following methods:
        `_model`, `_loss`, and `_metrics`.
    Additionally, the following static attributes should be defined:
        input_spec: A dictionary, where the keys are the input features (e.g. `"image"`)
            and the associated values are dictionaries containing `"shape"` (list of
            dimensions, e.g. `[N, H, W, C]` where `None` indicates an unconstrained
            dimension) and `"type"` (e.g. `tf.float32`).
        required_config_keys: A list containing the required configuration entries.
        default_config: A dictionary of potential default configuration values.
    """
    dataset_names = set(['training', 'validation', 'test'])
    required_baseconfig = ['batch_size', 'learning_rate']
    _default_config = {'eval_batch_size': 1}

    @abstractmethod
    def _model(self, inputs, mode, **config):
        """Implements the graph of the model.

        This method is called three times: for training, evaluation and prediction (see
        the `mode` argument) and can return different tensors depending on the mode.
        It is a good practice to support both NCHW (channels first) and NHWC (channels
        last) data formats using a dedicated configuration entry.

        Arguments:
            inputs: A dictionary of input features, where the keys are their names
                (e.g. `"image"`) and the values of type `tf.Tensor`. Same keys as in the
                datasets given during the object instantiation.
            mode: An attribute of the `Mode` class, either `Mode.TRAIN`, `Mode.EVAL` or
                `Mode.PRED`.
            config: A configuration dictionary, given during the object instantiantion.

        Returns:
            A dictionary of outputs, where the keys are their names (e.g. `"logits"`) and
            the values are the corresponding `tf.Tensor`.
        """
        raise NotImplementedError

    @abstractmethod
    def _loss(self, outputs, inputs, **config):
        """Implements the sub-graph computing the training loss.

        This method is called on the outputs of the `_model` method in training mode.

        Arguments:
            outputs: A dictionary, as retuned by `_model` called with `mode=Mode.TRAIN`.
            inputs: A dictionary of input features (see same as for `_model`).
            config: A configuration dictionary.

        Returns:
            A tensor corresponding to the loss to be minimized during training.
        """
        raise NotImplementedError

    @abstractmethod
    def _metrics(self, outputs, inputs, **config):
        """Implements the sub-graph computing the evaluation metrics.

        This method is called on the outputs of the `_model` method in evaluation mode.

        Arguments:
            outputs: A dictionary, as retuned by `_model` called with `mode=Mode.EVAL`.
            inputs: A dictionary of input features (see same as for `_model`).
            config: A configuration dictionary.

        Returns:
            A dictionary of metrics, where the keys are their names (e.g. "`accuracy`")
            and the values are the corresponding `tf.Tensor`.
        """
        raise NotImplementedError

    def __init__(self, data={}, n_gpus=1, data_shape=None, **config):
        self.datasets = data
        self.data_shape = data_shape
        self.n_gpus = n_gpus
        self.graph = tf.get_default_graph()
        self.name = self.__class__.__name__.lower()  # get child name

        # Update config
        self.config = self._default_config
        self.config.update(getattr(self, 'default_config', {}))
        self.config.update(config)

        required = getattr(self, 'required_config_keys', [])
        if self.datasets:
            required += self.required_baseconfig
        for r in required:
            assert r in self.config, 'Required configuration entry: \'{}\''.format(r)
        assert set(self.datasets) <= self.dataset_names, \
            'Unknown dataset name: {}'.format(set(self.datasets)-self.dataset_names)
        assert n_gpus > 0, 'TODO: CPU-only training is currently not supported.'

        if data_shape is None:
            self.data_shape = {i: s['shape'] for i, s in self.input_spec.items()}

        with tf.variable_scope('', reuse=tf.AUTO_REUSE):
            self._build_graph()

    def _gpu_tower(self, data, mode):
        # Split the batch between the GPUs (data parallelism)
        with tf.device('/cpu:0'):
            with tf.name_scope('{}_data_sharding'.format(mode)):
                batch_size = self.config['batch_size'] if (mode == Mode.TRAIN) \
                        else self.config['eval_batch_size']
                shards = {d: tf.unstack(v, num=batch_size*self.n_gpus, axis=0)
                          for d, v in data.items()}
                shards = [{d: tf.stack(v[i::self.n_gpus]) for d, v in shards.items()}
                          for i in range(self.n_gpus)]

        # Create towers, i.e. copies of the model for each GPU,
        # with their own loss and gradients.
        tower_losses = []
        tower_gradvars = []
        tower_preds = []
        tower_metrics = []
        for i in range(self.n_gpus):
            worker = '/gpu:{}'.format(i)
            device_setter = tf.train.replica_device_setter(
                    worker_device=worker, ps_device='/cpu:0', ps_tasks=1)
            with tf.name_scope('{}_{}'.format(mode, i)) as scope:
                with tf.device(device_setter):
                    net_outputs = self._model(shards[i], mode, **self.config)
                    if mode == Mode.TRAIN:
                        loss = self._loss(net_outputs, shards[i], **self.config)
                        loss += tf.reduce_sum(
                                tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                                                  scope))
                        model_params = tf.trainable_variables()
                        grad = tf.gradients(loss, model_params)
                        tower_losses.append(loss)
                        tower_gradvars.append(zip(grad, model_params))
                        if i == 0:
                            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                           scope)
                    elif mode == Mode.EVAL:
                        tower_metrics.append(self._metrics(
                            net_outputs, shards[i], **self.config))
                    else:
                        tower_preds.append(net_outputs)

        if mode == Mode.TRAIN:
            return tower_losses, tower_gradvars, update_ops
        elif mode == Mode.EVAL:
            return tower_metrics
        else:
            return tower_preds

    def _train_graph(self, data):
        tower_losses, tower_gradvars, update_ops = self._gpu_tower(data, Mode.TRAIN)

        # Perform the consolidation on CPU
        gradvars = []
        with tf.device('/cpu:0'):
            # Average losses and gradients
            with tf.name_scope('tower_averaging'):
                all_grads = {}
                for grad, var in itertools.chain(*tower_gradvars):
                    if grad is not None:
                        all_grads.setdefault(var, []).append(grad)
                for var, grads in all_grads.items():
                    if len(grads) == 1:
                        avg_grad = grads[0]
                    else:
                        avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
                    gradvars.append((avg_grad, var))
                self.loss = tf.reduce_mean(tower_losses)
                tf.summary.scalar('loss', self.loss)

            # Create optimizer ops
            self.global_step = tf.Variable(0, trainable=False, name='global_step')
            opt = tf.train.RMSPropOptimizer(self.config['learning_rate'])
            with tf.control_dependencies(update_ops):
                self.trainer = opt.apply_gradients(
                        gradvars, global_step=self.global_step)

    def _eval_graph(self, data):
        tower_metrics = self._gpu_tower(data, Mode.EVAL)
        with tf.device('/cpu:0'):
            self.metrics = {m: tf.reduce_mean(tf.stack([t[m] for t in tower_metrics]))
                            for m in tower_metrics[0]}

    def _pred_graph(self, data):
        with tf.name_scope('pred'):
            with tf.device('/gpu:0'):
                pred_out = self._model(data, Mode.PRED, **self.config)
        self.pred_out = {n: tf.identity(p, name=n) for n, p in pred_out.items()}

    def _build_graph(self):
        # Training and evaluation network, if tf datasets provided
        if self.datasets:
            # Generate iterators for the given tf datasets
            self.dataset_iterators = {}
            with tf.device('/cpu:0'):
                for n, d in self.datasets.items():
                    if n == 'training':
                        train_batch = self.config['batch_size']*self.n_gpus
                        d = d.repeat().batch(train_batch).prefetch(train_batch)
                        self.dataset_iterators[n] = d.make_one_shot_iterator()
                    else:
                        d = d.batch(self.config['eval_batch_size']*self.n_gpus)
                        self.dataset_iterators[n] = d.make_initializable_iterator()
                    output_types = d.output_types
                    output_shapes = d.output_shapes
                    self.datasets[n] = d

                    # Perform compatibility checks with the inputs of the child model
                    for i, spec in self.input_spec.items():
                        assert i in output_shapes
                        tf.TensorShape(output_shapes[i]).assert_is_compatible_with(
                                tf.TensorShape(spec['shape']))

                # Used for input shapes of the prediction network
                if self.data_shape is None:
                    self.data_shape = output_shapes

                # Handle for the feedable iterator
                self.handle = tf.placeholder(tf.string, shape=[])
                iterator = tf.data.Iterator.from_string_handle(
                        self.handle, output_types, output_shapes)
                data = iterator.get_next()

            # Build the actual training and evaluation models
            self._train_graph(data)
            self._eval_graph(data)
            self.summaries = tf.summary.merge_all()

        # Prediction network with feed_dict
        self.pred_in = {i: tf.placeholder(self.input_spec[i]['type'], shape=s, name=i)
                        for i, s in self.data_shape.items()}
        self._pred_graph(self.pred_in)

        # Start session
        sess_config = tf.ConfigProto(device_count={'GPU': self.n_gpus})
        sess_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=sess_config)

        # Register tf dataset handles
        if self.datasets:
            self.dataset_handles = {}
            for n, i in self.dataset_iterators.items():
                self.dataset_handles[n] = self.sess.run(i.string_handle())

        self.sess.run([tf.global_variables_initializer(),
                       tf.local_variables_initializer()])

    def train(self, iterations, validation_interval=100, output_dir=None,
              save_interval=None, checkpoint_path=None, keep_checkpoints=1):
        assert 'training' in self.datasets, 'Training dataset is required.'
        if output_dir is not None:
            train_writer = tf.summary.FileWriter(output_dir)
        if not hasattr(self, 'saver'):
            with tf.device('/cpu:0'):
                self.saver = tf.train.Saver(save_relative_paths=True,
                                            max_to_keep=keep_checkpoints)
        if not self.graph.finalized:
            self.graph.finalize()

        tf.logging.info('Start training')
        for i in range(iterations):
            loss, summaries, _ = self.sess.run(
                    [self.loss, self.summaries, self.trainer],
                    feed_dict={self.handle: self.dataset_handles['training']})

            if save_interval and checkpoint_path and i != 0 and i % save_interval == 0:
                self.save(checkpoint_path)
            if 'validation' in self.datasets and i % validation_interval == 0:
                metrics = self.evaluate('validation', mute=True)
                tf.logging.info(
                        'Iter {:4d}: loss {:.4f}'.format(i, loss) +
                        ''.join([', {} {:.4f}'.format(m, metrics[m]) for m in metrics]))

                if output_dir is not None:
                    train_writer.add_summary(summaries, i)
                    metrics_summaries = tf.Summary(value=[
                        tf.Summary.Value(tag=m, simple_value=v)
                        for m, v in metrics.items()])
                    train_writer.add_summary(metrics_summaries, i)
        tf.logging.info('Training finished')

    def predict(self, data, keys='*', batch=False):
        assert set(data.keys()) >= set(self.data_shape.keys())
        if isinstance(keys, str):
            if keys == '*':
                op = self.pred_out  # just gather all outputs
            else:
                op = self.pred_out[keys]
        else:
            op = {k: self.pred_out[k] for k in keys}
        if not batch:  # add batch dimension
            data = {d: [v] for d, v in data.items()}
        feed = {self.pred_in[i]: data[i] for i in self.data_shape}
        pred = self.sess.run(op, feed_dict=feed)
        if not batch:  # remove batch dimension
            if isinstance(pred, dict):
                pred = {p: v[0] for p, v in pred.items()}
            else:
                pred = pred[0]
        return pred

    def evaluate(self, dataset, max_iterations=None, mute=False):
        assert dataset in self.datasets
        self.sess.run(self.dataset_iterators[dataset].initializer)

        if not mute:
            tf.logging.info('Starting evaluation of dataset \'{}\''.format(dataset))
            if max_iterations:
                pbar = tqdm(total=max_iterations, ascii=True)
        i = 0
        metrics = []
        while True:
            try:
                metrics.append(self.sess.run(self.metrics,
                               feed_dict={self.handle: self.dataset_handles[dataset]}))
            except tf.errors.OutOfRangeError:
                break
            if max_iterations:
                i += 1
                if not mute:
                    pbar.update(1)
                if i == max_iterations:
                    break
        if not mute:
            tf.logging.info('Finished evaluation')
            if max_iterations:
                pbar.close()

        # List of dicts to dict of lists
        metrics = dict(zip(metrics[0], zip(*[m.values() for m in metrics])))
        metrics = {m: np.nanmean(metrics[m], axis=0) for m in metrics}
        return metrics

    def _checkpoint_var_search(self, checkpoint_path):
        reader = tf.train.NewCheckpointReader(checkpoint_path)
        saved_shapes = reader.get_variable_to_shape_map()
        model_names = tf.model_variables()  # Used by tf.slim layers
        if not len(tf.model_variables()):
            model_names = tf.global_variables()  # Fallback when slim is not used
        model_names = set([v.name.split(':')[0] for v in model_names])
        checkpoint_names = set(saved_shapes.keys())
        found_names = model_names & checkpoint_names
        missing_names = model_names - checkpoint_names
        shape_conflicts = set()
        restored = []
        with tf.variable_scope('', reuse=True):
            for name in found_names:
                # print(tf.global_variables())
                # print(name, name in model_names, name in checkpoint_names)
                var = tf.get_variable(name)
                var_shape = var.get_shape().as_list()
                if var_shape == saved_shapes[name]:
                    restored.append(var)
                else:
                    shape_conflicts.add(name)
        found_names -= shape_conflicts
        return (restored, sorted(found_names),
                sorted(missing_names), sorted(shape_conflicts))

    def load(self, checkpoint_path, flexible_restore=True):
        if tf.gfile.IsDirectory(checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
            if checkpoint_path is None:
                raise ValueError('Checkpoint directory is empty.')
        if flexible_restore:
            var_list, found, missing, conflicts = self._checkpoint_var_search(
                    checkpoint_path)
            tf.logging.info('Restoring variables: \n\t{}'.format(
                '\n\t'.join(found)))
            if len(missing) > 0:
                tf.logging.info('Variables not found in checkpoint: \n\t{}'.format(
                    '\n\t'.join(missing)))
            if len(conflicts) > 0:
                tf.logging.info('Variables with incompatible shapes: \n\t{}'.format(
                    '\n\t'.join(conflicts)))
        else:
            var_list = None
        with tf.device('/cpu:0'):
            saver = tf.train.Saver(var_list=var_list, save_relative_paths=True)
        saver.restore(self.sess, checkpoint_path)

    def save(self, checkpoint_path):
        step = self.sess.run(self.global_step)
        tf.logging.info('Saving checkpoint for iteration #{}'.format(step))
        self.saver.save(self.sess, checkpoint_path, write_meta_graph=False,
                        global_step=step)

    def close(self):
        self.sess.close()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()