"""

# Example call:
cd path/where/to/store/simulation/results

# Define how you want to parallelize.
# Note: The mixture models need significant more computation power.
# ccsalloc is an HPC scheduler, and this command requests 50 workers and each has 2GB memory
run="ccsalloc --res=rset=50:mem=2g:ncpus=1 --tracefile=ompi.%reqid.trace -t 2h ompi -- "
run="mpiexec -np 16 "

# For the experiments in the paper you need to run the following commands:
${run} python -m sms_wsj.examples.reference_systems with observation
${run} python -m sms_wsj.examples.reference_systems with mm_masking
${run} python -m sms_wsj.examples.reference_systems with mm_mvdr_souden
${run} python -m sms_wsj.examples.reference_systems with irm_mvdr
${run} python -m sms_wsj.examples.reference_systems with ibm_mvdr
${run} python -m sms_wsj.examples.reference_systems with image
${run} python -m sms_wsj.examples.reference_systems with image_early

"""

import numpy as np
from pathlib import Path
import sacred

from einops import rearrange

import dlp_mpi
from lazy_dataset import from_list

from nara_wpe.utils import stft as _stft, istft as _istft
from pb_bss.extraction import mask_module
from pb_bss.extraction import (
    apply_beamforming_vector,
    get_power_spectral_density_matrix,
    get_single_source_bf_vector,
)
from pb_bss.evaluation.wrapper import OutputMetrics
from pb_bss.distribution import CACGMMTrainer
from pb_bss import initializer
from pb_bss.permutation_alignment import DHTVPermutationAlignment

from sms_wsj.database import SmsWsj, AudioReader
from sms_wsj.io import dump_audio, dump_json

experiment = sacred.Experiment('Ref systems')


@experiment.config
def config():
    dataset = ['cv_dev93', 'test_eval92']  # or 'test_eval92'
    Observation = None

    stft_size = 512
    stft_shift = 128
    stft_window_length = None
    stft_window = 'hann'
    json_path = None


@experiment.named_config
def observation():
    out = 'observation'
    Observation = 'Observation'
    mask_estimator = 'IBM'
    beamformer = 'ch0'
    postfilter = None


@experiment.named_config
def mm_masking():
    out = 'mm_masking'
    Observation = 'Observation'
    mask_estimator = 'cacgmm'
    beamformer = 'ch0'
    postfilter = 'mask_mul'
    weight_constant_axis = -3  # pi_tk
    # weight_constant_axis = -1  # pi_fk


@experiment.named_config
def mm_mvdr_souden():
    out = 'mm_mvdr_souden'
    Observation = 'Observation'
    mask_estimator = 'cacgmm'
    beamformer = 'mvdr_souden'
    postfilter = None
    weight_constant_axis = -3  # pi_tk
    # weight_constant_axis = -1  # pi_fk

@experiment.named_config
def irm_mvdr():
    out = 'irm_mvdr'
    Observation = 'Observation'
    mask_estimator = 'IRM'
    beamformer = 'mvdr_souden'
    postfilter = None


@experiment.named_config
def ibm_mvdr():
    out = 'ibm_mvdr'
    Observation = 'Observation'
    mask_estimator = 'IBM'
    beamformer = 'mvdr_souden'
    postfilter = None


@experiment.named_config
def image():
    out = 'image'
    Observation = 'speech_image'
    # mask_estimator = 'ICM_0'
    mask_estimator = None
    beamformer = 'ch0'
    postfilter = None


@experiment.named_config
def image_mask():
    out = 'image_mask'
    Observation = 'Observation'
    mask_estimator = 'ICM_0'
    # mask_estimator = None
    beamformer = 'ch0'
    postfilter = 'mask_mul'


@experiment.named_config
def image_early():
    out = 'image_early'
    Observation = 'speech_reverberation_early'
    # mask_estimator = 'ICM_0_early'
    mask_estimator = None
    beamformer = 'ch0'
    postfilter = None


