#!/usr/bin/env python
# coding: utf8

""" Unit testing for Separator class. """

__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'

import filecmp
import itertools
from os.path import splitext, basename, exists, join
from tempfile import TemporaryDirectory

import pytest
import numpy as np

import tensorflow as tf

from spleeter import SpleeterError
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.separator import Separator

TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
BACKENDS = ["tensorflow", "librosa"]
MODELS = ['spleeter:2stems', 'spleeter:4stems', 'spleeter:5stems']

MODEL_TO_INST = {
    'spleeter:2stems': ('vocals', 'accompaniment'),
    'spleeter:4stems': ('vocals', 'drums', 'bass', 'other'),
    'spleeter:5stems': ('vocals', 'drums', 'bass', 'piano', 'other'),
}


MODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS))
TEST_CONFIGURATIONS = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS))


print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))


@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
def test_separate(test_file, configuration, backend):
    """ Test separation from raw data. """
    with tf.Session() as sess:
        instruments = MODEL_TO_INST[configuration]
        adapter = get_default_audio_adapter()
        waveform, _ = adapter.load(test_file)
        separator = Separator(configuration, stft_backend=backend)
        prediction = separator.separate(waveform, test_file)
        assert len(prediction) == len(instruments)
        for instrument in instruments:
            assert instrument in prediction
        for instrument in instruments:
            track = prediction[instrument]
            assert waveform.shape[:-1] == track.shape[:-1]
            assert not np.allclose(waveform, track)
            for compared in instruments:
                if instrument != compared:
                    assert not np.allclose(track, prediction[compared])


@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
def test_separate_to_file(test_file, configuration, backend):
    """ Test file based separation. """
    with tf.Session() as sess:
        instruments = MODEL_TO_INST[configuration]
        separator = Separator(configuration, stft_backend=backend)
        name = splitext(basename(test_file))[0]
        with TemporaryDirectory() as directory:
            separator.separate_to_file(
                test_file,
                directory)
            for instrument in instruments:
                assert exists(join(
                    directory,
                    '{}/{}.wav'.format(name, instrument)))


@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
def test_filename_format(test_file, configuration, backend):
    """ Test custom filename format. """
    with tf.Session() as sess:
        instruments = MODEL_TO_INST[configuration]
        separator = Separator(configuration, stft_backend=backend)
        name = splitext(basename(test_file))[0]
        with TemporaryDirectory() as directory:
            separator.separate_to_file(
                test_file,
                directory,
                filename_format='export/{filename}/{instrument}.{codec}')
            for instrument in instruments:
                assert exists(join(
                    directory,
                    'export/{}/{}.wav'.format(name, instrument)))


@pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES)
def test_filename_conflict(test_file, configuration):
    """ Test error handling with static pattern. """
    with tf.Session() as sess:
        separator = Separator(configuration)
        with TemporaryDirectory() as directory:
            with pytest.raises(SpleeterError):
                separator.separate_to_file(
                    test_file,
                    directory,
                    filename_format='I wanna be your lover')