"""Gaussian complex-Angular-Centric-Gaussian mixture model This is a specific mixture model to integrate DC and spatial observations. It does and will not support independent dimensions. This also explains, why concrete variable names (i.e. F, T, embedding) are used instead of unnamed independent axes. The Gaussian distributions are assumed to be spherical (scaled identity). @article{Drude2019Integration, title={Integration of neural networks and probabilistic spatial models for acoustic blind source separation}, author={Drude, Lukas and Haeb-Umbach, Reinhold}, journal={IEEE Journal of Selected Topics in Signal Processing}, year={2019}, publisher={IEEE} } """ from operator import xor from typing import Any import numpy as np from dataclasses import dataclass from pb_bss.utils import unsqueeze from pb_bss.distribution import ( ComplexAngularCentralGaussian, ComplexAngularCentralGaussianTrainer, ) from pb_bss.distribution import GaussianTrainer from .mixture_model_utils import ( log_pdf_to_affiliation, log_pdf_to_affiliation_for_integration_models_with_inline_pa, ) from .utils import _ProbabilisticModel @dataclass class GCACGMM(_ProbabilisticModel): weight: np.array # Shape (), (K,), (F, K), (T, K) weight_constant_axis: tuple gaussian: Any # Gaussian, DiagonalGaussian, or SphericalGaussian cacg: ComplexAngularCentralGaussian spatial_weight: float spectral_weight: float def predict(self, observation, embedding): """ Args: observation: Shape (F, T, D) embedding: Shape (F, T, E) Returns: affiliation: Shape (F, K, T) """ assert np.iscomplexobj(observation), observation.dtype assert np.isrealobj(embedding), embedding.dtype observation = observation / np.maximum( np.linalg.norm(observation, axis=-1, keepdims=True), np.finfo(observation.dtype).tiny, ) affiliation, quadratic_form = self._predict(observation, embedding) return affiliation def _predict( self, observation, embedding, affiliation_eps=0., inline_permutation_alignment=False, ): """ Args: observation: Shape (F, T, D) embedding: Shape (F, T, E) Returns: affiliation: Shape (F, K, T) quadratic_form: Shape (F, K, T) """ F, T, D = observation.shape _, _, E = embedding.shape observation_ = observation[..., None, :, :] cacg_log_pdf, quadratic_form = self.cacg._log_pdf( np.swapaxes(observation_, -1, -2) ) embedding_ = np.reshape(embedding, (1, F * T, E)) gaussian_log_pdf = self.gaussian.log_pdf(embedding_) num_classes = gaussian_log_pdf.shape[0] gaussian_log_pdf = np.transpose( np.reshape(gaussian_log_pdf, (num_classes, F, T)), (1, 0, 2) ) if inline_permutation_alignment: affiliation \ = log_pdf_to_affiliation_for_integration_models_with_inline_pa( weight=unsqueeze(self.weight, self.weight_constant_axis), spatial_log_pdf=self.spatial_weight * cacg_log_pdf, spectral_log_pdf=self.spectral_weight * gaussian_log_pdf, affiliation_eps=affiliation_eps, ) else: affiliation = log_pdf_to_affiliation( weight=unsqueeze(self.weight, self.weight_constant_axis), log_pdf=( self.spatial_weight * cacg_log_pdf + self.spectral_weight * gaussian_log_pdf ), affiliation_eps=affiliation_eps, ) return affiliation, quadratic_form class GCACGMMTrainer: def fit( self, observation, embedding, initialization=None, num_classes=None, iterations=100, saliency=None, hermitize=True, covariance_norm='eigenvalue', eigenvalue_floor=1e-10, covariance_type="spherical", fixed_covariance=None, affiliation_eps=1e-10, weight_constant_axis=(-1,), spatial_weight=1., spectral_weight=1., inline_permutation_alignment=False, ) -> GCACGMM: """ Args: observation: Shape (F, T, D) embedding: Shape (F, T, E) initialization: Affiliations between 0 and 1. Shape (F, K, T) num_classes: Scalar >0 iterations: Scalar >0 saliency: Importance weighting for each observation, shape (F, T) hermitize: trace_norm: eigenvalue_floor: covariance_type: Either 'full', 'diagonal', or 'spherical' fixed_covariance: Learned, if None. If fixed, you need to provide a covariance matrix with the correct shape. affiliation_eps: Used in M-step to clip saliency. weight_constant_axis: Axis, along which weight is constant. The axis indices are based on affiliation shape. Consequently: (-3, -2, -1) == constant = '' (-3, -1) == 'k' (-1,) == vanilla == 'fk' (-3,) == 'kt' spatial_weight: spectral_weight: inline_permutation_alignment: Bool to enable inline permutation alignment for integration models. The idea is to reduce disagreement between the spatial and the spectral model. Returns: """ assert xor(initialization is None, num_classes is None), ( "Incompatible input combination. " "Exactly one of the two inputs has to be None: " f"{initialization is None} xor {num_classes is None}" ) assert np.iscomplexobj(observation), observation.dtype assert np.isrealobj(embedding), embedding.dtype assert observation.shape[-1] > 1 observation = observation / np.maximum( np.linalg.norm(observation, axis=-1, keepdims=True), np.finfo(observation.dtype).tiny, ) F, T, D = observation.shape _, _, E = embedding.shape if initialization is None and num_classes is not None: affiliation_shape = (F, num_classes, T) initialization = np.random.uniform(size=affiliation_shape) initialization /= np.einsum("...kt->...t", initialization)[ ..., None, : ] if saliency is None: saliency = np.ones_like(initialization[..., 0, :]) quadratic_form = np.ones_like(initialization) affiliation = initialization model = None for iteration in range(iterations): if model is not None: affiliation, quadratic_form = model._predict( observation=observation, embedding=embedding, inline_permutation_alignment=inline_permutation_alignment, affiliation_eps=affiliation_eps, ) model = self._m_step( observation, embedding, quadratic_form, affiliation=affiliation, saliency=saliency, hermitize=hermitize, covariance_norm=covariance_norm, eigenvalue_floor=eigenvalue_floor, covariance_type=covariance_type, fixed_covariance=fixed_covariance, weight_constant_axis=weight_constant_axis, spatial_weight=spatial_weight, spectral_weight=spectral_weight ) return model def fit_predict( self, observation, embedding, initialization=None, num_classes=None, iterations=100, saliency=None, hermitize=True, covariance_norm='eigenvalue', eigenvalue_floor=1e-10, covariance_type="spherical", fixed_covariance=None, affiliation_eps=1e-10, weight_constant_axis=(-1,), spatial_weight=1., spectral_weight=1., inline_permutation_alignment=False, ): """Fit a model. Then just return the posterior affiliations.""" model = self.fit( observation=observation, embedding=embedding, initialization=initialization, num_classes=num_classes, iterations=iterations, saliency=saliency, hermitize=hermitize, covariance_norm=covariance_norm, eigenvalue_floor=eigenvalue_floor, covariance_type=covariance_type, fixed_covariance=fixed_covariance, affiliation_eps=affiliation_eps, weight_constant_axis=weight_constant_axis, spatial_weight=spatial_weight, spectral_weight=spectral_weight, inline_permutation_alignment=inline_permutation_alignment, ) return model.predict(observation=observation, embedding=embedding) def _m_step( self, observation, embedding, quadratic_form, affiliation, saliency, hermitize, covariance_norm, eigenvalue_floor, covariance_type, fixed_covariance, weight_constant_axis, spatial_weight, spectral_weight ): F, T, D = observation.shape _, _, E = embedding.shape _, K, _ = affiliation.shape masked_affiliation = affiliation * saliency[..., None, :] if -2 in weight_constant_axis: weight = 1 / K else: weight = np.sum( masked_affiliation, axis=weight_constant_axis, keepdims=True ) weight /= np.sum(weight, axis=-2, keepdims=True) weight = np.squeeze(weight, axis=weight_constant_axis) embedding_ = np.reshape(embedding, (1, F * T, E)) masked_affiliation_ = np.reshape( np.transpose(masked_affiliation, (1, 0, 2)), (K, F * T) ) # 'fkt->k,ft' gaussian = GaussianTrainer()._fit( y=embedding_, saliency=masked_affiliation_, covariance_type=covariance_type, ) if fixed_covariance is not None: assert fixed_covariance.shape == gaussian.covariance.shape, ( f'{fixed_covariance.shape} != {gaussian.covariance.shape}' ) gaussian = gaussian.__class__( mean=gaussian.mean, covariance=fixed_covariance ) cacg = ComplexAngularCentralGaussianTrainer()._fit( y=np.swapaxes(observation[..., None, :, :], -1, -2), saliency=masked_affiliation, quadratic_form=quadratic_form, hermitize=hermitize, covariance_norm=covariance_norm, eigenvalue_floor=eigenvalue_floor, ) return GCACGMM( weight=weight, gaussian=gaussian, cacg=cacg, weight_constant_axis=weight_constant_axis, spatial_weight=spatial_weight, spectral_weight=spectral_weight )