#!/usr/bin/env python
# coding: utf8

"""
    Module for building data preprocessing pipeline using the tensorflow
    data API. Data preprocessing such as audio loading, spectrogram
    computation, cropping, feature caching or data augmentation is done
    using a tensorflow dataset object that output a tuple (input_, output)
    where:

    -   input is a dictionary with a single key that contains the (batched)
        mix spectrogram of audio samples
    -   output is a dictionary of spectrogram of the isolated tracks
        (ground truth)
"""

import time
import os
from os.path import exists, join, sep as SEPARATOR

# pylint: disable=import-error
import pandas as pd
import numpy as np
import tensorflow as tf
# pylint: enable=import-error

from .audio.convertor import (
    db_uint_spectrogram_to_gain,
    spectrogram_to_db_uint)
from .audio.spectrogram import (
    compute_spectrogram_tf,
    random_pitch_shift,
    random_time_stretch)
from .utils.logging import get_logger
from .utils.tensor import (
    check_tensor_shape,
    dataset_from_csv,
    set_tensor_shape,
    sync_apply)

__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'

# Default audio parameters to use.
DEFAULT_AUDIO_PARAMS = {
    'instrument_list': ('vocals', 'accompaniment'),
    'mix_name': 'mix',
    'sample_rate': 44100,
    'frame_length': 4096,
    'frame_step': 1024,
    'T': 512,
    'F': 1024
}


def get_training_dataset(audio_params, audio_adapter, audio_path):
    """ Builds training dataset.

    :param audio_params: Audio parameters.
    :param audio_adapter: Adapter to load audio from.
    :param audio_path: Path of directory containing audio.
    :returns: Built dataset.
    """
    builder = DatasetBuilder(
        audio_params,
        audio_adapter,
        audio_path,
        chunk_duration=audio_params.get('chunk_duration', 20.0),
        random_seed=audio_params.get('random_seed', 0))
    return builder.build(
        audio_params.get('train_csv'),
        cache_directory=audio_params.get('training_cache'),
        batch_size=audio_params.get('batch_size'),
        n_chunks_per_song=audio_params.get('n_chunks_per_song', 2),
        random_data_augmentation=False,
        convert_to_uint=True,
        wait_for_cache=False)


def get_validation_dataset(audio_params, audio_adapter, audio_path):
    """ Builds validation dataset.

    :param audio_params: Audio parameters.
    :param audio_adapter: Adapter to load audio from.
    :param audio_path: Path of directory containing audio.
    :returns: Built dataset.
    """
    builder = DatasetBuilder(
        audio_params,
        audio_adapter,
        audio_path,
        chunk_duration=12.0)
    return builder.build(
        audio_params.get('validation_csv'),
        batch_size=audio_params.get('batch_size'),
        cache_directory=audio_params.get('validation_cache'),
        convert_to_uint=True,
        infinite_generator=False,
        n_chunks_per_song=1,
        # should not perform data augmentation for eval:
        random_data_augmentation=False,
        random_time_crop=False,
        shuffle=False,
    )


