#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright 2017   Chenglin Xu (NTU, Singapore)
# Updated by Chenglin, Dec 2018, Jul 2019

"""
1. Extract features (magnitude, log magnitude)
2. Converts to TFRecords format
3. Calculate global CMVN (same as kaldi).
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import multiprocessing
import os,sys
import numpy as np
import tensorflow as tf

from utils.audioread import audioread
from utils.sigproc import framesig,magspec
from utils.normhamming import normhamming
import time

def make_sequence(feats, feats_aux, labels=None):
    """
    Return a sequence for given feats and corresponding labels (optional for test)
    Args:
        feats: input feature vectors (i.e. magnitude of mixture speech)
        feats_aux: inputs to auxilary network to learn target speaker representation
        labels1: reference labels for target sepaker
    Returns:
        A tf.train.SequenceExample
    """

    inputs = [tf.train.Feature(float_list=tf.train.FloatList(value=feat)) for feat in feats]
    inputs_aux = [tf.train.Feature(float_list=tf.train.FloatList(value=feat_aux)) for feat_aux in feats_aux]
    if labels is not None:
        targets = [tf.train.Feature(float_list=tf.train.FloatList(value=label)) for label in labels]
        feature_list = {
            'inputs': tf.train.FeatureList(feature=inputs),
            'inputs_aux': tf.train.FeatureList(feature=inputs_aux),
            'labels': tf.train.FeatureList(feature=targets)
        }
    else:
        feature_list = {
            'inputs': tf.train.FeatureList(feature=inputs),
            'inputs_aux': tf.train.FeatureList(feature=inputs_aux)
        }

    feature_lists = tf.train.FeatureLists(feature_list=feature_list)
    return tf.train.SequenceExample(feature_lists=feature_lists)

def cal_phase_mag(filename, dur=None):
    '''
    extract phase and feats for one utterance
    '''
    
    rate, sig, _ = audioread(filename)
    if dur != 0:
        sig = sig[:rate*dur]
    frames = framesig(sig, FLAGS.FFT_LEN, FLAGS.FRAME_SHIFT, lambda x: normhamming(x), True)
    phase, feats = magspec(frames, FLAGS.FFT_LEN)

    return phase, feats

def cal_intermedia_mean_var(feats):
    mean_feats = np.sum(feats, 0)
    var_feats = np.sum(np.square(feats), 0)                                                      
    return str(np.shape(feats)[0])+'+'+' '.join(str(mean_feat) for mean_feat in mean_feats)+'+'+' '.join(str(var_feat) for var_feat in var_feats)

def extract_mag_feats(item, mean_var_dict, mean_var_dict_aux):
    tokens = item.strip().split()

    (_, name) = os.path.split(tokens[0])
    (uttid, _) = os.path.splitext(name)
    # extract feats for mixture
    phase_mix, feats = cal_phase_mag(tokens[0], dur=FLAGS.dur)
    mean_var_dict[uttid] = cal_intermedia_mean_var(feats)

    (_, name_aux) = os.path.split(tokens[1])
    (uttid_aux, _) = os.path.splitext(name_aux)
    tokens_aux = uttid_aux.split('-')
    # extract auxiliary feats for auxiliary network
    phase_aux, feats_aux = cal_phase_mag(tokens[1], dur=FLAGS.dur)
    # calculate intermediates for mean and variance for auxiliary inputs, save to kaldi vector format
    mean_var_dict_aux[uttid] = cal_intermedia_mean_var(feats_aux)

    # extract mag for clean as labels
    if FLAGS.with_labels:
        # extract feats for mixture
        phase_clean, labels = cal_phase_mag(tokens[2], dur=FLAGS.dur)

        if FLAGS.apply_psm:
            labels = labels * np.cos(phase_mix - phase_clean)
    else:
        labels = None
    
    # tfrecords to save the sequency consisting of feats and labels (optional for test)
    tfrecords_name = os.path.join(FLAGS.output_dir, FLAGS.data_type, uttid+".tfrecords")
    writer = tf.python_io.TFRecordWriter(tfrecords_name)
    # write feats and labels into tfrecords
    writer.write(make_sequence(feats, feats_aux, labels).SerializeToString())

    return mean_var_dict, mean_var_dict_aux

def cal_global_mean_std(filename, mean_var_dict):
    cmvn = np.zeros((2, int(FLAGS.FFT_LEN/2+1)), dtype=np.float32)
    frames = 0.0
    for line in mean_var_dict:
        tokens = line.strip().split('+')
        frames += float(tokens[0])
        utt_mean_tokens = tokens[1].strip().split()
        cmvn[0] += [np.float32(i) for i in utt_mean_tokens]
        utt_var_tokens = tokens[2].strip().split()
        cmvn[1] += [np.float32(i) for i in utt_var_tokens]

    mean = cmvn[0] / frames
    var = cmvn[1] / frames - mean ** 2
    var[var<=0] = 1.0e-20
    std = np.sqrt(var)

    print(mean)
    print(std)
    np.savez(filename, mean_inputs=mean, stddev_inputs=std)

def main(unused_argv):
    print('Extract starts ...')
    print(time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime()))

    if not os.path.exists(os.path.join(FLAGS.output_dir, FLAGS.data_type)):
        os.makedirs(os.path.join(FLAGS.output_dir, FLAGS.data_type))

    lists = open(FLAGS.list_path).readlines()

    # check whether the cmvn file for training exist, remove if exist.
    if os.path.exists(FLAGS.inputs_cmvn):
        os.remove(FLAGS.inputs_cmvn)
    if os.path.exists(FLAGS.inputs_cmvn.replace('cmvn', 'cmvn_aux')):
        os.remove(FLAGS.inputs_cmvn.replace('cmvn', 'cmvn_aux'))

    mean_var_dict = multiprocessing.Manager().dict()
    mean_var_dict_aux = multiprocessing.Manager().dict()
    pool = multiprocessing.Pool(FLAGS.num_threads)
    workers = []
    for item in lists:
        workers.append(pool.apply_async(extract_mag_feats(item, mean_var_dict, mean_var_dict_aux)))
    pool.close()
    pool.join()

    # convert the utterance level intermediates for mean and var to global mean and std, then save
    cal_global_mean_std(FLAGS.inputs_cmvn, mean_var_dict.values())
    cal_global_mean_std(FLAGS.inputs_cmvn.replace('cmvn', 'cmvn_aux'), mean_var_dict_aux.values())
    
    print(time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime()))
    print('Extract ends.')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--with_labels',
        type=int,
        default=1,
        help='Whether extract features for the targets as labels, default to prepare labels.')
    parser.add_argument(
        '--data_type',
        type=str,
        default='tr',
        help='tr, cv, tt.')
    parser.add_argument(
        '--apply_psm',
        type=int,
        default=1,
        help='Whether use phase sensitive mask.')
    parser.add_argument(
        '--inputs_cmvn',
        type=str,
        default='data/inputs_utts.cmvn',
        help='Path to save CMVN for the inputs'
    )
    parser.add_argument(
        '--list_path',
        type=str,
        default='lists/tr_mix.lst',
        help='List of the paired mix, aux, clean data'
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        default='data/tfrecords',
        help='Directory to save the features into tfrecords format'
    )
    parser.add_argument(
        '--FFT_LEN',
        type=int,
        default=512,
        help='The length of fft window.'
    )
    parser.add_argument(
        '--FRAME_SHIFT',
        type=int,
        default=256,
        help='The shift of samples when calculating fft.'
    )
    parser.add_argument(
        '--num_threads',
        type=int,
        default=10,
        help='The number of threads to convert tfrecords files.'
    )
    parser.add_argument(
        '--dur',
        type=int,
        default=0,
        help='Duration of each file, cut to fixed length wav for mix, aux, clean'
    )
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)