from __future__ import absolute_import
from __future__ import print_function

import pyspark
import h5py
import json
from keras.optimizers import serialize as serialize_optimizer
from keras.models import load_model

from .utils import subtract_params
from .utils import lp_to_simple_rdd
from .utils import model_to_dict
from .mllib import to_matrix, from_matrix, to_vector, from_vector
from .worker import AsynchronousSparkWorker, SparkWorker
from .parameter import HttpServer, SocketServer
from .parameter import HttpClient, SocketClient


class SparkModel(object):

    def __init__(self, model, mode='asynchronous', frequency='epoch',  parameter_server_mode='http', num_workers=None,
                 custom_objects=None, batch_size=32,  port=4000, *args, **kwargs):
        """SparkModel

        Base class for distributed training on RDDs. Spark model takes a Keras
        model as master network, an optimization scheme, a parallelisation mode
        and an averaging frequency.

        :param model: Compiled Keras model
        :param mode: String, choose from `asynchronous`, `synchronous` and `hogwild`
        :param frequency: String, either `epoch` or `batch`
        :param parameter_server_mode: String, either `http` or `socket`
        :param num_workers: int, number of workers used for training (defaults to None)
        :param custom_objects: Keras custom objects
        :param batch_size: batch size used for training and inference
        :param port: port used in case of 'http' parameter server mode
        """

        self._master_network = model
        if not hasattr(model, "loss"):
            raise Exception(
                "Compile your Keras model before initializing an Elephas model with it")
        metrics = model.metrics
        loss = model.loss
        optimizer = serialize_optimizer(model.optimizer)

        if custom_objects is None:
            custom_objects = {}
        if metrics is None:
            metrics = ["accuracy"]
        self.mode = mode
        self.frequency = frequency
        self.num_workers = num_workers
        self.weights = self._master_network.get_weights()
        self.pickled_weights = None
        self.master_optimizer = optimizer
        self.master_loss = loss
        self.master_metrics = metrics
        self.custom_objects = custom_objects
        self.parameter_server_mode = parameter_server_mode
        self.batch_size = batch_size
        self.port = port
        self.kwargs = kwargs

        self.serialized_model = model_to_dict(model)
        if self.mode is not 'synchronous':
            if self.parameter_server_mode == 'http':
                self.parameter_server = HttpServer(
                    self.serialized_model, self.mode, self.port)
                self.client = HttpClient(self.port)
            elif self.parameter_server_mode == 'socket':
                self.parameter_server = SocketServer(self.serialized_model)
                self.client = SocketClient()
            else:
                raise ValueError("Parameter server mode has to be either `http` or `socket`, "
                                 "got {}".format(self.parameter_server_mode))

    @staticmethod
    def get_train_config(epochs, batch_size, verbose, validation_split):
        return {'epochs': epochs,
                'batch_size': batch_size,
                'verbose': verbose,
                'validation_split': validation_split}

    def get_config(self):
        base_config = {
            'parameter_server_mode': self.parameter_server_mode,
            'mode': self.mode,
            'frequency': self.frequency,
            'num_workers': self.num_workers,
            'batch_size': self.batch_size}
        config = base_config.copy()
        config.update(self.kwargs)
        return config

    def save(self, file_name):
        model = self._master_network
        model.save(file_name)
        f = h5py.File(file_name, mode='a')

        f.attrs['distributed_config'] = json.dumps({
            'class_name': self.__class__.__name__,
            'config': self.get_config()
        }).encode('utf8')

        f.flush()
        f.close()

    @property
    def master_network(self):
        return self._master_network

    @master_network.setter
    def master_network(self, network):
        self._master_network = network

    def start_server(self):
        self.parameter_server.start()

    def stop_server(self):
        self.parameter_server.stop()

    def predict(self, data):
        """Get prediction probabilities for a numpy array of features
        """
        return self._master_network.predict(data)

    def predict_classes(self, data):
        """ Predict classes for a numpy array of features
        """
        return self._master_network.predict_classes(data)

    def fit(self, rdd, epochs=10, batch_size=32,
            verbose=0, validation_split=0.1):
        """
        Train an elephas model on an RDD. The Keras model configuration as specified
        in the elephas model is sent to Spark workers, abd each worker will be trained
        on their data partition.

        :param rdd: RDD with features and labels
        :param epochs: number of epochs used for training
        :param batch_size: batch size used for training
        :param verbose: logging verbosity level (0, 1 or 2)
        :param validation_split: percentage of data set aside for validation
        """
        print('>>> Fit model')
        if self.num_workers:
            rdd = rdd.repartition(self.num_workers)

        if self.mode in ['asynchronous', 'synchronous', 'hogwild']:
            self._fit(rdd, epochs, batch_size, verbose, validation_split)
        else:
            raise ValueError(
                "Choose from one of the modes: asynchronous, synchronous or hogwild")

    def _fit(self, rdd, epochs, batch_size, verbose, validation_split):
        """Protected train method to make wrapping of modes easier
        """
        self._master_network.compile(optimizer=self.master_optimizer,
                                     loss=self.master_loss,
                                     metrics=self.master_metrics)
        if self.mode in ['asynchronous', 'hogwild']:
            self.start_server()
        train_config = self.get_train_config(
            epochs, batch_size, verbose, validation_split)
        mode = self.parameter_server_mode
        freq = self.frequency
        optimizer = self.master_optimizer
        loss = self.master_loss
        metrics = self.master_metrics
        custom = self.custom_objects

        yaml = self._master_network.to_yaml()
        init = self._master_network.get_weights()
        parameters = rdd.context.broadcast(init)

        if self.mode in ['asynchronous', 'hogwild']:
            print('>>> Initialize workers')
            worker = AsynchronousSparkWorker(
                yaml, parameters, mode, train_config, freq, optimizer, loss, metrics, custom)
            print('>>> Distribute load')
            rdd.mapPartitions(worker.train).collect()
            print('>>> Async training complete.')
            new_parameters = self.client.get_parameters()
        elif self.mode == 'synchronous':
            worker = SparkWorker(yaml, parameters, train_config,
                                 optimizer, loss, metrics, custom)
            gradients = rdd.mapPartitions(worker.train).collect()
            new_parameters = self._master_network.get_weights()
            for grad in gradients:  # simply accumulate gradients one by one
                new_parameters = subtract_params(new_parameters, grad)
            print('>>> Synchronous training complete.')
        else:
            raise ValueError("Unsupported mode {}".format(self.mode))
        self._master_network.set_weights(new_parameters)
        if self.mode in ['asynchronous', 'hogwild']:
            self.stop_server()


