"""Unit tests for the causal_conv op."""

import numpy as np
import tensorflow as tf

from wavenet import time_to_batch, batch_to_time, causal_conv


class TestCausalConv(tf.test.TestCase):

    def testCausalConv(self):
        """Tests that the op is equivalent to a numpy implementation."""
        x1 = np.arange(1, 21, dtype=np.float32)
        x = np.append(x1, x1)
        x = np.reshape(x, [2, 20, 1])
        f = np.reshape(np.array([1, 1], dtype=np.float32), [2, 1, 1])
        out = causal_conv(x, f, 4)

        with self.test_session() as sess:
            result = sess.run(out)

        # Causal convolution using numpy
        ref = np.convolve(x1, [1, 0, 0, 0, 1], mode='valid')
        ref = np.append(ref, ref)
        ref = np.reshape(ref, [2, 16, 1])

        self.assertAllEqual(result, ref)

    def testNoTimeShift(self):
        """Tests that the convolution does not introduce a time shift.

        We give it a time series, choose a filter that should be the identity,
        and assert that the output is not shifted at all relative to the input.
        """
        # Input to filter is a time series of values 1..10
        x = np.arange(1, 11, dtype=np.float32)
        # Reshape the input: shape is batch item x duration x channels = 1x10x1
        x = np.reshape(x, [1, 10, 1])
        # Default shape ordering for conv filter = HWIO for 2d. Since we use
        # 1d, this just becomes WxIxO where:
        #   W = width AKA number of time steps in time series = 2
        #   I = input channels = 1
        #   O = output channels = 1
        # Since the filter is size 2, for it to be identity-preserving, one
        # value is 1.0, the other 0.0
        filter = np.reshape(np.array([0.0, 1.0], dtype=np.float32), [2, 1, 1])

        x_padded = np.pad(x, [[0, 0], [2, 0], [0, 0]], 'constant')

        # Compute the output
        out = causal_conv(x_padded, filter, dilation=2)

        with self.test_session() as sess:
            result = sess.run(out)

        # The shapes should be the same.
        self.assertAllEqual(result.shape, x.shape)

        # The output time series should be identical to the input series.
        self.assertAllEqual(result, x)


if __name__ == '__main__':
    tf.test.main()