def get_multi_speaker_metrics(
    mask,  # T Ktarget F
    Observation,  # D T F (stft signal)
    speech_source,  # Ksource N (time signal)
    Speech_image=None,  # Ksource D T F (stft signal)
    Noise_image=None,  # D T F (stft signal)
    istft=None,  # callable(signal, num_samples=num_samples)
    bf_algorithm='mvdr_souden',
    postfilter=None,  # [None, 'mask_mul']
) -> OutputMetrics:
    """

    >>> from IPython.lib.pretty import pprint
    >>> from pb_bss.testing import dummy_data
    >>> from paderbox.transform.module_stft import stft, istft
    >>> from pb_bss.extraction import ideal_ratio_mask, phase_sensitive_mask
    >>> from pb_bss.extraction import ideal_complex_mask

    >>> example = dummy_data.reverberation_data()

    >>> Observation = stft(example['audio_data']['observation'])
    >>> Speech_image = stft(example['audio_data']['speech_image'])
    >>> Noise_image = stft(example['audio_data']['noise_image'])
    >>> speech_source = example['audio_data']['speech_source']

    >>> mask = ideal_ratio_mask(np.abs([*Speech_image, Noise_image]).sum(1))
    >>> X_mask = mask[:-1]
    >>> N_mask = mask[-1]
    >>> kwargs = {}
    >>> kwargs['mask'] = np.stack([*mask], 1)
    >>> kwargs['Observation'] = Observation
    >>> kwargs['Speech_image'] = Speech_image
    >>> kwargs['Noise_image'] = Noise_image
    >>> kwargs['speech_source'] = speech_source
    >>> kwargs['istft'] = istft
    >>> pprint(get_multi_speaker_metrics(**kwargs).as_dict())
    {'pesq': array([1.996, 2.105]),
     'stoi': array([0.8425774 , 0.86015112]),
     'mir_eval_sxr_sdr': array([13.82179099, 11.37128002]),
     'mir_eval_sxr_sir': array([21.39419702, 18.52582023]),
     'mir_eval_sxr_sar': array([14.68805087, 12.3606874 ]),
     'mir_eval_sxr_selection': array([0, 1]),
     'invasive_sxr_sdr': array([17.17792759, 14.49937822]),
     'invasive_sxr_sir': array([18.9065789 , 16.07738463]),
     'invasive_sxr_snr': array([22.01439067, 19.66127281])}
    >>> pprint(get_multi_speaker_metrics(**kwargs, postfilter='mask_mul').as_dict())
    {'pesq': array([2.235, 2.271]),
     'stoi': array([0.84173865, 0.85532424]),
     'mir_eval_sxr_sdr': array([14.17958101, 11.69826193]),
     'mir_eval_sxr_sir': array([29.62978561, 26.10579693]),
     'mir_eval_sxr_sar': array([14.3099193, 11.8692283]),
     'mir_eval_sxr_selection': array([0, 1]),
     'invasive_sxr_sdr': array([24.00659296, 20.80162802]),
     'invasive_sxr_sir': array([27.13945978, 24.21115858]),
     'invasive_sxr_snr': array([26.89769041, 23.44632734])}
    >>> pprint(get_multi_speaker_metrics(**kwargs, bf_algorithm='ch0', postfilter='mask_mul').as_dict())
    {'pesq': array([1.969, 2.018]),
     'stoi': array([0.81097215, 0.80093435]),
     'mir_eval_sxr_sdr': array([10.2343187 ,  8.29797827]),
     'mir_eval_sxr_sir': array([16.84226656, 14.64059341]),
     'mir_eval_sxr_sar': array([11.3932819 ,  9.59180288]),
     'mir_eval_sxr_selection': array([0, 1]),
     'invasive_sxr_sdr': array([14.70258429, 11.87061145]),
     'invasive_sxr_sir': array([14.74794743, 11.92701556]),
     'invasive_sxr_snr': array([34.53605847, 30.76351885])}

    >>> mask = ideal_ratio_mask(np.abs([*Speech_image, Noise_image])[:, 0])
    >>> kwargs['mask'] = np.stack([*mask], 1)
    >>> kwargs['speech_source'] = example['audio_data']['speech_image'][:, 0]
    >>> pprint(get_multi_speaker_metrics(**kwargs, bf_algorithm='ch0', postfilter='mask_mul').as_dict())
    {'pesq': array([3.471, 3.47 ]),
     'stoi': array([0.96011783, 0.96072581]),
     'mir_eval_sxr_sdr': array([13.50013349, 10.59091527]),
     'mir_eval_sxr_sir': array([17.67436882, 14.76824653]),
     'mir_eval_sxr_sar': array([15.66698718, 12.82478905]),
     'mir_eval_sxr_selection': array([0, 1]),
     'invasive_sxr_sdr': array([15.0283757 , 12.18546349]),
     'invasive_sxr_sir': array([15.07095641, 12.23764194]),
     'invasive_sxr_snr': array([35.13536337, 31.41445774])}

    >>> mask = phase_sensitive_mask(np.array([*Speech_image, Noise_image])[:, 0])
    >>> kwargs['mask'] = np.stack([*mask], 1)
    >>> kwargs['speech_source'] = example['audio_data']['speech_image'][:, 0]
    >>> pprint(get_multi_speaker_metrics(**kwargs, bf_algorithm='ch0', postfilter='mask_mul').as_dict())
    {'pesq': array([3.965, 3.968]),
     'stoi': array([0.98172316, 0.98371817]),
     'mir_eval_sxr_sdr': array([17.08649852, 14.51167667]),
     'mir_eval_sxr_sir': array([25.39489935, 24.17276323]),
     'mir_eval_sxr_sar': array([17.79271334, 15.0251782 ]),
     'mir_eval_sxr_selection': array([0, 1]),
     'invasive_sxr_sdr': array([14.67450877, 12.21865275]),
     'invasive_sxr_sir': array([14.77642923, 12.32843497]),
     'invasive_sxr_snr': array([31.02059848, 28.2459515 ])}
    >>> mask = ideal_complex_mask(np.array([*Speech_image, Noise_image])[:, 0])
    >>> kwargs['mask'] = np.stack([*mask], 1)
    >>> kwargs['speech_source'] = example['audio_data']['speech_image'][:, 0]
    >>> pprint(get_multi_speaker_metrics(**kwargs, bf_algorithm='ch0', postfilter='mask_mul').as_dict())
    {'pesq': array([4.549, 4.549]),
     'stoi': array([1., 1.]),
     'mir_eval_sxr_sdr': array([149.04269346, 147.03728106]),
     'mir_eval_sxr_sir': array([170.73079352, 168.36046824]),
     'mir_eval_sxr_sar': array([149.07223578, 147.06942287]),
     'mir_eval_sxr_selection': array([0, 1]),
     'invasive_sxr_sdr': array([12.32048218,  9.61471296]),
     'invasive_sxr_sir': array([12.41346788,  9.69274082]),
     'invasive_sxr_snr': array([29.06057363, 27.10901422])}

    """
    _, N = speech_source.shape
    K = mask.shape[-2]
    D, T, F = Observation.shape

    assert K < 10, (K, mask.shape, N, D, T, F)
    assert D < 30, (K, N, D, T, F)

    psds = get_power_spectral_density_matrix(
        rearrange(Observation, 'd t f -> f d t', d=D, t=T, f=F),
        rearrange(mask, 't k f -> f k t', k=K, t=T, f=F),
    )  # shape: f, ktarget, d, d

    assert psds.shape == (F, K, D, D), (psds.shape, (F, K, D, D))

    beamformers = list()
    for k_target in range(K):
        target_psd = psds[:, k_target]
        distortion_psd = np.sum(np.delete(psds, k_target, axis=1), axis=1)

        beamformers.append(
            get_single_source_bf_vector(
                bf_algorithm,
                target_psd_matrix=target_psd,
                noise_psd_matrix=distortion_psd,
            )
        )
    beamformers = np.stack(beamformers, axis=1)
    assert beamformers.shape == (F, K, D), (beamformers.shape, (F, K, D))

    def postfiler_fn(Signal):
        if postfilter is None:
            return Signal
        elif postfilter == 'mask_mul':
            return Signal * rearrange(mask, 't k f -> k f t', k=K, t=T, f=F)
        else:
            raise ValueError(postfilter)

    Speech_prediction = apply_beamforming_vector(
        vector=rearrange(beamformers, 'f k d -> k f d', k=K, d=D, f=F),
        mix=rearrange(Observation, 'd t f -> f d t', d=D, t=T, f=F),
    )
    Speech_prediction = postfiler_fn(Speech_prediction)
    speech_prediction = istft(rearrange(Speech_prediction, 'k f t -> k t f', k=K, t=T, f=F), num_samples=N)

    if Speech_image is None:
        speech_contribution = None
    else:
        Speech_contribution = apply_beamforming_vector(
            vector=rearrange(beamformers, 'f k d -> k f d', k=K, d=D, f=F),
            mix=rearrange(Speech_image, '(ksource k) d t f -> ksource k f d t', k=1, d=D, t=T, f=F),
        )
        Speech_contribution = postfiler_fn(Speech_contribution)
        # ksource in [K-1, K]
        speech_contribution = istft(rearrange(Speech_contribution, 'ksource k f t -> ksource k t f', k=K, t=T, f=F), num_samples=N)

    if Noise_image is None:
        noise_contribution = None
    else:
        Noise_contribution = apply_beamforming_vector(
            vector=rearrange(beamformers, 'f k d -> k f d', k=K, d=D, f=F),
            mix=rearrange(Noise_image, '(k d) t f -> k f d t', k=1, d=D, t=T, f=F),
        )
        Noise_contribution = postfiler_fn(Noise_contribution)
        noise_contribution = istft(rearrange(Noise_contribution, 'k f t -> k t f', k=K, t=T, f=F), num_samples=N)

    metric = OutputMetrics(
            speech_prediction=speech_prediction,
            speech_source=speech_source,
            speech_contribution=speech_contribution,
            noise_contribution=noise_contribution,
            sample_rate=8000,
            enable_si_sdr=False,
    )

    return metric


