import sys
import torch
import mir_eval
import numpy as np
from asteroid.data.avspeech_dataset import AVSpeechDataset


def snr(pred_signal: torch.Tensor, true_signal: torch.Tensor) -> torch.FloatTensor:
    """
        Calculate the Signal-to-Noise Ratio
        from two signals

        Args:
            pred_signal (torch.Tensor): predicted signal spectrogram.
            true_signal (torch.Tensor): original signal spectrogram.

    """
    inter_signal = true_signal - pred_signal

    true_power = (true_signal ** 2).sum()
    inter_power = (inter_signal ** 2).sum()

    snr = 10*torch.log10(true_power / inter_power)

    return snr

def sdr(pred_signal: torch.Tensor, true_signal: torch.Tensor) -> torch.FloatTensor:
    """
        Calculate the Signal-to-Distortion Ratio
        from two signals

        Args:
            pred_signal (torch.Tensor): predicted signal spectrogram.
            true_signal (torch.Tensor): original signal spectrogram.

    """
    n_sources = pred_signal.shape[0]

    y_pred_wav = np.zeros((n_sources, 48_000))
    y_wav = np.zeros((n_sources, 48_000))

    for i in range(n_sources):
        y_pred_wav[i] = AVSpeechDataset.decode(pred_signal[i, ...]).numpy()
        y_wav[i] = AVSpeechDataset.decode(true_signal[i, ...]).numpy()
    sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(y_wav, y_pred_wav)

    return sdr