import numpy as np
import soundfile as sf
import argparse
import os
import keras
import sklearn
import librosa
from keras import backend as K

eps = np.finfo(np.float).eps

def class_mae(y_true, y_pred):
    return K.mean(
            K.argmax(y_pred, axis=-1) - K.argmax(y_true, axis=-1)

def count(audio, model, scaler):
    # compute STFT
    X = np.abs(librosa.stft(audio, n_fft=400, hop_length=160)).T

    # apply global (featurewise) standardization to mean1, var0
    X = scaler.transform(X)

    # cut to input shape length (500 frames x 201 STFT bins)
    X = X[:500, :]

    # apply l2 normalization
    Theta = np.linalg.norm(X, axis=1) + eps
    X /= np.mean(Theta)

    # add sample dimension
    X = X[np.newaxis, ...]

    if len(model.input_shape) == 4:
        X = X[:, np.newaxis, ...]

    ys = model.predict(X, verbose=0)
    return np.argmax(ys, axis=1)[0]

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Load keras model and predict speaker count'

        help='audio file (samplerate 16 kHz) of 5 seconds duration'

        '--model', default='CRNN',
        help='model name'

    args = parser.parse_args()

    # load model
    model = keras.models.load_model(
        os.path.join('models', args.model + '.h5'),
            'class_mae': class_mae,
            'exp': K.exp

    # print model configuration
    # save as svg file
    # load standardisation parameters
    scaler = sklearn.preprocessing.StandardScaler()
    with np.load(os.path.join("models", 'scaler.npz')) as data:
        scaler.mean_ = data['arr_0']
        scaler.scale_ = data['arr_1']

    # compute audio
    audio, rate =, always_2d=True)

    # downmix to mono
    audio = np.mean(audio, axis=1)
    estimate = count(audio, model, scaler)
    print("Speaker Count Estimate: ", estimate)