# -*- coding: utf-8 -*-
import unittest

from nose.tools import raises

import os
import torch
import tempfile
from kraken.lib import vgsl


class TestVGSL(unittest.TestCase):
    """
    Testing VGSL module
    """
    def test_helper_train(self):
        """
        Tests train/eval mode helper methods
        """
        rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
        rnn.train()
        self.assertTrue(torch.is_grad_enabled())
        self.assertTrue(rnn.nn.training)
        rnn.eval()
        self.assertFalse(torch.is_grad_enabled())
        self.assertFalse(rnn.nn.training)

    @unittest.skip('works randomly on ci')
    def test_helper_threads(self):
        """
        Test openmp threads helper method.
        """
        rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
        rnn.set_num_threads(4)
        self.assertEqual(torch.get_num_threads(), 4)

    def test_save_model(self):
        """
        Test model serialization.
        """
        rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
        with tempfile.TemporaryDirectory() as dir:
            rnn.save_model(dir + '/foo.mlmodel')
            self.assertTrue(os.path.exists(dir + '/foo.mlmodel'))

    def test_resize(self):
        """
        Tests resizing of output layers.
        """
        rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
        rnn.resize_output(80)
        self.assertEqual(rnn.nn[-1].lin.out_features, 80)

    def test_del_resize(self):
        """
        Tests resizing of output layers with entry deletion.
        """
        rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
        rnn.resize_output(80, [2, 4, 5, 6, 7, 12, 25])
        self.assertEqual(rnn.nn[-1].lin.out_features, 80)