from tensorflow.keras.layers import Layer
from tensorflow.keras import backend as K
from . import backend


class Filterbank(Layer):
    """
    ### `Filterbank`

    `kapre.filterbank.Filterbank(n_fbs, trainable_fb, sr=None, init='mel', fmin=0., fmax=None,
                                 bins_per_octave=12, image_data_format='default', **kwargs)`

    #### Notes
        Input/output are 2D image format.
        E.g., if channel_first,
            - input_shape: ``(None, n_ch, n_freqs, n_time)``
            - output_shape: ``(None, n_ch, n_mels, n_time)``


    #### Parameters
    * n_fbs: int
       - Number of filterbanks

    * sr: int
        - sampling rate. It is used to initialize ``freq_to_mel``.

    * init: str
        - if ``'mel'``, init with mel center frequencies and stds.

    * fmin: float
        - min frequency of filterbanks.
        - If `init == 'log'`, fmin should be > 0. Use `None` if you got no idea.

    * fmax: float
        - max frequency of filterbanks.
        - If `init == 'log'`, fmax is ignored.

    * trainable_fb: bool,
        - Whether the filterbanks are trainable or not.

    """

    def __init__(
        self,
        n_fbs,
        trainable_fb,
        sr=None,
        init='mel',
        fmin=0.0,
        fmax=None,
        bins_per_octave=12,
        image_data_format='default',
        **kwargs,
    ):
        """ TODO: is sr necessary? is fmax necessary? init with None?  """
        self.supports_masking = True
        self.n_fbs = n_fbs
        assert init in ('mel', 'log', 'linear', 'uni_random')
        if fmax is None:
            self.fmax = sr / 2.0
        else:
            self.fmax = fmax
        if init in ('mel', 'log'):
            assert sr is not None

        self.fmin = fmin
        self.init = init
        self.bins_per_octave = bins_per_octave
        self.sr = sr
        self.trainable_fb = trainable_fb
        assert image_data_format in ('default', 'channels_first', 'channels_last')
        if image_data_format == 'default':
            self.image_data_format = K.image_data_format()
        else:
            self.image_data_format = image_data_format
        super(Filterbank, self).__init__(**kwargs)

    def build(self, input_shape):
        if self.image_data_format == 'channels_first':
            self.n_ch = input_shape[1]
            self.n_freq = input_shape[2]
            self.n_time = input_shape[3]
        else:
            self.n_ch = input_shape[3]
            self.n_freq = input_shape[1]
            self.n_time = input_shape[2]

        if self.init == 'mel':
            self.filterbank = K.variable(
                backend.filterbank_mel(
                    sr=self.sr,
                    n_freq=self.n_freq,
                    n_mels=self.n_fbs,
                    fmin=self.fmin,
                    fmax=self.fmax,
                ).transpose(),
                dtype=K.floatx(),
            )
        elif self.init == 'log':
            self.filterbank = K.variable(
                backend.filterbank_log(
                    sr=self.sr,
                    n_freq=self.n_freq,
                    n_bins=self.n_fbs,
                    bins_per_octave=self.bins_per_octave,
                    fmin=self.fmin,
                ).transpose(),
                dtype=K.floatx(),
            )

        if self.trainable_fb:
            self.trainable_weights.append(self.filterbank)
        else:
            self.non_trainable_weights.append(self.filterbank)
        super(Filterbank, self).build(input_shape)
        self.built = True

    def compute_output_shape(self, input_shape):
        if self.image_data_format == 'channels_first':
            return input_shape[0], self.n_ch, self.n_fbs, self.n_time
        else:
            return input_shape[0], self.n_fbs, self.n_time, self.n_ch

    def call(self, x):
        # reshape so that the last axis is freq axis
        if self.image_data_format == 'channels_first':
            x = K.permute_dimensions(x, [0, 1, 3, 2])
        else:
            x = K.permute_dimensions(x, [0, 3, 2, 1])
        output = K.dot(x, self.filterbank)
        # reshape back
        if self.image_data_format == 'channels_first':
            return K.permute_dimensions(output, [0, 1, 3, 2])
        else:
            return K.permute_dimensions(output, [0, 3, 2, 1])

    def get_config(self):
        config = {
            'n_fbs': self.n_fbs,
            'sr': self.sr,
            'init': self.init,
            'fmin': self.fmin,
            'fmax': self.fmax,
            'bins_per_octave': self.bins_per_octave,
            'trainable_fb': self.trainable_fb,
        }
        base_config = super(Filterbank, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))