# -*- coding: utf-8 -*-

import numpy as np
import scipy.sparse
import sklearn.mixture
from sklearn.mixture.gaussian_mixture import _compute_precision_cholesky

from sprocket.util.delta import construct_static_and_delta_matrix
from .diagGMM import BlockDiagonalGaussianMixture

class GMMTrainer(object):
    """GMM trainer
    This class offers the training of GMM with several types of covariance

    n_mix : int, optional
        The number of mixture components of the GMM
        Default set to 32.
    n_iter : int, optional
        THe number of iteration for EM algorithm.
        Default set to 100.
    covtype : str, optional
        The type of covariance matrix of the GMM
        'full' : full-covariance matrix
        'block_diag' : block-diagonal matrix

    param :
        Sklean-based model parameters of the GMM


    def __init__(self, n_mix=32, n_iter=100, covtype='full'):
        self.n_mix = n_mix
        self.n_iter = n_iter
        self.covtype = covtype

        self.random_state = np.random.mtrand._rand

        # construct GMM parameter
        if self.covtype == 'full':
            self.param = sklearn.mixture.GaussianMixture(
        elif self.covtype == 'block_diag':
            self.param = BlockDiagonalGaussianMixture(
            raise ValueError('Covariance type should be full or block_diag')

    def open_from_param(self, param):
        """Open GMM from sklearn.GaussianMixture

        trainer: GMMTrainer
            GMMTrainer class

        self.param = param

    def train(self, jnt):
        """Fit GMM parameter from given joint feature vector

        jnt : array, shape(`T`, `dim`)
            Joint feature vector of original and target feature vector
            consisting of static and delta components


    def estimate_responsibility(self, ref_jnt):
        """E-step for the single-path training

        ref_jnt: array, shape(`T`, `ref_dim`)
            Reference joint feature vector of original and target feature
            vector consisting of static and delta components, which was
            already fit.

        if self.param is None:
            raise ValueError(
                'Please load param before call estimate_responsibility')

        # perform e-step
        _, self.log_resp = self.param._e_step(ref_jnt)

    def train_singlepath(self, tar_jnt):
        """Fit GMM parameter based on single-path training
        M-step :
            Update GMM parameter using `self.log_resp`, and `tar_jnt`

        tar_jnt: array, shape(`T`, `tar_dim`)
            Joint feature vector of original and target feature vector
            consisting of static and delta components, which will be modeled.

        param :
            Sklean-based model parameters of the GMM

        if self.covtype == 'full':
            single_param = sklearn.mixture.GaussianMixture(
        elif self.covtype == 'block_diag':
            single_param = BlockDiagonalGaussianMixture(
            raise ValueError('Covariance type should be full or block_diag')

        # initialize target single-path param
        single_param._initialize_parameters(tar_jnt, self.random_state)

        # perform m-step
        single_param._m_step(tar_jnt, self.log_resp)

        return single_param

class GMMConvertor(object):
    """A GMM Convertor
    This class offers the several conversion techniques such as Maximum
    Likelihood Parameter Generation (MLPG) and Mimimum Mean Square Error
    (MMSE). Note that the conversion is performed while regarding GMM
    covariance as full-covariance matrix

    n_mix : int, optional
        The number of mixture components of the GMM
        Default set to 32.
    gmmmode: str, optional
        The type of the GMM for opening
        `None` : Normal JD-GMM
        `diff` : Differential GMM
        `intra` : Intra-speaker GMM

    param :
        Sklean-based model parameters of the GMM
    w : shape (`n_mix`)
        Vector of mixture component weight of the GMM
    jmean : shape (`n_mix`, `jnt.shape[0]`)
        Array of joint mean vector of the GMM
    jcov: shape (`n_mix`, `jnt.shape[0]`, `jnt.shape[0]`)
        Array of joint covariance matrix of the GMM


    def __init__(self, n_mix=32, covtype='full', gmmmode=None):
        self.n_mix = n_mix
        self.gmmmode = gmmmode

    def open_from_param(self, param):
        """Open GMM from GMMTrainer

        trainer: GMMTrainer
            GMMTrainer class

        self.param = param

    def convert(self, data, cvtype='mlpg'):
        """Convert data based on conditional probability densify function

        data : array, shape(`T`, `dim`)
            Original data will be converted
        cvtype: str, optional
            Type of conversion technique
            `mlpg` : maximum likelihood parameter generation
            `mmse` : minimum mean square error

        odata : array, shape(`T`, `dim`)
            Converted data

        # estimate parameter sequence
        cseq, wseq, mseq, covseq = self._gmmmap(data)

        if cvtype == 'mlpg':
            # maximum likelihood parameter generation
            odata = self._mlpg(mseq, covseq)
        elif cvtype == 'mmse':
            # minimum mean square error based parameter generation
            odata = self._mmse(wseq, data)
            raise ValueError('please choose conversion mode in `mlpg`, `mmse`')

        return odata

    def _gmmmap(self, sddata):
        # parameter for sequencial data
        T, sddim = sddata.shape

        # estimate posterior sequence
        wseq = self.pX.predict_proba(sddata)

        # estimate mixture sequence
        cseq = np.argmax(wseq, axis=1)

        mseq = np.zeros((T, sddim))
        covseq = np.zeros((T, sddim, sddim))
        for t in range(T):
            # read maximum likelihood mixture component in frame t
            m = cseq[t]

            # conditional mean vector sequence
            mseq[t] = self.meanY[m] + \
                self.A[m] @ (sddata[t] - self.meanX[m])

            # conditional covariance sequence
            covseq[t] = self.cond_cov_inv[m]

        return cseq, wseq, mseq, covseq

    def _mmse(self, wseq, sddata):
        # parameter for sequencial data
        T, sddim = sddata.shape

        odata = np.zeros((T, sddim))
        for t in range(T):
            for m in range(self.n_mix):
                odata[t] += wseq[t, m] * \
                    (self.meanY[m] +
                     self.A[m] @ (sddata[t] - self.meanX[m]))

        # retern static and throw away delta component
        return odata[:, :sddim // 2]

    def _mlpg(self, mseq, covseq):
        # parameter for sequencial data
        T, sddim = mseq.shape

        # prepare W
        W = construct_static_and_delta_matrix(T, sddim // 2)

        # prepare D
        D = get_diagonal_precision_matrix(T, sddim, covseq)

        # calculate W'D
        WD = W.T @ D

        # W'DW
        WDW = WD @ W

        # W'Um
        WDm = WD @ mseq.flatten()

        # estimate y = (W'DW)^-1 * W'Dm
        odata = scipy.sparse.linalg.spsolve(
            WDW, WDm, use_umfpack=False).reshape(T, sddim // 2)

        # return odata
        return odata

    def _deploy_parameters(self):
        # read JD-GMM parameters from self.param
        self.w = self.param.weights_
        self.jmean = self.param.means_
        self.jcov = self.param.covariances_

        # devide GMM parameters into source and target parameters
        sddim = self.jmean.shape[1] // 2
        self.meanX = self.jmean[:, 0:sddim]
        self.meanY = self.jmean[:, sddim:]
        self.covXX = self.jcov[:, :sddim, :sddim]
        self.covXY = self.jcov[:, :sddim, sddim:]
        self.covYX = self.jcov[:, sddim:, :sddim]
        self.covYY = self.jcov[:, sddim:, sddim:]

        # change model paramter of GMM into that of gmmmode
        if self.gmmmode is None:
        elif self.gmmmode == 'diff':
        elif self.gmmmode == 'intra':
            raise ValueError('please choose GMM mode in [None, diff, intra]')

        # estimate parameters for conversion


    def _set_Ab(self):
        # calculate A and b from self.jmean, self.jcov
        sddim = self.jmean.shape[1] // 2

        # calculate inverse covariance for covariance XX in each mixture
        self.covXXinv = np.zeros((self.n_mix, sddim, sddim))
        for m in range(self.n_mix):
            self.covXXinv[m] = np.linalg.inv(self.covXX[m])

        # calculate A, b, and conditional covariance given X
        self.A = np.zeros((self.n_mix, sddim, sddim))
        self.b = np.zeros((self.n_mix, sddim))
        self.cond_cov_inv = np.zeros((self.n_mix, sddim, sddim))
        for m in range(self.n_mix):
            # calculate A (i.e., A = yxcov_m * xxcov_m^-1)
            self.A[m] = self.covYX[m] @ self.covXXinv[m]

            # calculate b (i.e., b = mean^Y - A * mean^X)
            self.b[m] = self.meanY[m] - self.A[m] @ self.meanX[m]

            # calculate conditional covariance
            # (i.e., cov^(Y|X)^-1 = (yycov - A * xycov)^-1)
            self.cond_cov_inv[m] = np.linalg.inv(self.covYY[
                m] - self.A[m] @ self.covXY[m])


    def _set_pX(self):
        # probability density function of X
        self.pX = sklearn.mixture.GaussianMixture(
            n_components=self.n_mix, covariance_type='full')
        self.pX.weights_ = self.w
        self.pX.means_ = self.meanX
        self.pX.covariances_ = self.covXX

        # following function is required to estimate porsterior
        self.pX.precisions_cholesky_ = _compute_precision_cholesky(
            self.covXX, 'full')

    def _transform_gmm_into_diffgmm(self):
        self.meanX = self.meanX
        self.meanY = self.meanY - self.meanX
        self.covXX = self.covXX
        self.covYY = self.covXX + self.covYY - self.covXY - self.covYX
        self.covXY = self.covXY - self.covXX
        self.covYX = self.covXY.transpose(0, 2, 1)

    def _transform_gmm_into_intragmm(self):
        self.meanX = self.meanX
        self.meanY = self.meanX
        self.covXX = self.covXX
        self.covXY = self.covXY @ np.linalg.solve(self.covYY, self.covYX)
        self.covYX = self.covXY
        self.covYY = self.covXX

def get_diagonal_precision_matrix(T, D, covseq):
    return scipy.sparse.block_diag(covseq, format='csr')