# coding=utf-8 # Copyright 2020 The Trax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for trax.rl.serialization_utils.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized import gin import gym from jax import numpy as jnp import numpy as np from tensorflow import test from trax import shapes from trax.layers import base as layers_base from trax.models import transformer from trax.rl import serialization_utils from trax.rl import space_serializer # pylint: disable=invalid-name def TestModel(extra_dim): """Dummy sequence model for testing.""" def f(inputs): # Cast the input to float32 - this is for simulating discrete-input models. inputs = inputs.astype(np.float32) # Add an extra dimension if requested, e.g. the logit dimension for output # symbols. if extra_dim is not None: return jnp.broadcast_to(inputs[:, :, None], inputs.shape + (extra_dim,)) else: return inputs return layers_base.Fn('TestModel', f) # pylint: enable=invalid-name class SerializationTest(parameterized.TestCase): def setUp(self): super(SerializationTest, self).setUp() self._serializer = space_serializer.create( gym.spaces.Discrete(2), vocab_size=2 ) self._repr_length = 100 self._serialization_utils_kwargs = { 'observation_serializer': self._serializer, 'action_serializer': self._serializer, 'representation_length': self._repr_length, } def test_serialized_model_discrete(self): vocab_size = 3 obs = np.array([[[0, 1], [1, 1], [1, 0], [0, 0]]]) act = np.array([[1, 0, 0]]) mask = np.array([[1, 1, 1, 0]]) test_model_inputs = [] # pylint: disable=invalid-name def TestModelSavingInputs(): def f(inputs): # Save the inputs for a later check. test_model_inputs.append(inputs) # Change type to np.float32 and add the logit dimension. return jnp.broadcast_to( inputs.astype(np.float32)[:, :, None], inputs.shape + (vocab_size,) ) return layers_base.Fn('TestModelSavingInputs', f) # pylint: enable=invalid-name obs_serializer = space_serializer.create( gym.spaces.MultiDiscrete([2, 2]), vocab_size=vocab_size ) act_serializer = space_serializer.create( gym.spaces.Discrete(2), vocab_size=vocab_size ) serialized_model = serialization_utils.SerializedModel( TestModelSavingInputs(), # pylint: disable=no-value-for-parameter observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) example = (obs, act, obs, mask) serialized_model.init(shapes.signature(example)) (obs_logits, obs_repr, weights) = serialized_model(example) # Check that the model has been called with the correct input. np.testing.assert_array_equal( # The model is called multiple times for determining shapes etc. # Check the last saved input - that should be the actual concrete array # calculated during the forward pass. test_model_inputs[-1], # Should be serialized observations and actions interleaved. [[0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0]], ) # Check the output shape. self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size,)) # Check that obs_logits are the same as obs_repr, just broadcasted over the # logit dimension. np.testing.assert_array_equal(np.min(obs_logits, axis=-1), obs_repr) np.testing.assert_array_equal(np.max(obs_logits, axis=-1), obs_repr) # Check that the observations are correct. np.testing.assert_array_equal(obs_repr, obs) # Check weights. np.testing.assert_array_equal(weights, [[[1, 1], [1, 1], [1, 1], [0, 0]]]) def test_serialized_model_continuous(self): precision = 3 gin.bind_parameter('BoxSpaceSerializer.precision', precision) vocab_size = 32 obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0, 0]]]) act = np.array([[0, 1, 0]]) mask = np.array([[1, 1, 1, 0]]) obs_serializer = space_serializer.create( gym.spaces.Box(shape=(2,), low=-2, high=2), vocab_size=vocab_size ) act_serializer = space_serializer.create( gym.spaces.Discrete(2), vocab_size=vocab_size ) serialized_model = serialization_utils.SerializedModel( TestModel(extra_dim=vocab_size), # pylint: disable=no-value-for-parameter observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) example = (obs, act, obs, mask) serialized_model.init(shapes.signature(example)) (obs_logits, obs_repr, weights) = serialized_model(example) self.assertEqual(obs_logits.shape, obs_repr.shape + (vocab_size,)) self.assertEqual( obs_repr.shape, (1, obs.shape[1], obs.shape[2] * precision) ) self.assertEqual(obs_repr.shape, weights.shape) def test_extract_inner_model(self): vocab_size = 3 inner_model = transformer.TransformerLM( vocab_size=vocab_size, d_model=2, d_ff=2, n_layers=0 ) obs_serializer = space_serializer.create( gym.spaces.Discrete(2), vocab_size=vocab_size ) act_serializer = space_serializer.create( gym.spaces.Discrete(2), vocab_size=vocab_size ) serialized_model = serialization_utils.SerializedModel( inner_model, observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) obs_sig = shapes.ShapeDtype((1, 2)) act_sig = shapes.ShapeDtype((1, 1)) (weights, state) = serialized_model.init( input_signature=(obs_sig, act_sig, obs_sig, obs_sig), ) (inner_weights, inner_state) = map( serialization_utils.extract_inner_model, (weights, state) ) inner_model(jnp.array([[0]]), weights=inner_weights, state=inner_state) @parameterized.named_parameters(('raw', None), ('serialized', 32)) def test_wrapped_policy_continuous(self, vocab_size): precision = 3 n_controls = 2 n_actions = 4 gin.bind_parameter('BoxSpaceSerializer.precision', precision) obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0.01, 0.66]]]) act = np.array([[[0, 1], [2, 0], [1, 3]]]) wrapped_policy = serialization_utils.wrap_policy( TestModel(extra_dim=vocab_size), # pylint: disable=no-value-for-parameter observation_space=gym.spaces.Box(shape=(2,), low=-2, high=2), action_space=gym.spaces.MultiDiscrete([n_actions] * n_controls), vocab_size=vocab_size, ) example = (obs, act) wrapped_policy.init(shapes.signature(example)) (act_logits, values) = wrapped_policy(example) self.assertEqual(act_logits.shape, obs.shape[:2] + (n_controls, n_actions)) self.assertEqual(values.shape, obs.shape[:2]) def test_analyzes_discrete_action_space(self): space = gym.spaces.Discrete(n=5) (n_controls, n_actions) = serialization_utils.analyze_action_space(space) self.assertEqual(n_controls, 1) self.assertEqual(n_actions, 5) def test_analyzes_multi_discrete_action_space_with_equal_categories(self): space = gym.spaces.MultiDiscrete(nvec=(3, 3)) (n_controls, n_actions) = serialization_utils.analyze_action_space(space) self.assertEqual(n_controls, 2) self.assertEqual(n_actions, 3) def test_doesnt_analyze_multi_disccrete_action_space_with_inequal_categories( self ): space = gym.spaces.MultiDiscrete(nvec=(2, 3)) with self.assertRaises(AssertionError): serialization_utils.analyze_action_space(space) def test_doesnt_analyze_box_action_space(self): space = gym.spaces.Box(shape=(2, 3), low=0, high=1) with self.assertRaises(AssertionError): serialization_utils.analyze_action_space(space) if __name__ == '__main__': test.main()