class InstrumentDatasetBuilder(object):
    """ Instrument based filter and mapper provider. """

    def __init__(self, parent, instrument):
        """ Default constructor.

        :param parent: Parent dataset builder.
        :param instrument: Target instrument.
        """
        self._parent = parent
        self._instrument = instrument
        self._spectrogram_key = f'{instrument}_spectrogram'
        self._min_spectrogram_key = f'min_{instrument}_spectrogram'
        self._max_spectrogram_key = f'max_{instrument}_spectrogram'

    def load_waveform(self, sample):
        """ Load waveform for given sample. """
        return dict(sample, **self._parent._audio_adapter.load_tf_waveform(
            sample[f'{self._instrument}_path'],
            offset=sample['start'],
            duration=self._parent._chunk_duration,
            sample_rate=self._parent._sample_rate,
            waveform_name='waveform'))

    def compute_spectrogram(self, sample):
        """ Compute spectrogram of the given sample. """
        return dict(sample, **{
            self._spectrogram_key: compute_spectrogram_tf(
                sample['waveform'],
                frame_length=self._parent._frame_length,
                frame_step=self._parent._frame_step,
                spec_exponent=1.,
                window_exponent=1.)})

    def filter_frequencies(self, sample):
        """ """
        return dict(sample, **{
            self._spectrogram_key:
                sample[self._spectrogram_key][:, :self._parent._F, :]})

    def convert_to_uint(self, sample):
        """ Convert given sample from float to unit. """
        return dict(sample, **spectrogram_to_db_uint(
            sample[self._spectrogram_key],
            tensor_key=self._spectrogram_key,
            min_key=self._min_spectrogram_key,
            max_key=self._max_spectrogram_key))

    def filter_infinity(self, sample):
        """ Filter infinity sample. """
        return tf.logical_not(
            tf.math.is_inf(
                sample[self._min_spectrogram_key]))

    def convert_to_float32(self, sample):
        """ Convert given sample from unit to float. """
        return dict(sample, **{
            self._spectrogram_key: db_uint_spectrogram_to_gain(
                sample[self._spectrogram_key],
                sample[self._min_spectrogram_key],
                sample[self._max_spectrogram_key])})

    def time_crop(self, sample):
        """ """
        def start(sample):
            """ mid_segment_start """
            return tf.cast(
                tf.maximum(
                    tf.shape(sample[self._spectrogram_key])[0]
                    / 2 - self._parent._T / 2, 0),
                tf.int32)
        return dict(sample, **{
            self._spectrogram_key: sample[self._spectrogram_key][
                start(sample):start(sample) + self._parent._T, :, :]})

    def filter_shape(self, sample):
        """ Filter badly shaped sample. """
        return check_tensor_shape(
            sample[self._spectrogram_key], (
                self._parent._T, self._parent._F, 2))

    def reshape_spectrogram(self, sample):
        """ """
        return dict(sample, **{
            self._spectrogram_key: set_tensor_shape(
                sample[self._spectrogram_key],
                (self._parent._T, self._parent._F, 2))})