@experiment.capture
def get_dataset(dataset, json_path):
    """
    >>> from IPython.lib.pretty import pprint
    >>> np.set_string_function(lambda a: f'array(shape={a.shape}, dtype={a.dtype})')
    >>> pprint(get_dataset('cv_dev93')[0])  # doctest: +ELLIPSIS
    {...
     'example_id': '0_4k6c0303_4k4c0319',
     ...
     'snr': 23.287502642941252,
     'dataset': 'cv_dev93',
     'audio_data': {'observation': array(shape=(6, 93389), dtype=float64),
      'speech_source': array(shape=(2, 93389), dtype=float64),
      'speech_reverberation_early': array(shape=(2, 6, 93389), dtype=float64),
      'speech_reverberation_tail': array(shape=(2, 6, 93389), dtype=float64),
      'noise_image': array(shape=(6, 93389), dtype=float64),
      'speech_image': array(shape=(2, 6, 93389), dtype=float64),
      'Speech_source': array(shape=(2, 733, 257), dtype=complex128),
      'Speech_reverberation_early': array(shape=(2, 6, 733, 257), dtype=complex128),
      'Speech_reverberation_tail': array(shape=(2, 6, 733, 257), dtype=complex128),
      'Speech_image': array(shape=(2, 6, 733, 257), dtype=complex128),
      'Noise_image': array(shape=(6, 733, 257), dtype=complex128),
      'Observation': array(shape=(6, 733, 257), dtype=complex128)}}

    """
    db = SmsWsj(json_path=json_path)
    ds = db.get_dataset(dataset)
    ds = ds.map(AudioReader())

    def calculate_stfts(ex):
        ex['audio_data']['Speech_source'] = stft(ex['audio_data']['speech_source'])
        ex['audio_data']['Speech_reverberation_early'] = stft(ex['audio_data']['speech_reverberation_early'])
        ex['audio_data']['Speech_reverberation_tail'] = stft(ex['audio_data']['speech_reverberation_tail'])
        ex['audio_data']['Speech_image'] = stft(ex['audio_data']['speech_image'])
        ex['audio_data']['Noise_image'] = stft(ex['audio_data']['noise_image'])
        ex['audio_data']['Observation'] = stft(ex['audio_data']['observation'])
        return ex

    return ds.map(calculate_stfts)


