# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf

from distribution_base import DistributionBase
from full_covariance import FullCovariance


class GaussianDistribution(DistributionBase):

    def __init__(self, dims, mean=None, covariance=None):
        self.dims = dims
        self.mean = mean
        self.covariance = covariance

        self.tf_mean = None
        self.tf_covariance = None
        self.tf_ln2piD = None

    def initialize(self, dtype=tf.float64):
        if self.tf_mean is None:
            if self.mean is not None:
                self.tf_mean = tf.Variable(self.mean, dtype=dtype)
            else:
                self.tf_mean = tf.Variable(tf.cast(tf.fill([self.dims], 0.0), dtype))

        if self.tf_covariance is None:
            if self.covariance is not None:
                self.tf_covariance = self.covariance
            else:
                self.tf_covariance = FullCovariance(self.dims)

            self.tf_covariance.initialize(dtype)

        if self.tf_ln2piD is None:
            self.tf_ln2piD = tf.constant(np.log(2 * np.pi) * self.dims, dtype=dtype)

    def get_parameters(self):
        return [
            self.tf_mean,
            self.tf_covariance.get_matrix()
        ]

    def get_log_probabilities(self, data):
        tf_quadratic_form = self.tf_covariance.get_inv_quadratic_form(data[0], self.tf_mean)
        tf_log_coefficient = self.tf_ln2piD + self.tf_covariance.get_log_determinant()

        return -0.5 * (tf_log_coefficient + tf_quadratic_form)

    def get_parameter_updaters(self, data, gamma_weighted, gamma_sum):
        tf_new_mean = tf.reduce_sum(data[0] * tf.expand_dims(gamma_weighted, 1), 0)
        tf_covariance_updater = self.tf_covariance.get_value_updater(
            data[0], tf_new_mean, gamma_weighted, gamma_sum)

        return [tf_covariance_updater, self.tf_mean.assign(tf_new_mean)]