import pytest import theano import theano.tensor as T import numpy as np from .model import memory_augmented_neural_network def test_batch_size(): input_var_1, input_var_2 = T.tensor3s('input1', 'input2') target_var_1, target_var_2 = T.imatrices('target1', 'target2') # First model with `batch_size=16` output_var_1, _, params1 = memory_augmented_neural_network( input_var_1, target_var_1, batch_size=16, nb_class=5, memory_shape=(128, 40), controller_size=200, input_size=20 * 20, nb_reads=4) # Second model with `batch_size=1` output_var_2, _, params2 = memory_augmented_neural_network( input_var_2, target_var_2, batch_size=1, nb_class=5, memory_shape=(128, 40), controller_size=200, input_size=20 * 20, nb_reads=4) for (param1, param2) in zip(params1, params2): param2.set_value(param1.get_value()) posterior_fn1 = theano.function([input_var_1, target_var_1], output_var_1) posterior_fn2 = theano.function([input_var_2, target_var_2], output_var_2) # Input has shape (batch_size, timesteps, vocabulary_size + actions_vocabulary_size + 3) test_input = np.random.rand(16, 50, 20 * 20) test_target = np.random.randint(5, size=(16, 50)).astype('int32') test_output1 = posterior_fn1(test_input, test_target) test_output2 = np.zeros_like(test_output1) for i in range(16): test_output2[i] = posterior_fn2(test_input[i][np.newaxis, :, :], test_target[i][np.newaxis, :]) assert np.allclose(test_output1, test_output2) def test_shape(): input_var = T.tensor3('input') target_var = T.imatrix('target') output_var, _, _ = memory_augmented_neural_network( input_var, target_var, batch_size=16, nb_class=5, memory_shape=(128, 40), controller_size=200, input_size=20 * 20, nb_reads=4) posterior_fn = theano.function([input_var, target_var], output_var) test_input = np.random.rand(16, 50, 20 * 20) test_target = np.random.randint(5, size=(16, 50)).astype('int32') test_input_invalid_batch_size = np.random.rand(16 + 1, 50, 20 * 20) test_input_invalid_depth = np.random.rand(16, 50, 20 * 20 - 1) test_output = posterior_fn(test_input, test_target) assert test_output.shape == (16, 50, 5) with pytest.raises(ValueError) as e_info: posterior_fn(test_input_invalid_batch_size, test_target) with pytest.raises(ValueError) as e_info: posterior_fn(test_input_invalid_depth, test_target)