# Copyright 2018/2019 The RLgraph authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import unittest from six.moves import xrange as range_ from rlgraph.spaces import * from rlgraph.utils.ops import FLAT_TUPLE_CLOSE, FLAT_TUPLE_OPEN class TestSpaces(unittest.TestCase): """ Tests creation, sampling and shapes of Spaces. """ def test_box_spaces(self): """ Tests all BoxSpaces via sample/contains loop. With and without batch-rank, different batch sizes, and different los/high combinations (including no bounds). """ for class_ in [FloatBox, IntBox, BoolBox, TextBox]: for add_batch_rank in [False, True]: # TODO: Test time-rank more thoroughly. for add_time_rank in [False, True]: if class_ != BoolBox and class_ != TextBox: for low, high in [(None, None), (-1.0, 10.0), ((1.0, 2.0), (3.0, 4.0)), (((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)), ((7.0, 8.0, 9.0), (10.0, 11.0, 12.0)))]: space = class_(low=low, high=high, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank) if add_batch_rank is False: sample = space.sample() self.assertTrue(space.contains(sample)) else: for batch_size in range_(1, 4): samples = space.sample(size=batch_size) for s in samples: self.assertTrue(space.contains(s)) # TODO: test zero() method perperly for all cases #all_0s = space.zeros() #self.assertTrue(all(v == 0 for v in all_0s)) else: space = class_(add_batch_rank=add_batch_rank, add_time_rank=add_time_rank) if add_batch_rank is False: sample = space.sample() self.assertTrue(space.contains(sample)) else: for batch_size in range_(1, 4): samples = space.sample(size=batch_size) for s in samples: self.assertTrue(space.contains(s)) def test_complex_space_sampling_and_check_via_contains(self): """ Tests a complex Space on sampling and `contains` functionality. """ space = Dict( a=dict(aa=float, ab=bool), b=dict(ba=float), c=float, d=IntBox(low=0, high=1), e=IntBox(5), f=FloatBox(shape=(2, 2)), g=Tuple(float, FloatBox(shape=())), add_batch_rank=True ) samples = space.sample(size=100, horizontal=True) for i in range_(len(samples)): self.assertTrue(space.contains(samples[i])) def test_container_space_flattening_with_mapping(self): space = Tuple( Dict( a=bool, b=IntBox(4), c=Dict( d=FloatBox(shape=()) ) ), BoolBox(), IntBox(2), FloatBox(shape=(3, 2)), Tuple( BoolBox(), BoolBox() ) ) def mapping_func(key, primitive_space): # Just map a primitive Space to its flat_dim property. return primitive_space.flat_dim result = "" flat_space_and_mapped = space.flatten(mapping=mapping_func, scope_separator_at_start=False) for key, value in flat_space_and_mapped.items(): result += "{}:{},".format(key, value) tuple_txt = [FLAT_TUPLE_OPEN, FLAT_TUPLE_CLOSE] * 10 expected = "{}0{}/a:1,{}0{}/b:1,{}0{}/c/d:1,{}1{}:1,{}2{}:1,{}3{}:6,{}4{}/{}0{}:1,{}4{}/{}1{}:1,".\ format(*tuple_txt) self.assertTrue(result == expected) def test_container_space_mapping(self): space = Tuple( Dict( a=bool, b=IntBox(4), c=Dict( d=FloatBox(shape=()) ) ), BoolBox(), IntBox(2), FloatBox(shape=(3, 2)), Tuple( BoolBox(), BoolBox() ) ) def mapping_func(key, primitive_space): # Change each primitive space to IntBox(5). return IntBox(5) mapped_space = space.map(mapping=mapping_func) self.assertTrue(isinstance(mapped_space[0]["a"], IntBox)) self.assertTrue(mapped_space[0]["a"].num_categories == 5) self.assertTrue(mapped_space[3].num_categories == 5) self.assertTrue(mapped_space[4][0].num_categories == 5) self.assertTrue(mapped_space[4][1].num_categories == 5) # Same on Dict. space = Dict( a=bool, b=IntBox(4), c=Dict( d=FloatBox(shape=()) ) ) mapped_space = space.map(mapping=mapping_func) self.assertTrue(isinstance(mapped_space["a"], IntBox)) self.assertTrue(mapped_space["a"].num_categories == 5) self.assertTrue(isinstance(mapped_space["b"], IntBox)) self.assertTrue(mapped_space["c"]["d"].num_categories == 5)