# -*- coding: utf-8 -*-
#
"""
Unit tests for mle losses.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

# pylint: disable=invalid-name

import numpy as np

import tensorflow as tf

import texar as tx

class MLELossesTest(tf.test.TestCase):
    """Tests mle losses.
    """

    def setUp(self):
        tf.test.TestCase.setUp(self)
        self._batch_size = 64
        self._max_time = 16
        self._num_classes = 100
        self._labels = tf.ones([self._batch_size, self._max_time],
                               dtype=tf.int32)
        one_hot_labels = tf.one_hot(
            self._labels, self._num_classes, dtype=tf.float32)
        self._one_hot_labels = tf.reshape(
            one_hot_labels, [self._batch_size, self._max_time, -1])
        self._logits = tf.random_uniform(
            [self._batch_size, self._max_time, self._num_classes])
        self._sequence_length = tf.random_uniform(
            [self._batch_size], maxval=self._max_time, dtype=tf.int32)

    def _test_sequence_loss(self, loss_fn, labels, logits, sequence_length):
        with self.test_session() as sess:
            loss = loss_fn(labels, logits, sequence_length)
            rank = sess.run(tf.rank(loss))
            self.assertEqual(rank, 0)

            loss = loss_fn(
                labels, logits, sequence_length, sum_over_timesteps=False)
            rank = sess.run(tf.rank(loss))
            self.assertEqual(rank, 1)
            self.assertEqual(loss.shape, tf.TensorShape([self._max_time]))

            loss = loss_fn(
                labels, logits, sequence_length, sum_over_timesteps=False,
                average_across_timesteps=True, average_across_batch=False)
            rank = sess.run(tf.rank(loss))
            self.assertEqual(rank, 1)
            self.assertEqual(loss.shape, tf.TensorShape([self._batch_size]))

            loss = loss_fn(
                labels, logits, sequence_length, sum_over_timesteps=False,
                average_across_batch=False)
            rank = sess.run(tf.rank(loss))
            self.assertEqual(rank, 2)
            self.assertEqual(loss.shape,
                             tf.TensorShape([self._batch_size, self._max_time]))

            sequence_length_time = tf.random_uniform(
                [self._max_time], maxval=self._max_time, dtype=tf.int32)
            loss = loss_fn(
                labels, logits, sequence_length_time, sum_over_timesteps=False,
                average_across_batch=False, time_major=True)
            self.assertEqual(loss.shape,
                             tf.TensorShape([self._batch_size, self._max_time]))

    def test_sequence_softmax_cross_entropy(self):
        """Tests `sequence_softmax_cross_entropy`
        """
        self._test_sequence_loss(
            tx.losses.sequence_softmax_cross_entropy,
            self._one_hot_labels, self._logits, self._sequence_length)

    def test_sequence_sparse_softmax_cross_entropy(self):
        """Tests `sequence_sparse_softmax_cross_entropy`
        """
        self._test_sequence_loss(
            tx.losses.sequence_sparse_softmax_cross_entropy,
            self._labels, self._logits, self._sequence_length)

    def test_sequence_sigmoid_cross_entropy(self):
        """Tests `texar.losses.test_sequence_sigmoid_cross_entropy`.
        """
        self._test_sequence_loss(
            tx.losses.sequence_sigmoid_cross_entropy,
            self._one_hot_labels, self._logits, self._sequence_length)

        self._test_sequence_loss(
            tx.losses.sequence_sigmoid_cross_entropy,
            self._one_hot_labels[:, :, 0],
            self._logits[:, :, 0],
            self._sequence_length)

        labels = tf.placeholder(dtype=tf.int32, shape=None)
        loss = tx.losses.sequence_sigmoid_cross_entropy(
            logits=self._logits[:, :, 0],
            labels=tf.to_float(labels),
            sequence_length=self._sequence_length)
        with self.test_session() as sess:
            rank = sess.run(
                tf.rank(loss),
                feed_dict={labels: np.ones([self._batch_size, self._max_time])})
            self.assertEqual(rank, 0)


if __name__ == "__main__":
    tf.test.main()