# coding=utf-8 # Copyright 2020 The Tensor2Robot Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as python3 """TFModel abstract subclasses.""" import abc from typing import Callable, Optional, Text from absl import flags import gin import six from tensor2robot.models import abstract_model from tensor2robot.utils import tensorspec_utils import tensorflow.compat.v1 as tf FLAGS = flags.FLAGS TRAIN = tf.estimator.ModeKeys.TRAIN EVAL = tf.estimator.ModeKeys.EVAL PREDICT = tf.estimator.ModeKeys.PREDICT RunConfigType = abstract_model.RunConfigType ParamsType = abstract_model.ParamsType DictOrSpec = abstract_model.DictOrSpec ModelTrainOutputType = abstract_model.ModelTrainOutputType ExportOutputType = abstract_model.ExportOutputType @gin.configurable @six.add_metaclass(abc.ABCMeta) class CriticModel(abstract_model.AbstractT2RModel): """Critic model with continuous actions trained using MC returns.""" def __init__( self, loss_function = tf.losses.mean_squared_error, action_batch_size = None, **kwargs): """Constructor for ContinuousMCModel. Args: loss_function: Python function taking in (labels, predictions) that builds loss tensor. action_batch_size: If specified, a tiling of actions for prediction along a sub-dimension. **kwargs: Additional arguments for the TFModel parent class. """ super(CriticModel, self).__init__(**kwargs) self._loss_function = loss_function self._action_batch_size = action_batch_size self._tile_actions_for_predict = action_batch_size is not None # Rigid separation of state and action features, as they are treated # differently. State features are often duplicated across the examples in a # batch via a broadcast rule, while action features are unique to each # example in a batch. The state features and action features themselves # should not be nested. @abc.abstractmethod def get_action_specification(self): """Gets model inputs (including context) for the action for inference. Returns: action_params_spec: A named tuple with fields for the action. The action features holds all tensors that are unique to each action. """ @abc.abstractmethod def get_state_specification(self): """Gets model inputs (including context) for the state for inference. Returns: state_params_spec: A named tuple with fields for the state. The state features are shared by all potential actions. """ def pack_state_action_to_feature_spec( self, state_params, action_params ): """Gets a feature spec namedtuple from the state and action. Args: state_params: Instance of state_spec_class. action_params: Instance of action_spec_class. Returns: feature_spec: An instance of self.feature_spec_class. This contains features for both the action and state. """ return tensorspec_utils.TensorSpecStruct( state=state_params, action=action_params) def get_feature_specification( self, mode): """Gets model inputs (incl. context) for inference. Returns: feature_spec: A named tuple with fields for both the state and action. The state features are shared by all potential actions. The action component holds all tensors that are unique to each potential action. Arguments: mode: The mode for feature specifications """ feature_spec = tensorspec_utils.TensorSpecStruct( state=self.get_state_specification(), action=self.get_action_specification()) if mode == tf.estimator.ModeKeys.PREDICT and self._tile_actions_for_predict: def _expand_spec(spec): new_shape = ( tf.TensorShape([self._action_batch_size]).concatenate(spec.shape)) return tensorspec_utils.ExtendedTensorSpec.from_spec( spec, shape=new_shape) tiled_action_spec = tf.nest.map_structure(_expand_spec, self.get_action_specification()) return tensorspec_utils.TensorSpecStruct( state=self.get_state_specification(), action=tiled_action_spec) return feature_spec @abc.abstractmethod def q_func(self, features, scope, mode, config = None, params = None, reuse=tf.AUTO_REUSE): """Q(state, action) value function. We only need to define the q_func and loss_fn to have a proper model. For more specialization please overwrite inference_network_fn, model_*_fn. Args: features: This is the first item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. scope: String specifying variable scope. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. config: (Optional tf.estimator.RunConfig or contrib_tpu.RunConfig) Will receive what is passed to Estimator in config parameter, or the default config (tf.estimator.RunConfig). Allows updating things in your model_fn based on configuration such as num_ps_replicas, or model_dir. params: An optional dict of hyper parameters that will be passed into input_fn and model_fn. Keys are names of parameters, values are basic python types. There are reserved keys for TPUEstimator, including 'batch_size'. reuse: Whether or not to reuse variables under variable scope 'scope'. Returns: outputs: A {key: Tensor} mapping. The key 'q_predicted' is required. """ def loss_fn(self, features, labels, inference_outputs): """Convenience function for critic models. We only need to define the q_func and loss_fn to have a proper model. For more specialization please overwrite inference_network_fn, model_*_fn. Args: features: TensorSpecStruct encapsulating input features. labels: This is the second item returned from the input_fn and parsed by self._extract_and_validate_inputs. A dictionary which fulfills the requirements of the self.get_labels_spefication. inference_outputs: A dict containing the output tensors of model_inference_fn. Returns: A scalar loss tensor. """ del features return self._loss_function( labels=labels.reward, predictions=inference_outputs['q_predicted']) def inference_network_fn( self, features, labels, mode, config = None, params = None): """See base class.""" del labels outputs = self.q_func( features=features, mode=mode, scope='q_func', config=config, params=params, reuse=tf.AUTO_REUSE) if isinstance(outputs, tuple): update_ops = outputs[1] outputs = outputs[0] else: update_ops = None if not isinstance(outputs, dict): raise ValueError('The output of q_func is expected to be a dict.') if 'q_predicted' not in outputs: raise ValueError('For critic models q_predicted is a required key in ' 'outputs but is not in {}.'.format(list(outputs.keys()))) if self.use_summaries(params): tf.summary.histogram('q_t_predicted', outputs['q_predicted']) return outputs, update_ops def model_train_fn(self, features, labels, inference_outputs, mode, config = None, params = None): """See base class.""" del mode, config, params loss = self.loss_fn(features, labels, inference_outputs) return loss