@experiment.capture
def stft(
        signal,
        *,
        stft_size=512,
        stft_shift=128,
        stft_window_length=None,
        stft_window='hann',
):
    return _stft(
        signal,
        size=stft_size,
        shift=stft_shift,
        window_length=stft_window_length,
        window=stft_window,
    )


@experiment.capture
def istft(
        signal,
        num_samples,
        *,
        stft_size=512,
        stft_shift=128,
        stft_window_length=None,
        stft_window='hann',
):
    time_signal = _istft(
        signal,
        size=stft_size,
        shift=stft_shift,
        window_length=stft_window_length,
        window=stft_window,
        # num_samples=num_samples,  # this stft does not support num_samples
    )

    pad = True
    if pad:
        assert time_signal.shape[-1] >= num_samples, (time_signal.shape, num_samples)
        assert time_signal.shape[-1] < num_samples + stft_shift, (time_signal.shape, num_samples)
        time_signal = time_signal[..., :num_samples]
    else:
        raise ValueError(
            pad,
            'When padding is False in the stft, the signal is cutted.'
            'This operation can not be inverted.',
        )
    return time_signal


def get_scores(
        ex,
        mask,

        Observation='Observation',

        beamformer='mvdr_souden',
        postfilter = None,
):
    """
    Calculate the scores, where the prediction/estimated signal is tested
    against the source/desired signal.
    This function is for oracle test to figure out, which metric can work with
    source signal.

    SI-SDR does not work, when the desired signal is the signal before the
    room impulse response and give strange results, when the channel is
    changed.

    Example:

        >>> from IPython.lib.pretty import pprint
        >>> ex = get_dataset('cv_dev93')[0]
        >>> mask = get_mask_from_oracle(ex, 'IBM')
        >>> metric, result = get_scores(ex, mask)
        >>> pprint(result)
        {'pesq': array([2.014, 1.78 ]),
         'stoi': array([0.68236465, 0.61319396]),
         'mir_eval_sxr_sdr': array([10.23933413, 10.01566298]),
         'invasive_sxr_sdr': array([15.76439393, 13.86230425])}
    """

    if Observation == 'Observation':
        metric = get_multi_speaker_metrics(
            mask=rearrange(mask, 'k t f -> t k f'),  # T Ktarget F
            Observation=ex['audio_data'][Observation],  # D T F (stft signal)
            speech_source=ex['audio_data']['speech_source'],  # Ksource N (time signal)
            Speech_image=ex['audio_data']['Speech_image'],  # Ksource D T F (stft signal)
            Noise_image=ex['audio_data']['Noise_image'],  # D T F (stft signal)
            istft=istft,  # callable(signal, num_samples=num_samples)
            bf_algorithm=beamformer,
            postfilter=postfilter,  # [None, 'mask_mul']
        )
    else:
        assert mask is None, mask
        assert beamformer == 'ch0', beamformer
        assert postfilter is None, postfilter
        metric = OutputMetrics(
            speech_prediction=ex['audio_data'][Observation][:, 0],
            speech_source=ex['audio_data']['speech_source'],
            # speech_contribution=speech_contribution,
            # noise_contribution=noise_contribution,
            sample_rate=8000,
            enable_si_sdr=False,
        )

    result = metric.as_dict()
    del result['mir_eval_sxr_selection']
    del result['mir_eval_sxr_sar']
    del result['mir_eval_sxr_sir']
    if 'invasive_sxr_sir' in result:
        del result['invasive_sxr_sir']
        del result['invasive_sxr_snr']

    return metric, result