def load_spark_model(file_name):
    model = load_model(file_name)
    f = h5py.File(file_name, mode='r')

    elephas_conf = json.loads(f.attrs.get('distributed_config'))
    class_name = elephas_conf.get('class_name')
    config = elephas_conf.get('config')
    if class_name == "SparkModel":
        return SparkModel(model=model, **config)
    elif class_name == "SparkMLlibModel":
        return SparkMLlibModel(model=model, **config)


class SparkMLlibModel(SparkModel):

    def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_server_mode='http',
                 num_workers=4, elephas_optimizer=None, custom_objects=None, batch_size=32, port=4000, *args, **kwargs):
        """SparkMLlibModel

        The Spark MLlib model takes RDDs of LabeledPoints for training.

        :param model: Compiled Keras model
        :param mode: String, choose from `asynchronous`, `synchronous` and `hogwild`
        :param frequency: String, either `epoch` or `batch`
        :param parameter_server_mode: String, either `http` or `socket`
        :param num_workers: int, number of workers used for training (defaults to None)
        :param custom_objects: Keras custom objects
        :param batch_size: batch size used for training and inference
        :param port: port used in case of 'http' parameter server mode
        """
        SparkModel.__init__(self, model=model, mode=mode, frequency=frequency,
                            parameter_server_mode=parameter_server_mode, num_workers=num_workers,
                            custom_objects=custom_objects,
                            batch_size=batch_size, port=port, *args, **kwargs)

    def fit(self, labeled_points, epochs=10, batch_size=32, verbose=0, validation_split=0.1,
            categorical=False, nb_classes=None):
        """Train an elephas model on an RDD of LabeledPoints
        """
        rdd = lp_to_simple_rdd(labeled_points, categorical, nb_classes)
        rdd = rdd.repartition(self.num_workers)
        self._fit(rdd=rdd, epochs=epochs, batch_size=batch_size,
                  verbose=verbose, validation_split=validation_split)

    def predict(self, mllib_data):
        """Predict probabilities for an RDD of features
        """
        if isinstance(mllib_data, pyspark.mllib.linalg.Matrix):
            return to_matrix(self._master_network.predict(from_matrix(mllib_data)))
        elif isinstance(mllib_data, pyspark.mllib.linalg.Vector):
            return to_vector(self._master_network.predict(from_vector(mllib_data)))
        else:
            raise ValueError(
                'Provide either an MLLib matrix or vector, got {}'.format(mllib_data.__name__))