from __future__ import print_function

import sys
sys.path.append('./vggish')

import os

import tensorflow as tf
import numpy as np

from abc import ABC
from abc import abstractmethod

import vggish_slim
import vggish_input
import vggish_postprocess

from audio_records import encodes_example
from audio_util import maybe_create_directory


class ExtractorBase(ABC):
    """Base class for Extractors"""
    def __init__(self):
        super(ExtractorBase, self).__init__()

    @abstractmethod
    def __enter__(self):
        return self

    @abstractmethod
    def __exit__(self, type, value, traceback):
        pass

    @abstractmethod
    def wavfile_to_features(self, wav_file):
        """Extract features from wav file."""
        pass

    def create_records(self, record_path, wav_files, wav_labels):
        """Create TF Records from wav files and corresponding labels."""
        record_dir = os.path.dirname(record_path)
        maybe_create_directory(record_dir)
        writer = tf.python_io.TFRecordWriter(record_path)
        N = len(wav_labels)
        n = 1
        for (wav_file, wav_label) in zip(wav_files, wav_labels):
            tf.logging.info('[{}/{}] Extracting VGGish feature:'
                ' label: {} - {}'.format(n, N, wav_label, wav_file))
            n += 1
            features = self.wavfile_to_features(wav_file)
            num_features = features.shape[0] # one feature for one second
            if num_features == 0:
                tf.logging.warning('No vggish features:'
                    ' label: {} - {}'.format(wav_label, wav_file))
                continue
            cur_wav_labels = [wav_label] * num_features
            for (f, l) in zip(features, cur_wav_labels):
                example = encodes_example(np.float64(f), np.int64(l))
                writer.write(example.SerializeToString())
        writer.close()


class MelExtractor(ExtractorBase):
    """Feature Extractor that extract mel feature from wav."""
    def __init__(self):
        super(MelExtractor, self).__init__()

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        pass

    @staticmethod
    def wavfile_to_features(wav_file):
        assert os.path.exists(wav_file), '{} not exists!'.format(wav_file)
        mel_features = vggish_input.wavfile_to_examples(wav_file)
        return mel_features


class VGGishExtractor(ExtractorBase):
    """Feature Extractor use VGGish model from wav."""
    def __init__(self, checkpoint, pca_params, input_tensor_name, output_tensor_name):
        """Create a new Graph and a new Session for every VGGishExtractor object."""
        super(VGGishExtractor, self).__init__()
        
        self.graph = tf.Graph()
        with self.graph.as_default():
            vggish_slim.define_vggish_slim(training=False)

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True
        self.sess = tf.Session(graph=self.graph, config=sess_config)
        vggish_slim.load_defined_vggish_slim_checkpoint(self.sess, checkpoint)
        
        # use the self.sess to init others
        self.input_tensor = self.graph.get_tensor_by_name(input_tensor_name)
        self.output_tensor = self.graph.get_tensor_by_name(output_tensor_name)

        # postprocessor
        self.postprocess = vggish_postprocess.Postprocessor(pca_params)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()

    def mel_to_vggish(self, mel_features):
        """Converting mel features to VGGish features."""
        assert mel_features is not None, 'mel_features is None'
        # mel_features shape is 0, skip
        if mel_features.shape[0]==0:
            return mel_features
        # Run inference and postprocessing.
        [embedding_batch] = self.sess.run([self.output_tensor],
                                     feed_dict={self.input_tensor: mel_features})
        vggish_features = self.postprocess.postprocess(embedding_batch)
        return vggish_features

    def wavfile_to_features(self, wav_file):
        """Extract VGGish feature from wav file."""
        assert os.path.exists(wav_file), '{} not exists!'.format(wav_file)
        mel_features = MelExtractor.wavfile_to_features(wav_file)
        return self.mel_to_vggish(mel_features)

    def close(self):
        self.sess.close()


if __name__ == '__main__':
    import audio_params
    import vggish_params
    import timeit
    from audio_util import urban_labels

    wav_file = 'F:/3rd-datasets/UrbanSound8K-16bit/audio-classified/siren/90014-8-0-1.wav'
    wav_dir = 'F:/3rd-datasets/UrbanSound8K-16bit/audio-classified/siren'
    wav_filenames = os.listdir(wav_dir)
    wav_files = [os.path.join(wav_dir, wav_filename) for wav_filename in wav_filenames]
    wav_labels = urban_labels(wav_files)

    # test VGGishExtractor
    time_start = timeit.default_timer()
    with VGGishExtractor(audio_params.VGGISH_CHECKPOINT,
                         audio_params.VGGISH_PCA_PARAMS,
                         vggish_params.INPUT_TENSOR_NAME,
                         vggish_params.OUTPUT_TENSOR_NAME) as ve:
        
        vggish_features = ve.wavfile_to_features(wav_file)
        print(vggish_features, vggish_features.shape)

        ve.create_records('./vggish_test.records', wav_files[:10], wav_labels[:10])

    time_end = timeit.default_timer()
    # print('cost time: {}s, {}s/wav'.format((time_end-time_start), (time_end-time_start)/(i+1)))

    # test MelExtractor
    with MelExtractor() as me:
        mel_features = me.wavfile_to_features(wav_file)
        print(mel_features, mel_features.shape)
        me.create_records('./mel_test.records', wav_files[:10], wav_labels[:10])