@experiment.capture
def get_mask_from_cacgmm(
        ex,  # (D, T, F)
        weight_constant_axis=-1,
):  # (K, T, F)
    """

    Args:
        observation:

    Returns:

    >>> from nara_wpe.utils import stft
    >>> y = get_dataset('cv_dev93')[0]['audio_data']['observation']
    >>> Y = stft(y, size=512, shift=128)
    >>> get_mask_from_cacgmm(Y).shape
    (3, 813, 257)

    """
    Observation = ex['audio_data']['Observation']
    Observation = rearrange(Observation, 'd t f -> f t d')

    trainer = CACGMMTrainer()

    initialization: 'F, K, T' = initializer.iid.dirichlet_uniform(
        Observation,
        num_classes=3,
        permutation_free=False,
    )

    pa = DHTVPermutationAlignment.from_stft_size(512)

    affiliation = trainer.fit_predict(
        Observation,
        initialization=initialization,
        weight_constant_axis=weight_constant_axis,
        inline_permutation_aligner=pa if weight_constant_axis != -1 else None
    )

    mapping = pa.calculate_mapping(
        rearrange(affiliation, 'f k t ->k f t'))

    affiliation = rearrange(pa.apply_mapping(
        rearrange(affiliation, 'f k t ->k f t'), mapping
    ), 'k f t -> k t f')

    return affiliation


