# coding=utf-8 # Copyright 2020 The Tensor2Robot Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as python3 """TFModel abstract subclasses.""" import abc from typing import Any, Dict, Optional, Text from absl import flags import gin import six from tensor2robot.models import abstract_model from tensor2robot.utils import tensorspec_utils import tensorflow.compat.v1 as tf FLAGS = flags.FLAGS TRAIN = tf.estimator.ModeKeys.TRAIN EVAL = tf.estimator.ModeKeys.EVAL PREDICT = tf.estimator.ModeKeys.PREDICT RunConfigType = abstract_model.RunConfigType ParamsType = abstract_model.ParamsType DictOrSpec = abstract_model.DictOrSpec ModelTrainOutputType = abstract_model.ModelTrainOutputType ExportOutputType = abstract_model.ExportOutputType @gin.configurable @six.add_metaclass(abc.ABCMeta) class ClassificationModel(abstract_model.AbstractT2RModel): """Classification model.""" def __init__(self, loss_function=tf.losses.log_loss, **kwargs): """Constructor for ClassificationModel. Args: loss_function: Python function taking in (labels, predictions) that builds loss tensor. **kwargs: Additional arguments for the TFModel parent class. """ super(ClassificationModel, self).__init__(**kwargs) self._loss_function = loss_function self._label_specification = None self._state_specification = None def get_label_specification( self, mode): del mode return self._label_specification def get_feature_specification( self, mode): """Gets model inputs (including context) for inference. Arguments: mode: The mode for feature specifications Returns: feature_spec: A named tuple with fields for the state. """ del mode return tensorspec_utils.TensorSpecStruct(state=self.state_specification) @property def state_specification(self): return self._state_specification @state_specification.setter def state_specification(self, value): self._state_specification = value @state_specification.setter def label_specification(self, value): self._label_specification = value @abc.abstractmethod def a_func(self, features, scope, mode, config = None, params = None): """The F(state) function. We only need to define the a_func and loss_fn to have a proper model. For more specialization please overwrite inference_network_fn, model_*_fn. Args: features: This is the first item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. scope: String specifying variable scope. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. config: (Optional tf.estimator.RunConfig or contrib_tpu.RunConfig) Will receive what is passed to Estimator in config parameter, or the default config (tf.estimator.RunConfig). Allows updating things in your model_fn based on configuration such as num_ps_replicas, or model_dir. params: An optional dict of hyper parameters. Returns: outputs: A {key: Tensor} mapping. The key 'q_predicted' is required. """ def loss_fn(self, labels, inference_outputs): """Convenience function for classification models. We only need to define the a_func and loss_fn to have a proper model. For more specialization please overwrite inference_network_fn, model_*_fn. Args: labels: This is the second item returned from the input_fn and parsed by self._extract_and_validate_inputs. A dictionary which fulfills the requirements of the self.get_labels_spefication. inference_outputs: A dict containing the output tensors of model_inference_fn. Returns: A scalar loss tensor. """ return self._loss_function( labels=labels.classes, predictions=inference_outputs['a_predicted']) def inference_network_fn(self, features, labels, mode, config = None, params = None): """See base class.""" del labels outputs = self.a_func( features, scope='a_func', mode=mode, params=params, config=config) if not isinstance(outputs, dict): raise ValueError('The output of a_func is expected to be a dict.') if 'a_predicted' not in outputs: raise ValueError('For classification models a_predicted is a required ' 'key in outputs but is not in {}.'.format( list(outputs.keys()))) if self.use_summaries(params): tf.summary.histogram('a_t_predicted', outputs['a_predicted']) return outputs def model_train_fn(self, features, labels, inference_outputs, mode, config = None, params = None): """See base class.""" del features, mode, config, params loss = self.loss_fn(labels, inference_outputs) return loss def create_export_outputs_fn(self, features, inference_outputs, mode, config = None, params = None): """See base class.""" del features, mode, config, params predictions = {'prediction': inference_outputs['a_predicted']} return predictions def pack_state_to_feature_spec(self, state_params ): """Packs the state feature spec from the state. Args: state_params: Instance of state_spec_class. Returns: feature_spec: An instance of self.feature_spec_class. This contains features for the state. """ feature_spec = tensorspec_utils.TensorSpecStruct(state=state_params) return feature_spec def model_eval_fn(self, features, labels, inference_outputs, train_loss, train_outputs, mode, config = None, params = None): """See base class.""" eval_mse = tf.metrics.mean_squared_error( labels=labels.classes, predictions=inference_outputs['a_predicted'], name='eval_mse') predictions_rounded = tf.round(inference_outputs['a_predicted']) eval_precision = tf.metrics.precision( labels=labels.classes, predictions=predictions_rounded, name='eval_precision') eval_accuracy = tf.metrics.accuracy( labels=labels.classes, predictions=predictions_rounded, name='eval_accuracy') eval_recall = tf.metrics.recall( labels=labels.classes, predictions=predictions_rounded, name='eval_recall') metric_fn = { 'eval_mse': eval_mse, 'eval_precision': eval_precision, 'eval_accuracy': eval_accuracy, 'eval_recall': eval_recall } return metric_fn