# pylint: disable=g-bad-file-header # Copyright 2019 The dm_env Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Classes that describe numpy arrays.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import sys import numpy as np # pylint: disable=g-import-not-at-top if sys.version_info >= (3, 3): # `inspect.signature` was added in Python 3.3 import inspect else: import funcsigs as inspect # pylint: enable=g-import-not-at-top _INVALID_SHAPE = 'Expected shape %r but found %r' _INVALID_DTYPE = 'Expected dtype %r but found %r' _OUT_OF_BOUNDS = 'Values were not all within bounds %s <= %s <= %s' _VAR_ARGS_NOT_ALLOWED = 'Spec subclasses must not accept *args.' _VAR_KWARGS_NOT_ALLOWED = 'Spec subclasses must not accept **kwargs.' class Array(object): """Describes a numpy array or scalar shape and dtype. An `Array` spec allows an API to describe the arrays that it accepts or returns, before that array exists. The equivalent version describing a `tf.Tensor` is `TensorSpec`. """ __slots__ = ('_shape', '_dtype', '_name') __hash__ = None def __init__(self, shape, dtype, name=None): """Initializes a new `Array` spec. Args: shape: An iterable specifying the array shape. dtype: numpy dtype or string specifying the array dtype. name: Optional string containing a semantic name for the corresponding array. Defaults to `None`. Raises: TypeError: If `shape` is not an iterable of elements convertible to int, or if `dtype` is not convertible to a numpy dtype. """ self._shape = tuple(int(dim) for dim in shape) self._dtype = np.dtype(dtype) self._name = name @property def shape(self): """Returns a `tuple` specifying the array shape.""" return self._shape @property def dtype(self): """Returns a numpy dtype specifying the array dtype.""" return self._dtype @property def name(self): """Returns the name of the Array.""" return self._name def __repr__(self): return 'Array(shape={}, dtype={}, name={})'.format(self.shape, repr(self.dtype), repr(self.name)) def __eq__(self, other): """Checks if the shape and dtype of two specs are equal.""" if not isinstance(other, Array): return False return self.shape == other.shape and self.dtype == other.dtype def __ne__(self, other): return not self == other def _fail_validation(self, message, *args): message %= args if self.name: message += ' for spec %s' % self.name raise ValueError(message) def validate(self, value): """Checks if value conforms to this spec. Args: value: a numpy array or value convertible to one via `np.asarray`. Returns: value, converted if necessary to a numpy array. Raises: ValueError: if value doesn't conform to this spec. """ value = np.asarray(value) if value.shape != self.shape: self._fail_validation(_INVALID_SHAPE, self.shape, value.shape) if value.dtype != self.dtype: self._fail_validation(_INVALID_DTYPE, self.dtype, value.dtype) return value def generate_value(self): """Generate a test value which conforms to this spec.""" return np.zeros(shape=self.shape, dtype=self.dtype) def _get_constructor_kwargs(self): """Returns constructor kwargs for instantiating a new copy of this spec.""" # Get the names and kinds of the constructor parameters. params = inspect.signature(type(self)).parameters # __init__ must not accept *args or **kwargs, since otherwise we won't be # able to infer what the corresponding attribute names are. kinds = {value.kind for value in params.values()} if inspect.Parameter.VAR_POSITIONAL in kinds: raise TypeError(_VAR_ARGS_NOT_ALLOWED) elif inspect.Parameter.VAR_KEYWORD in kinds: raise TypeError(_VAR_KWARGS_NOT_ALLOWED) # Note that we assume direct correspondence between the names of constructor # arguments and attributes. return {name: getattr(self, name) for name in params.keys()} def replace(self, **kwargs): """Returns a new copy of `self` with specified attributes replaced. Args: **kwargs: Optional attributes to replace. Returns: A new copy of `self`. """ all_kwargs = self._get_constructor_kwargs() all_kwargs.update(kwargs) return type(self)(**all_kwargs) def __reduce__(self): return Array, (self._shape, self._dtype, self._name) class BoundedArray(Array): """An `Array` spec that specifies minimum and maximum values. Example usage: ```python # Specifying the same minimum and maximum for every element. spec = BoundedArray((3, 4), np.float64, minimum=0.0, maximum=1.0) # Specifying a different minimum and maximum for each element. spec = BoundedArray( (2,), np.float64, minimum=[0.1, 0.2], maximum=[0.9, 0.9]) # Specifying the same minimum and a different maximum for each element. spec = BoundedArray( (3,), np.float64, minimum=-10.0, maximum=[4.0, 5.0, 3.0]) ``` Bounds are meant to be inclusive. This is especially important for integer types. The following spec will be satisfied by arrays with values in the set {0, 1, 2}: ```python spec = BoundedArray((3, 4), np.int, minimum=0, maximum=2) ``` """ __slots__ = ('_minimum', '_maximum') __hash__ = None def __init__(self, shape, dtype, minimum, maximum, name=None): """Initializes a new `BoundedArray` spec. Args: shape: An iterable specifying the array shape. dtype: numpy dtype or string specifying the array dtype. minimum: Number or sequence specifying the minimum element bounds (inclusive). Must be broadcastable to `shape`. maximum: Number or sequence specifying the maximum element bounds (inclusive). Must be broadcastable to `shape`. name: Optional string containing a semantic name for the corresponding array. Defaults to `None`. Raises: ValueError: If `minimum` or `maximum` are not broadcastable to `shape`. TypeError: If the shape is not an iterable or if the `dtype` is an invalid numpy dtype. """ super(BoundedArray, self).__init__(shape, dtype, name) try: np.broadcast_to(minimum, shape=shape) except ValueError as numpy_exception: raise ValueError('minimum is not compatible with shape. ' 'Message: {!r}.'.format(numpy_exception)) try: np.broadcast_to(maximum, shape=shape) except ValueError as numpy_exception: raise ValueError('maximum is not compatible with shape. ' 'Message: {!r}.'.format(numpy_exception)) self._minimum = np.array(minimum, dtype=self.dtype) self._minimum.setflags(write=False) self._maximum = np.array(maximum, dtype=self.dtype) self._maximum.setflags(write=False) @property def minimum(self): """Returns a NumPy array specifying the minimum bounds (inclusive).""" return self._minimum @property def maximum(self): """Returns a NumPy array specifying the maximum bounds (inclusive).""" return self._maximum def __repr__(self): template = ('BoundedArray(shape={}, dtype={}, name={}, ' 'minimum={}, maximum={})') return template.format(self.shape, repr(self.dtype), repr(self.name), self._minimum, self._maximum) def __eq__(self, other): if not isinstance(other, BoundedArray): return False return (super(BoundedArray, self).__eq__(other) and (self.minimum == other.minimum).all() and (self.maximum == other.maximum).all()) def validate(self, value): value = np.asarray(value) super(BoundedArray, self).validate(value) if (value < self.minimum).any() or (value > self.maximum).any(): self._fail_validation(_OUT_OF_BOUNDS, self.minimum, value, self.maximum) return value def generate_value(self): return (np.ones(shape=self.shape, dtype=self.dtype) * self.dtype.type(self.minimum)) def __reduce__(self): return BoundedArray, (self._shape, self._dtype, self._minimum, self._maximum, self._name) _NUM_VALUES_NOT_POSITIVE = '`num_values` must be a positive integer, got {}.' _DTYPE_NOT_INTEGRAL = '`dtype` must be integral, got {}.' _DTYPE_OVERFLOW = ( '`dtype` {} is not big enough to hold `num_values` ({}) without overflow.') class DiscreteArray(BoundedArray): """Represents a discrete, scalar, zero-based space. Concretely this is a 0-dimensional numpy array containing a single integer value between 0 and num_items - 1 (inclusive). """ _REPR_TEMPLATE = ( 'DiscreteArray(shape={self.shape}, dtype={self.dtype}, name={self.name}, ' 'minimum={self.minimum}, maximum={self.maximum}, ' 'num_values={self.num_values})') __slots__ = ('_num_values',) def __init__(self, num_values, dtype=np.int32, name=None): """Initializes a new `DiscreteArray` spec. Args: num_values: Integer specifying the number of possible values to represent. dtype: The dtype of the array. Must be an integral type large enough to hold `num_values` without overflow. name: Optional string specifying the name of the array. Raises: ValueError: If `num_values` is not positive, if `dtype` is not integral, or if `dtype` is not large enough to hold `num_values` without overflow. """ if num_values <= 0 or not np.issubdtype(type(num_values), np.integer): raise ValueError(_NUM_VALUES_NOT_POSITIVE.format(num_values)) if not np.issubdtype(dtype, np.integer): raise ValueError(_DTYPE_NOT_INTEGRAL.format(dtype)) num_values = int(num_values) maximum = num_values - 1 dtype = np.dtype(dtype) if np.min_scalar_type(maximum) > dtype: raise ValueError(_DTYPE_OVERFLOW.format(dtype, num_values)) super(DiscreteArray, self).__init__( shape=(), dtype=dtype, minimum=0, maximum=maximum, name=name) self._num_values = num_values @property def num_values(self): """Returns the number of items.""" return self._num_values def __repr__(self): return self._REPR_TEMPLATE.format(self=self) def __reduce__(self): return DiscreteArray, (self._num_values, self._dtype, self._name)