class DatasetBuilder(object):
    """
    """

    # Margin at beginning and end of songs in seconds.
    MARGIN = 0.5

    # Wait period for cache (in seconds).
    WAIT_PERIOD = 60

    def __init__(
            self,
            audio_params, audio_adapter, audio_path,
            random_seed=0, chunk_duration=20.0):
        """ Default constructor.

        NOTE: Probably need for AudioAdapter.

        :param audio_params: Audio parameters to use.
        :param audio_adapter: Audio adapter to use.
        :param audio_path:
        :param random_seed:
        :param chunk_duration:
        """
        # Length of segment in frames (if fs=22050 and
        # frame_step=512, then T=512 corresponds to 11.89s)
        self._T = audio_params['T']
        # Number of frequency bins to be used (should
        # be less than frame_length/2 + 1)
        self._F = audio_params['F']
        self._sample_rate = audio_params['sample_rate']
        self._frame_length = audio_params['frame_length']
        self._frame_step = audio_params['frame_step']
        self._mix_name = audio_params['mix_name']
        self._instruments = [self._mix_name] + audio_params['instrument_list']
        self._instrument_builders = None
        self._chunk_duration = chunk_duration
        self._audio_adapter = audio_adapter
        self._audio_params = audio_params
        self._audio_path = audio_path
        self._random_seed = random_seed

    def expand_path(self, sample):
        """ Expands audio paths for the given sample. """
        return dict(sample, **{f'{instrument}_path': tf.string_join(
            (self._audio_path, sample[f'{instrument}_path']), SEPARATOR)
            for instrument in self._instruments})

    def filter_error(self, sample):
        """ Filter errored sample. """
        return tf.logical_not(sample['waveform_error'])

    def filter_waveform(self, sample):
        """ Filter waveform from sample. """
        return {k: v for k, v in sample.items() if not k == 'waveform'}

    def harmonize_spectrogram(self, sample):
        """ Ensure same size for vocals and mix spectrograms. """
        def _reduce(sample):
            return tf.reduce_min([
                tf.shape(sample[f'{instrument}_spectrogram'])[0]
                for instrument in self._instruments])
        return dict(sample, **{
            f'{instrument}_spectrogram':
                sample[f'{instrument}_spectrogram'][:_reduce(sample), :, :]
            for instrument in self._instruments})

    def filter_short_segments(self, sample):
        """ Filter out too short segment. """
        return tf.reduce_any([
            tf.shape(sample[f'{instrument}_spectrogram'])[0] >= self._T
            for instrument in self._instruments])

    def random_time_crop(self, sample):
        """ Random time crop of 11.88s. """
        return dict(sample, **sync_apply({
            f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
            for instrument in self._instruments},
            lambda x: tf.image.random_crop(
                x, (self._T, len(self._instruments) * self._F, 2),
                seed=self._random_seed)))

    def random_time_stretch(self, sample):
        """ Randomly time stretch the given sample. """
        return dict(sample, **sync_apply({
            f'{instrument}_spectrogram':
                sample[f'{instrument}_spectrogram']
            for instrument in self._instruments},
            lambda x: random_time_stretch(
                x, factor_min=0.9, factor_max=1.1)))

    def random_pitch_shift(self, sample):
        """ Randomly pitch shift the given sample. """
        return dict(sample, **sync_apply({
            f'{instrument}_spectrogram':
                sample[f'{instrument}_spectrogram']
            for instrument in self._instruments},
            lambda x: random_pitch_shift(
                x, shift_min=-1.0, shift_max=1.0), concat_axis=0))

    def map_features(self, sample):
        """ Select features and annotation of the given sample. """
        input_ = {
            f'{self._mix_name}_spectrogram':
                sample[f'{self._mix_name}_spectrogram']}
        output = {
            f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
            for instrument in self._audio_params['instrument_list']}
        return (input_, output)

    def compute_segments(self, dataset, n_chunks_per_song):
        """ Computes segments for each song of the dataset.

        :param dataset: Dataset to compute segments for.
        :param n_chunks_per_song: Number of segment per song to compute.
        :returns: Segmented dataset.
        """
        if n_chunks_per_song <= 0:
            raise ValueError('n_chunks_per_song must be positif')
        datasets = []
        for k in range(n_chunks_per_song):
            if n_chunks_per_song > 1:
                datasets.append(
                    dataset.map(lambda sample: dict(sample, start=tf.maximum(
                        k * (
                            sample['duration'] - self._chunk_duration - 2
                            * self.MARGIN) / (n_chunks_per_song - 1)
                        + self.MARGIN, 0))))
            elif n_chunks_per_song == 1:  # Take central segment.
                datasets.append(
                    dataset.map(lambda sample: dict(sample, start=tf.maximum(
                        sample['duration'] / 2 - self._chunk_duration / 2,
                        0))))
        dataset = datasets[-1]
        for d in datasets[:-1]:
            dataset = dataset.concatenate(d)
        return dataset

    @property
    def instruments(self):
        """ Instrument dataset builder generator.

        :yield InstrumentBuilder instance.
        """
        if self._instrument_builders is None:
            self._instrument_builders = []
            for instrument in self._instruments:
                self._instrument_builders.append(
                    InstrumentDatasetBuilder(self, instrument))
        for builder in self._instrument_builders:
            yield builder

    def cache(self, dataset, cache, wait):
        """ Cache the given dataset if cache is enabled. Eventually waits for
        cache to be available (useful if another process is already computing
        cache) if provided wait flag is True.

        :param dataset: Dataset to be cached if cache is required.
        :param cache: Path of cache directory to be used, None if no cache.
        :param wait: If caching is enabled, True is cache should be waited.
        :returns: Cached dataset if needed, original dataset otherwise.
        """
        if cache is not None:
            if wait:
                while not exists(f'{cache}.index'):
                    get_logger().info(
                        'Cache not available, wait %s',
                        self.WAIT_PERIOD)
                    time.sleep(self.WAIT_PERIOD)
            cache_path = os.path.split(cache)[0]
            os.makedirs(cache_path, exist_ok=True)
            return dataset.cache(cache)
        return dataset

    def build(
            self, csv_path,
            batch_size=8, shuffle=True, convert_to_uint=True,
            random_data_augmentation=False, random_time_crop=True,
            infinite_generator=True, cache_directory=None,
            wait_for_cache=False, num_parallel_calls=4, n_chunks_per_song=2,):
        """
        TO BE DOCUMENTED.
        """
        dataset = dataset_from_csv(csv_path)
        dataset = self.compute_segments(dataset, n_chunks_per_song)
        # Shuffle data
        if shuffle:
            dataset = dataset.shuffle(
                buffer_size=200000,
                seed=self._random_seed,
                # useless since it is cached :
                reshuffle_each_iteration=True)
        # Expand audio path.
        dataset = dataset.map(self.expand_path)
        # Load waveform, compute spectrogram, and filtering error,
        # K bins frequencies, and waveform.
        N = num_parallel_calls
        for instrument in self.instruments:
            dataset = (
                dataset
                .map(instrument.load_waveform, num_parallel_calls=N)
                .filter(self.filter_error)
                .map(instrument.compute_spectrogram, num_parallel_calls=N)
                .map(instrument.filter_frequencies))
        dataset = dataset.map(self.filter_waveform)
        # Convert to uint before caching in order to save space.
        if convert_to_uint:
            for instrument in self.instruments:
                dataset = dataset.map(instrument.convert_to_uint)
        dataset = self.cache(dataset, cache_directory, wait_for_cache)
        # Check for INFINITY (should not happen)
        for instrument in self.instruments:
            dataset = dataset.filter(instrument.filter_infinity)
        # Repeat indefinitly
        if infinite_generator:
            dataset = dataset.repeat(count=-1)
        # Ensure same size for vocals and mix spectrograms.
        # NOTE: could be done before caching ?
        dataset = dataset.map(self.harmonize_spectrogram)
        # Filter out too short segment.
        # NOTE: could be done before caching ?
        dataset = dataset.filter(self.filter_short_segments)
        # Random time crop of 11.88s
        if random_time_crop:
            dataset = dataset.map(self.random_time_crop, num_parallel_calls=N)
        else:
            # frame_duration = 11.88/T
            # take central segment (for validation)
            for instrument in self.instruments:
                dataset = dataset.map(instrument.time_crop)
        # Post cache shuffling. Done where the data are the lightest:
        # after croping but before converting back to float.
        if shuffle:
            dataset = dataset.shuffle(
                buffer_size=256, seed=self._random_seed,
                reshuffle_each_iteration=True)
        # Convert back to float32
        if convert_to_uint:
            for instrument in self.instruments:
                dataset = dataset.map(
                    instrument.convert_to_float32, num_parallel_calls=N)
        M = 8  # Parallel call post caching.
        # Must be applied with the same factor on mix and vocals.
        if random_data_augmentation:
            dataset = (
                dataset
                .map(self.random_time_stretch, num_parallel_calls=M)
                .map(self.random_pitch_shift, num_parallel_calls=M))
        # Filter by shape (remove badly shaped tensors).
        for instrument in self.instruments:
            dataset = (
                dataset
                .filter(instrument.filter_shape)
                .map(instrument.reshape_spectrogram))
        # Select features and annotation.
        dataset = dataset.map(self.map_features)
        # Make batch (done after selection to avoid
        # error due to unprocessed instrument spectrogram batching).
        dataset = dataset.batch(batch_size)
        return dataset