def get_mask_from_oracle(
        ex,
        mask_estimator

):  # (K, T, F)
    """

    Args:
        ex:
        mask_estimator:

    Returns:

    >>> ex = get_dataset('cv_dev93')[0]
    >>> mask = get_mask_from_oracle(ex, 'ICM_0_early')
    >>> mask.shape
    (2, 733, 257)

    >>> obs = np.sum(ex['audio_data']['speech_reverberation_early'] + ex['audio_data']['speech_reverberation_tail'], axis=0) + ex['audio_data']['noise_image']
    >>> np.testing.assert_allclose(ex['audio_data']['observation'], obs, atol=1e-7)
    >>> Obs = np.sum(ex['audio_data']['Speech_reverberation_early'] + ex['audio_data']['Speech_reverberation_tail'], axis=0) + ex['audio_data']['Noise_image']
    >>> np.testing.assert_allclose(ex['audio_data']['Observation'], Obs, atol=2e-7)

    >>> Speech_reverberation_early_0 = ex['audio_data']['Observation'][0] * mask
    >>> np.testing.assert_allclose(Speech_reverberation_early_0, ex['audio_data']['Speech_reverberation_early'][:, 0], atol=1e-13, rtol=1e-13)
    >>> speech_reverberation_early_0 = istft(Speech_reverberation_early_0, num_samples=obs.shape[-1])
    >>> np.testing.assert_allclose(speech_reverberation_early_0, ex['audio_data']['speech_reverberation_early'][:, 0], atol=1e-13, rtol=1e-13)

    >>> mask = get_mask_from_oracle(ex, 'ICM_0')
    >>> mask.shape
    (2, 733, 257)
    >>> Speech_reverberation_early_0 = ex['audio_data']['Observation'][0] * mask
    >>> np.testing.assert_allclose(Speech_reverberation_early_0, ex['audio_data']['Speech_image'][:, 0], atol=1e-13, rtol=1e-13)

    >>> mask = get_mask_from_oracle(ex, 'IBM')
    >>> mask.shape
    (3, 733, 257)

    >>> mask = get_mask_from_oracle(ex, 'IRM')
    >>> mask.shape
    (3, 733, 257)

    """
    from pb_bss.extraction.mask_module import (
        ideal_ratio_mask,
        ideal_complex_mask,
        ideal_binary_mask,
    )

    if mask_estimator == 'ICM_0_early':
        # K, D, T, F =  ex['audio_data']['Speech_reverberation_early'].shape
        return ideal_complex_mask(
            [
            *ex['audio_data']['Speech_reverberation_early'][..., 0, :, :],
            # np.sum(ex['audio_data']['Speech_reverberation_tail'][..., 0, :, :]
            # , axis=0) + ex['audio_data']['Noise_image'][..., 0, :, :]
            ex['audio_data']['Observation'][0, :, :]
            - np.sum(
                ex['audio_data']['Speech_reverberation_early'][..., 0, :, :],
                axis=0)
            ]
        )[:-1, ...]
    elif mask_estimator == 'ICM_0':
        return ideal_complex_mask(
            [
            *ex['audio_data']['Speech_image'][..., 0, :, :],
            ex['audio_data']['Observation'][0, :, :]
            - np.sum(
                ex['audio_data']['Speech_image'][..., 0, :, :],
                axis=0)
            ]
        )[:-1, ...]
    elif mask_estimator in ['IBM', 'IRM']:
        signal = np.sqrt(
            np.abs(
                (np.array([
                    *ex['audio_data']['Speech_image'],
                    ex['audio_data']['Noise_image']
                ]) ** 2)
            ).sum(axis=1)
        )
        if mask_estimator == 'IRM':
            return mask_module.ideal_ratio_mask(signal)
        elif mask_estimator == 'IBM':
            mask = mask_module.ideal_binary_mask(signal)
            mask = np.clip(mask, 1e-10, 1 - 1e-10)
            return mask
        else:
            raise NotImplementedError(mask_module)
    else:
        raise NotImplementedError(mask_module)


