# Copyright 2017 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Classes for FFN model definition.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from . import optimizer class FFNModel(object): """Base class for FFN models.""" # Dimensionality of the model (2 or 3). dim = None ############################################################################ # (x, y, z) tuples defining various properties of the network. # Note that 3-tuples should be used even for 2D networks, in which case # the third (z) value is ignored. # How far to move the field of view in the respective directions. deltas = None # Size of the input image and seed subvolumes to be used during inference. # This is enough information to execute a single prediction step, without # moving the field of view. input_image_size = None input_seed_size = None # Size of the predicted patch as returned by the model. pred_mask_size = None ########################################################################### # TF op to compute loss optimized during training. This should include all # loss components in case more than just the pixelwise loss is used. loss = None # TF op to call to perform loss optimization on the model. train_op = None def __init__(self, deltas, batch_size=None, define_global_step=True): assert self.dim is not None self.deltas = deltas self.batch_size = batch_size # Initialize the shift collection. This is used during training with the # fixed step size policy. self.shifts = [] for dx in (-self.deltas[0], 0, self.deltas[0]): for dy in (-self.deltas[1], 0, self.deltas[1]): for dz in (-self.deltas[2], 0, self.deltas[2]): if dx == 0 and dy == 0 and dz == 0: continue self.shifts.append((dx, dy, dz)) if define_global_step: self.global_step = tf.Variable(0, name='global_step', trainable=False) # The seed is always a placeholder which is fed externally from the # training/inference drivers. self.input_seed = tf.placeholder(tf.float32, name='seed') self.input_patches = tf.placeholder(tf.float32, name='patches') # For training, labels should be defined as a TF object. self.labels = None # Optional. Provides per-pixel weights with which the loss is multiplied. # If specified, should have the same shape as self.labels. self.loss_weights = None self.logits = None # type: tf.Operation # List of image tensors to save in summaries. The images are concatenated # along the X axis. self._images = [] def set_uniform_io_size(self, patch_size): """Initializes unset input/output sizes to 'patch_size', sets input shapes. This assumes that the inputs and outputs are of equal size, and that exactly one step is executed in every direction during training. Args: patch_size: (x, y, z) specifying the input/output patch size Returns: None """ if self.pred_mask_size is None: self.pred_mask_size = patch_size if self.input_seed_size is None: self.input_seed_size = patch_size if self.input_image_size is None: self.input_image_size = patch_size self.set_input_shapes() def set_input_shapes(self): """Sets the shape inference for input_seed and input_patches. Assumes input_seed_size and input_image_size are already set. """ self.input_seed.set_shape([self.batch_size] + list(self.input_seed_size[::-1]) + [1]) self.input_patches.set_shape([self.batch_size] + list(self.input_image_size[::-1]) + [1]) def set_up_sigmoid_pixelwise_loss(self, logits): """Sets up the loss function of the model.""" assert self.labels is not None assert self.loss_weights is not None pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=self.labels) pixel_loss *= self.loss_weights self.loss = tf.reduce_mean(pixel_loss) tf.summary.scalar('pixel_loss', self.loss) self.loss = tf.verify_tensor_all_finite(self.loss, 'Invalid loss detected') def set_up_optimizer(self, loss=None, max_gradient_entry_mag=0.7): """Sets up the training op for the model.""" if loss is None: loss = self.loss tf.summary.scalar('optimizer_loss', self.loss) opt = optimizer.optimizer_from_flags() grads_and_vars = opt.compute_gradients(loss) for g, v in grads_and_vars: if g is None: tf.logging.error('Gradient is None: %s', v.op.name) if max_gradient_entry_mag > 0.0: grads_and_vars = [(tf.clip_by_value(g, -max_gradient_entry_mag, +max_gradient_entry_mag), v) for g, v, in grads_and_vars] trainables = tf.trainable_variables() if trainables: for var in trainables: tf.summary.histogram(var.name.replace(':0', ''), var) for grad, var in grads_and_vars: tf.summary.histogram( 'gradients/%s' % var.name.replace(':0', ''), grad) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): self.train_op = opt.apply_gradients(grads_and_vars, global_step=self.global_step, name='train') def show_center_slice(self, image, sigmoid=True): image = image[:, image.get_shape().dims[1] // 2, :, :, :] if sigmoid: image = tf.sigmoid(image) self._images.append(image) def add_summaries(self): pass def update_seed(self, seed, update): """Updates the initial 'seed' with 'update'.""" dx = self.input_seed_size[0] - self.pred_mask_size[0] dy = self.input_seed_size[1] - self.pred_mask_size[1] dz = self.input_seed_size[2] - self.pred_mask_size[2] if dx == 0 and dy == 0 and dz == 0: seed += update else: seed += tf.pad(update, [[0, 0], [dz // 2, dz - dz // 2], [dy // 2, dy - dy // 2], [dx // 2, dx - dx // 2], [0, 0]]) return seed def define_tf_graph(self): """Creates the TensorFlow graph representing the model. If self.labels is not None, the graph should include operations for computing and optimizing the loss. """ raise NotImplementedError( 'DefineTFGraph needs to be defined by a subclass.')