## AUTHOR: Aaron Nicolson ## AFFILIATION: Signal Processing Laboratory, Griffith University ## ## This Source Code Form is subject to the terms of the Mozilla Public ## License, v. 2.0. If a copy of the MPL was not distributed with this ## file, You can obtain one at http://mozilla.org/MPL/2.0/. import numpy as np from scipy.special import exp1, i0, i1 def mmse_stsa(xi, gamma): """ Computes the MMSE-STSA gain function. Argument/s: xi - a priori SNR. gamma - a posteriori SNR. Returns: G - MMSE-STSA gain function. """ nu = np.multiply(xi, np.divide(gamma, np.add(1, xi))) G = np.multiply(np.multiply(np.multiply(np.divide(np.sqrt(np.pi), 2), np.divide(np.sqrt(nu), gamma)), np.exp(np.divide(-nu,2))), np.add(np.multiply(np.add(1, nu), i0(np.divide(nu,2))), np.multiply(nu, i1(np.divide(nu, 2))))) # MMSE-STSA gain function. idx = np.isnan(G) | np.isinf(G) # replace by Wiener gain. G[idx] = np.divide(xi[idx], np.add(1, xi[idx])) # Wiener gain. return G def mmse_lsa(xi, gamma): """ Computes the MMSE-LSA gain function. Argument/s: xi - a priori SNR. gamma - a posteriori SNR. Returns: MMSE-LSA gain function. """ nu = np.multiply(np.divide(xi, np.add(1, xi)), gamma) return np.multiply(np.divide(xi, np.add(1, xi)), np.exp(np.multiply(0.5, exp1(nu)))) # MMSE-LSA gain function. def wf(xi): """ Computes the Wiener filter (WF) gain function. Argument/s: xi - a priori SNR. Returns: WF gain function. """ return np.divide(xi, np.add(xi, 1.0)) # WF gain function. def srwf(xi): """ Computes the square-root Wiener filter (WF) gain function. Argument/s: xi - a priori SNR. Returns: SRWF gain function. """ return np.sqrt(wf(xi)) # SRWF gain function. def cwf(xi): """ Computes the constrained Wiener filter (WF) gain function. Argument/s: xi - a priori SNR. Returns: cWF gain function. """ return wf(np.sqrt(xi)) # cWF gain function. def irm(xi): """ Computes the ideal ratio mask (IRM). Argument/s: xi - a priori SNR. Returns: IRM. """ return srwf(xi) # IRM. def ibm(xi): """ Computes the ideal binary mask (IBM) with a threshold of 0 dB. Argument/s: xi - a priori SNR. Returns: IBM. """ return np.greater(xi, 1, dtype=np.float32) # IBM (1 corresponds to 0 dB). def deepmmse(xi, gamma): """ DeepMMSE utilises the MMSE noise periodogram estimate gain function. Argument/s: xi - a priori SNR. gamma - a posteriori SNR. Returns: MMSE-Noise_PSD gain function. """ return np.add(np.divide(1, np.add(1, xi)), np.divide(xi, np.multiply(gamma, np.add(1, xi)))) # MMSE noise periodogram estimate gain function. def gfunc(xi, gamma=None, gtype='mmse-lsa'): """ Computes the selected gain function. Argument/s: xi - a priori SNR. gamma - a posteriori SNR. gtype - gain function type. Returns: G - gain function. """ if gtype == 'mmse-lsa': G = mmse_lsa(xi, gamma) elif gtype == 'mmse-stsa': G = mmse_stsa(xi, gamma) elif gtype == 'wf': G = wf(xi) elif gtype == 'srwf': G = srwf(xi) elif gtype == 'cwf': G = cwf(xi) elif gtype == 'irm': G = irm(xi) elif gtype == 'ibm': G = ibm(xi) elif gtype == 'deepmmse': G = deepmmse(xi, gamma) else: raise ValueError('Invalid gain function type.') return G