# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf


class MixtureModel:

    def __init__(self, data, components, cluster=None, dtype=tf.float64):
        if isinstance(data, np.ndarray):
            data = [data]

        self.data = data
        self.dims = sum(d.shape[1] for d in data)
        self.num_points = data[0].shape[0]
        self.components = components

        self.tf_graph = tf.Graph()

        self._initialize_workers(cluster)
        self._initialize_component_mapping()
        self._initialize_data_sources()
        self._initialize_variables(dtype)
        self._initialize_graph(dtype)

    def _initialize_workers(self, cluster):
        if cluster is None:
            self.master_host = ""
            self.workers = [None]
        else:
            self.master_host = "grpc://" + cluster.job_tasks("master")[0]
            self.workers = ["/job:worker/task:" + str(i) for i in range(cluster.num_tasks("worker"))]

    def _initialize_component_mapping(self):
        self.mapping = []
        for component_id in range(len(self.components)):
            worker_id = component_id % len(self.workers)
            self.mapping.append(worker_id)

    def _initialize_data_sources(self):
        self.tf_input_sources = None
        self.tf_worker_data = None

        with self.tf_graph.as_default():
            self.tf_input_sources = []
            for data in self.data:
                self.tf_input_sources.append(tf.placeholder(
                    data.dtype, shape=[self.num_points, data.shape[1]]
                ))

            self.tf_worker_data = []
            for w in self.workers:
                with tf.device(w):
                    self.tf_worker_data.append(
                        [tf.Variable(input, trainable=False)
                         for input in self.tf_input_sources]
                    )

    def _initialize_variables(self, dtype):
        self.tf_dims = None
        self.tf_num_points = None
        self.tf_weights = None

        with self.tf_graph.as_default():
            self.tf_dims = tf.constant(self.dims, dtype=dtype)
            self.tf_num_points = tf.constant(self.num_points, dtype=dtype)
            self.tf_weights = tf.Variable(
                tf.cast(tf.fill(
                    [len(self.components)], 1.0 / len(self.components)
                ), dtype)
            )

    def _initialize_graph(self, dtype=tf.float64):
        self.tf_train_step = None
        self.tf_component_parameters = None
        self.tf_mean_log_likelihood = None

        with self.tf_graph.as_default():
            tf_component_log_probabilities = []
            for component_id in range(len(self.components)):
                worker_id = self.mapping[component_id]
                with tf.device(self.workers[worker_id]):
                    self.components[component_id].initialize(dtype)
                    tf_component_log_probabilities.append(
                        self.components[component_id].get_log_probabilities(
                            self.tf_worker_data[worker_id]
                        )
                    )

            tf_log_components = tf.parallel_stack(tf_component_log_probabilities)
            tf_log_weighted = tf_log_components + tf.expand_dims(tf.log(self.tf_weights), 1)
            tf_log_shift = tf.expand_dims(tf.reduce_max(tf_log_weighted, 0), 0)
            tf_exp_log_shifted = tf.exp(tf_log_weighted - tf_log_shift)
            tf_exp_log_shifted_sum = tf.reduce_sum(tf_exp_log_shifted, 0)
            tf_log_likelihood = tf.reduce_sum(tf.log(tf_exp_log_shifted_sum)) + tf.reduce_sum(tf_log_shift)

            self.tf_mean_log_likelihood = tf_log_likelihood / (self.tf_num_points * self.tf_dims)

            tf_gamma = tf_exp_log_shifted / tf_exp_log_shifted_sum
            tf_gamma_sum = tf.reduce_sum(tf_gamma, 1)
            tf_gamma_weighted = tf_gamma / tf.expand_dims(tf_gamma_sum, 1)
            tf_gamma_sum_split = tf.unstack(tf_gamma_sum)
            tf_gamma_weighted_split = tf.unstack(tf_gamma_weighted)

            tf_component_updaters = []
            for component_id in range(len(self.components)):
                worker_id = self.mapping[component_id]
                with tf.device(self.workers[worker_id]):
                    tf_component_updaters.extend(
                        self.components[component_id].get_parameter_updaters(
                            self.tf_worker_data[worker_id],
                            tf_gamma_weighted_split[component_id],
                            tf_gamma_sum_split[component_id]
                        )
                    )

            tf_new_weights = tf_gamma_sum / self.tf_num_points
            tf_weights_updater = self.tf_weights.assign(tf_new_weights)
            tf_all_updaters = tf_component_updaters + [tf_weights_updater]

            self.tf_train_step = tf.group(*tf_all_updaters)
            self.tf_component_parameters = [
                comp.get_parameters() for comp in self.components
            ]

    def train(self, tolerance=10e-6, max_steps=1000, feedback=None):
        with tf.Session(target=self.master_host, graph=self.tf_graph) as sess:
            sess.run(
                tf.global_variables_initializer(),
                feed_dict={self.tf_input_sources[i]: self.data[i] for i in range(len(self.data))}
            )

            previous_log_likelihood = -np.inf
            for step in range(max_steps):
                _, current_log_likelihood = sess.run([
                    self.tf_train_step,
                    self.tf_mean_log_likelihood
                ])

                if step > 0:
                    difference = current_log_likelihood - previous_log_likelihood
                    if feedback is not None:
                        feedback(step, current_log_likelihood, difference)
                    if tolerance is not None and difference <= tolerance:
                        break
                else:
                    if feedback is not None:
                        feedback(step, current_log_likelihood, None)
                previous_log_likelihood = current_log_likelihood

            return sess.run([
                self.tf_mean_log_likelihood,
                self.tf_weights, self.tf_component_parameters
            ])