@experiment.automain
def main(
        _run,
        out,
        mask_estimator,
        Observation,
        beamformer,
        postfilter,
        normalize_audio=True,
):
    if dlp_mpi.IS_MASTER:
        from sacred.commands import print_config
        print_config(_run)

    ds = get_dataset()

    data = []

    out = Path(out)

    for ex in dlp_mpi.split_managed(ds.sort(), allow_single_worker=True):

        if mask_estimator is None:
            mask = None
        elif mask_estimator == 'cacgmm':
            mask = get_mask_from_cacgmm(ex)
        else:
            mask = get_mask_from_oracle(ex, mask_estimator)

        metric, score = get_scores(
            ex,
            mask,
            Observation=Observation,
            beamformer=beamformer,
            postfilter=postfilter,
        )
        
        est0, est1 = metric.speech_prediction_selection
        dump_audio(est0, out / ex['dataset'] / f"{ex['example_id']}_0.wav", normalize=normalize_audio)
        dump_audio(est1, out / ex['dataset'] / f"{ex['example_id']}_1.wav", normalize=normalize_audio)

        data.append(dict(
            example_id=ex['example_id'],
            value=score,
            dataset=ex['dataset'],
        ))

        # print(score, repr(score))

    data = dlp_mpi.gather(data)

    if dlp_mpi.IS_MASTER:
        data = [
            entry
            for worker_data in data
            for entry in worker_data
        ]

        data = {  # itertools.groupby expect an order
            dataset: list(subset)
            for dataset, subset in from_list(data).groupby(
                lambda ex: ex['dataset']
            ).items()
        }

        for dataset, sub_data in data.items():
            print(f'Write details to {out}.')
            dump_json(sub_data, out / f'{dataset}_scores.json')

        for dataset, sub_data in data.items():
            summary = {}
            for k in sub_data[0]['value'].keys():
                m = np.mean([
                    d['value'][k]
                    for d in sub_data
                ])
                print(dataset, k, m)
                summary[k] = m
            dump_json(summary, out / f'{dataset}_summary.json')