######################################################################################## # # Hierarchical Attentive Recurrent Tracking # Copyright (C) 2017 Adam R. Kosiorek, Oxford Robotics Institute, University of Oxford # email: adamk@robots.ox.ac.uk # webpage: http://ori.ox.ac.uk # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. # ######################################################################################## import tensorflow as tf from neurocity.component.model import base class Model(base.BaseModel): """An abstraction of a Neural Network model. It simplifies layer management and supports train/test modes. To define a new model, derive a class and overload the _build method, which should be responsible for constructing the model, e.g.: class MLP(object): def _build(self): self.inpt = tf.placeholder(tf.float32, (32, 100), name='inpt') l1 = AffineLayer(self.inpt, 200) l2 = AffineLayer(l1, 10) Both layers will be registered in model.layers attribute. It is important to call: model.train_mode() - before training model.test_mode() - before testing Note: `mode` can be overwritten by calling one of the global setters: neurocity.train_mode() or neurocity.test_mode()""" def __init__(self, name='Model'): super(Model, self).__init__() self.name = name self.layers = [] model_vars = set(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) trainable_vars = set(tf.trainable_variables()) with tf.variable_scope(self.name): with self: self._build() base.get_model().register(self) self.trainable_vars = set(tf.trainable_variables()) - trainable_vars self.model_vars = set(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) - model_vars - self.trainable_vars @property def vars(self): return self.trainable_vars.union(self.model_vars) def saver(self, **kwargs): """Returns a Saver for all (trainable and model) variables used by the model. Model variables include e.g. moving mean and average in BatchNorm. :return: tf.Saver """ return tf.train.Saver(self.vars, **kwargs) def _build(self): raise NotImplementedError def __enter__(self): base.set_model(self) return self def __exit__(self, exc_type, exc_val, exc_tb): base.reset_model() if exc_type is None: return True