import numpy as np from typing import Tuple, List import tensorflow as tf from tensorflow import Tensor import string from decompose.distributions.distribution import DrawType, UpdateType from decompose.distributions.cenNormal import CenNormal from decompose.likelihoods.likelihood import Likelihood from decompose.distributions.distribution import Properties class NormalNdLikelihood(Likelihood): def __init__(self, M: Tuple[int, ...], K: int=1, tau: float = 1./1e10, drawType: DrawType = DrawType.SAMPLE, updateType: UpdateType = UpdateType.ALL, dtype=tf.float32) -> None: Likelihood.__init__(self, M, K) self.__tauInit = tau self.__dtype = dtype self.__properties = Properties(name='likelihood', drawType=drawType, updateType=updateType, persistent=True) def init(self, data: Tensor) -> None: tau = self.__tauInit dtype = self.__dtype properties = self.__properties noiseDistribution = CenNormal(tau=tf.constant([tau], dtype=dtype), properties=properties) self.__noiseDistribution = noiseDistribution @property def noiseDistribution(self) -> CenNormal: return(self.__noiseDistribution) def residuals(self, U: Tuple[Tensor, ...], X: Tensor) -> Tensor: F = len(U) axisIds = string.ascii_lowercase[:F] subscripts = f'k{",k".join(axisIds)}->{axisIds}' Xhat = tf.einsum(subscripts, *U) residuals = X-Xhat return(residuals) def llh(self, U: Tuple[Tensor, ...], X: Tensor) -> Tensor: r = self.residuals(U, X) llh = tf.reduce_sum(self.noiseDistribution.llh(r)) return(llh) def loss(self, U: Tuple[Tensor, ...], X: Tensor) -> Tensor: loss = tf.reduce_sum(self.residuals(U, X)**2) return(loss) def update(self, U: Tuple[Tensor, ...], X: Tensor) -> None: if self.noiseDistribution.updateType == UpdateType.ALL: residuals = self.residuals(U, X) flattenedResiduals = tf.reshape(residuals, (-1,))[..., None] self.noiseDistribution.update(flattenedResiduals) def outterTensorProduct(self, Us): F = len(Us) axisIds = string.ascii_lowercase[:F] subscripts = f'k{",k".join(axisIds)}->{axisIds}k' Xhat = tf.einsum(subscripts, *Us) return(Xhat) def prepVars(self, f: int, U: List[Tensor], X: Tensor) -> Tuple[Tensor, Tensor, Tensor]: F = self.F Umf = [U[g] for g in range(F) if g != f] UmfOutter = self.outterTensorProduct(Umf) rangeFm1 = list(range(F-1)) A = tf.tensordot(X, UmfOutter, axes=([g for g in range(F) if g != f], rangeFm1)) B = tf.tensordot(UmfOutter, UmfOutter, axes=(rangeFm1, rangeFm1)) alpha = self.noiseDistribution.tau return(A, B, alpha)