""" Original code from OSVOS (https://github.com/scaelles/OSVOS-TensorFlow) Sergi Caelles (scaelles@vision.ee.ethz.ch) Modified code for liver and lesion segmentation: Miriam Bellver (miriam.bellver@bsc.es) """ import tensorflow as tf import numpy as np from tensorflow.contrib.layers.python.layers import utils from tensorflow.contrib.layers.python.layers import initializers import sys from datetime import datetime import os import scipy.misc from PIL import Image slim = tf.contrib.slim from tensorflow.contrib.slim.nets import resnet_v1 import scipy.io import scipy.misc DTYPE = tf.float32 # set parameters s.t. deconvolutional layers compute bilinear interpolation # N.B. this is for deconvolution without groups def interp_surgery(variables): interp_tensors = [] for v in variables: if '-up' in v.name: h, w, k, m = v.get_shape() tmp = np.zeros((m, k, h, w)) if m != k: print 'input + output channels need to be the same' raise if h != w: print 'filters need to be square' raise up_filter = upsample_filt(int(h)) tmp[range(m), range(k), :, :] = up_filter interp_tensors.append(tf.assign(v, tmp.transpose((2, 3, 1, 0)), validate_shape=True, use_locking=True)) return interp_tensors def det_lesion_arg_scope(weight_decay=0.0002): """Defines the arg scope. Args: weight_decay: The l2 regularization coefficient. Returns: An arg_scope. """ with slim.arg_scope([slim.conv2d, slim.convolution2d_transpose], activation_fn=tf.nn.relu, weights_initializer=tf.random_normal_initializer(stddev=0.001), weights_regularizer=slim.l2_regularizer(weight_decay), biases_initializer=tf.zeros_initializer, biases_regularizer=None, padding='SAME') as arg_sc: return arg_sc def binary_cross_entropy(output, target, epsilon=1e-8, name='bce_loss'): """Defines the binary cross entropy loss Args: output: the output of the network target: the ground truth Returns: A scalar with the loss, the output and the target """ target = tf.cast(target, tf.float32) output = tf.cast(tf.squeeze(output), tf.float32) with tf.name_scope(name): return tf.reduce_mean(-(target * tf.log(output + epsilon) + (1. - target) * tf.log(1. - output + epsilon))), output, target def preprocess_img(image, x_bb, y_bb, ids=None): """Preprocess the image to adapt it to network requirements Args: Image we want to input the network (W,H,3) numpy array Returns: Image ready to input the network (1,W,H,3) """ if ids == None: ids = np.ones(np.array(image).shape[0]) images = [[] for i in range(np.array(image).shape[0])] for j in range(np.array(image).shape[0]): for i in range(3): aux = np.array(scipy.io.loadmat(image[j])['section'], dtype=np.float32) crop = aux[int(float(x_bb[j])):int((float(x_bb[j])+80)), int(float(y_bb[j])): int((float(y_bb[j])+80))] """Different data augmentation options """ if id == '2': crop = np.fliplr(crop) elif id == '3': crop = np.fliphr(crop) elif id == '4': crop = np.fliphr(crop) crop = np.fliplr(crop) elif id == '5': crop = np.rot90(crop) elif id == '6': crop = np.rot90(crop, 2) elif id == '7': crop = np.fliplr(crop) crop = np.rot90(crop) elif id == '8': crop = np.fliplr(crop) crop = np.rot90(crop, 2) images[j].append(crop) in_ = np.array(images) in_ = in_.transpose((0,2,3,1)) in_ = np.subtract(in_, np.array((104.00699, 116.66877, 122.67892), dtype=np.float32)) return in_ def preprocess_labels(label): """Preprocess the labels to adapt them to the loss computation requirements Args: Label corresponding to the input image (W,H) numpy array Returns: Label ready to compute the loss (1,W,H,1) """ labels = [[] for i in range(np.array(label).shape[0])] for j in range(np.array(label).shape[0]): if type(label) is not np.ndarray: for i in range(3): aux = np.array(Image.open(label[j][i]), dtype=np.uint8) crop = aux[int(float(x_bb[j])):int((float(x_bb[j])+80)), int(float(y_bb[j])): int((float(y_bb[j])+80))] labels[j].append(crop) label = np.array(labels[0]) label = label.transpose((1,2,0)) max_mask = np.max(label) * 0.5 label = np.greater(label, max_mask) label = np.expand_dims(label, axis=0) return label def det_lesion_resnet(inputs, is_training_option=False, scope='det_lesion'): """Defines the network Args: inputs: Tensorflow placeholder that contains the input image scope: Scope name for the network Returns: net: Output Tensor of the network end_points: Dictionary with all Tensors of the network """ with tf.variable_scope(scope, 'det_lesion', [inputs]) as sc: end_points_collection = sc.name + '_end_points' with slim.arg_scope(resnet_v1.resnet_arg_scope()): net, end_points = resnet_v1.resnet_v1_50(inputs, is_training=is_training_option) net = slim.flatten(net, scope='flatten5') net = slim.fully_connected(net, 1, activation_fn=tf.nn.sigmoid, weights_initializer=initializers.xavier_initializer(), scope='output') utils.collect_named_outputs(end_points_collection, 'det_lesion/output', net) end_points = slim.utils.convert_collection_to_dict(end_points_collection) return net, end_points def load_resnet_imagenet(ckpt_path): """Initialize the network parameters from the Resnet-50 pre-trained model provided by TF-SLIM Args: Path to the checkpoint Returns: Function that takes a session and initializes the network """ reader = tf.train.NewCheckpointReader(ckpt_path) var_to_shape_map = reader.get_variable_to_shape_map() vars_corresp = dict() for v in var_to_shape_map: if "bottleneck_v1" in v or "conv1" in v: vars_corresp[v] = slim.get_model_variables(v.replace("resnet_v1_50", "det_lesion/resnet_v1_50"))[0] init_fn = slim.assign_from_checkpoint_fn(ckpt_path, vars_corresp) return init_fn def my_accuracy(output, target, name='accuracy'): """Accuracy for detection Args: The output and the target Returns: The accuracy based on the binary cross entropy """ target = tf.cast(target, tf.float32) output = tf.squeeze(output) with tf.name_scope(name): return tf.reduce_mean((target * output) + (1. - target) * (1. - output)) def train(dataset, initial_ckpt, learning_rate, logs_path, max_training_iters, save_step, display_step, global_step, iter_mean_grad=1, batch_size=1, momentum=0.9, resume_training=False, config=None, finetune=1): """Train network Args: dataset: Reference to a Dataset object instance initial_ckpt: Path to the checkpoint to initialize the network (May be parent network or pre-trained Imagenet) supervison: Level of the side outputs supervision: 1-Strong 2-Weak 3-No supervision learning_rate: Value for the learning rate. It can be number or an instance to a learning rate object. logs_path: Path to store the checkpoints max_training_iters: Number of training iterations save_step: A checkpoint will be created every save_steps display_step: Information of the training will be displayed every display_steps global_step: Reference to a Variable that keeps track of the training steps iter_mean_grad: Number of gradient computations that are average before updating the weights batch_size: momentum: Value of the momentum parameter for the Momentum optimizer resume_training: Boolean to try to restore from a previous checkpoint (True) or not (False) config: Reference to a Configuration object used in the creation of a Session finetune: Use to select to select type of training, 0 for the parent network and 1 for finetunning Returns: """ model_name = os.path.join(logs_path, "det_lesion.ckpt") if config is None: config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True tf.logging.set_verbosity(tf.logging.INFO) # Prepare the input data input_image = tf.placeholder(tf.float32, [batch_size, 80, 80, 3]) input_label = tf.placeholder(tf.float32, [batch_size]) is_training = tf.placeholder(tf.bool, shape=()) tf.summary.histogram('input_label', input_label) # Create the network with slim.arg_scope(det_lesion_arg_scope()): net, end_points = det_lesion_resnet(input_image, is_training_option=is_training) # Initialize weights from pre-trained model if finetune == 0: init_weights = load_resnet_imagenet(initial_ckpt) # Define loss with tf.name_scope('losses'): loss, output, target = binary_cross_entropy(net, input_label) total_loss = loss + tf.add_n(tf.losses.get_regularization_losses()) tf.summary.scalar('losses/total_loss', total_loss) tf.summary.histogram('losses/output', output) tf.summary.histogram('losses/target', target) # Define optimization method with tf.name_scope('optimization'): tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate, momentum) grads_and_vars = optimizer.compute_gradients(total_loss) with tf.name_scope('grad_accumulator'): grad_accumulator = [] for ind in range(0, len(grads_and_vars)): if grads_and_vars[ind][0] is not None: grad_accumulator.append(tf.ConditionalAccumulator(grads_and_vars[0][0].dtype)) with tf.name_scope('apply_gradient'): grad_accumulator_ops = [] for ind in range(0, len(grad_accumulator)): if grads_and_vars[ind][0] is not None: var_name = str(grads_and_vars[ind][1].name).split(':')[0] var_grad = grads_and_vars[ind][0] if "weights" in var_name: aux_layer_lr = 1.0 elif "biases" in var_name: aux_layer_lr = 2.0 grad_accumulator_ops.append(grad_accumulator[ind].apply_grad(var_grad*aux_layer_lr, local_step=global_step)) with tf.name_scope('take_gradients'): mean_grads_and_vars = [] for ind in range(0, len(grad_accumulator)): if grads_and_vars[ind][0] is not None: mean_grads_and_vars.append((grad_accumulator[ind].take_grad(iter_mean_grad), grads_and_vars[ind][1])) apply_gradient_op = optimizer.apply_gradients(mean_grads_and_vars, global_step=global_step) with tf.name_scope('metrics'): acc_op = my_accuracy(net, input_label) tf.summary.scalar('metrics/accuracy', acc_op) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if update_ops: tf.logging.info('Gathering update_ops') with tf.control_dependencies(tf.tuple(update_ops)): total_loss = tf.identity(total_loss) merged_summary_op = tf.summary.merge_all() # Initialize variables init = tf.global_variables_initializer() with tf.Session(config=config) as sess: print 'Init variable' sess.run(init) # op to write logs to Tensorboard summary_writer = tf.summary.FileWriter(logs_path + '/train', graph=tf.get_default_graph()) test_writer = tf.summary.FileWriter(logs_path + '/test') # Create saver to manage checkpoints saver = tf.train.Saver(max_to_keep=None) last_ckpt_path = tf.train.latest_checkpoint(logs_path) if last_ckpt_path is not None and resume_training: # Load last checkpoint print('Initializing from previous checkpoint...') saver.restore(sess, last_ckpt_path) step = global_step.eval() + 1 else: # Load pre-trained model if finetune == 0: print('Initializing from pre-trained imagenet model...') init_weights(sess) else: print('Initializing from pre-trained model...') # init_weights(sess) var_list = [] for var in tf.global_variables(): var_type = var.name.split('/')[-1] if 'weights' in var_type or 'bias' in var_type: var_list.append(var) saver_res = tf.train.Saver(var_list=var_list) saver_res.restore(sess, initial_ckpt) step = 1 sess.run(interp_surgery(tf.global_variables())) print('Weights initialized') print 'Start training' while step < max_training_iters + 1: # Average the gradient for iter_steps in range(0, iter_mean_grad): batch_image, batch_label, x_bb_train, y_bb_train, ids_train = dataset.next_batch(batch_size, 'train', 0.5) batch_image_val, batch_label_val, x_bb_val, y_bb_val, ids_val = dataset.next_batch(batch_size, 'val', 0.5) image = preprocess_img(batch_image, x_bb_train, y_bb_train, ids_train) label = batch_label val_image = preprocess_img(batch_image_val, x_bb_val, y_bb_val) label_val = batch_label_val run_res = sess.run([total_loss, merged_summary_op, acc_op] + grad_accumulator_ops, feed_dict={input_image: image, input_label: label, is_training: True}) batch_loss = run_res[0] summary = run_res[1] acc = run_res[2] if step % display_step == 0: val_run_res = sess.run([total_loss, merged_summary_op, acc_op], feed_dict={input_image: val_image, input_label: label_val, is_training: False}) val_batch_loss = val_run_res[0] val_summary = val_run_res[1] val_acc = val_run_res[2] # Apply the gradients sess.run(apply_gradient_op) # Save summary reports summary_writer.add_summary(summary, step) if step % display_step == 0: test_writer.add_summary(val_summary, step) # Display training status if step % display_step == 0: print >> sys.stderr, "{} Iter {}: Training Loss = {:.4f}".format(datetime.now(), step, batch_loss) print >> sys.stderr, "{} Iter {}: Validation Loss = {:.4f}".format(datetime.now(), step, val_batch_loss) print >> sys.stderr, "{} Iter {}: Training Accuracy = {:.4f}".format(datetime.now(), step, acc) print >> sys.stderr, "{} Iter {}: Validation Accuracy = {:.4f}".format(datetime.now(), step, val_acc) # Save a checkpoint if step % save_step == 0: save_path = saver.save(sess, model_name, global_step=global_step) print "Model saved in file: %s" % save_path step += 1 if (step-1) % save_step != 0: save_path = saver.save(sess, model_name, global_step=global_step) print "Model saved in file: %s" % save_path print('Finished training.') def validate(dataset, checkpoint_path, result_path, number_slices=1, config=None): """Test one sequence Args: dataset: Reference to a Dataset object instance checkpoint_path: Path of the checkpoint to use for the evaluation result_path: Path to save the output images config: Reference to a Configuration object used in the creation of a Session Returns: net: """ if config is None: config = tf.ConfigProto() config.gpu_options.allow_growth = True # config.log_device_placement = True config.allow_soft_placement = True tf.logging.set_verbosity(tf.logging.INFO) # Input data batch_size = 64 number_of_slices = number_slices depth_input = number_of_slices if number_of_slices < 3: depth_input = 3 pos_size = dataset.get_val_pos_size() neg_size = dataset.get_val_neg_size() input_image = tf.placeholder(tf.float32, [batch_size, None, None, depth_input]) # Create the cnn with slim.arg_scope(det_lesion_arg_scope()): net, end_points = det_lesion_resnet(input_image, is_training_option=False) probabilities = end_points['det_lesion/output'] global_step = tf.Variable(0, name='global_step', trainable=False) # Create a saver to load the network saver = tf.train.Saver([v for v in tf.global_variables() if '-up' not in v.name and '-cr' not in v.name]) with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) sess.run(interp_surgery(tf.global_variables())) saver.restore(sess, checkpoint_path) if not os.path.exists(result_path): os.makedirs(result_path) results_file_soft = open(os.path.join(result_path, 'soft_results.txt'), 'w') results_file_hard = open(os.path.join(result_path, 'hard_results.txt'), 'w') # Test positive windows count_patches = 0 for frame in range(0, pos_size/batch_size + (pos_size % batch_size > 0)): img, label, x_bb, y_bb = dataset.next_batch(batch_size, 'val', 1) curr_ct_scan = img[0] print 'Testing ' + curr_ct_scan image = preprocess_img(img, x_bb, y_bb) res = sess.run(probabilities, feed_dict={input_image: image}) label = np.array(label).astype(np.float32).reshape(batch_size, 1) for i in range(0, batch_size): count_patches +=1 img_part = img[i] res_part = res[i][0] label_part = label[i][0] if count_patches < (pos_size + 1): results_file_soft.write(img_part.split('images_volumes/')[-1] + ' ' + str(x_bb[i]) + ' ' + str(y_bb[i]) + ' ' + str(res_part) + ' ' + str(label_part) + '\n') if res_part > 0.5: results_file_hard.write(img_part.split('images_volumes/')[-1] + ' ' + str(x_bb[i]) + ' ' + str(y_bb[i]) + '\n') # Test negative windows count_patches = 0 for frame in range(0, neg_size/batch_size + (neg_size % batch_size > 0)): img, label, x_bb, y_bb = dataset.next_batch(batch_size, 'val', 0) curr_ct_scan = img[0] print 'Testing ' + curr_ct_scan image = preprocess_img(img, x_bb, y_bb) res = sess.run(probabilities, feed_dict={input_image: image}) label = np.array(label).astype(np.float32).reshape(batch_size, 1) for i in range(0, batch_size): count_patches += 1 img_part = img[i] res_part = res[i][0] label_part = label[i][0] if count_patches < (neg_size + 1): results_file_soft.write(img_part.split('images_volumes/')[-1] + ' ' + str(x_bb[i]) + ' ' + str(y_bb[i]) + ' ' + str(res_part) + ' ' + str(label_part) + '\n') if res_part > 0.5: results_file_hard.write(img_part.split('images_volumes/')[-1] + ' ' + str(x_bb[i]) + ' ' + str(y_bb[i]) + '\n') results_file_soft.close() results_file_hard.close() def test(dataset, checkpoint_path, result_path, number_slices=1, volume=False, config=None): """Test one sequence Args: dataset: Reference to a Dataset object instance checkpoint_path: Path of the checkpoint to use for the evaluation result_path: Path to save the output images config: Reference to a Configuration object used in the creation of a Session Returns: net: """ if config is None: config = tf.ConfigProto() config.gpu_options.allow_growth = True # config.log_device_placement = True config.allow_soft_placement = True tf.logging.set_verbosity(tf.logging.INFO) # Input data batch_size = 64 number_of_slices = number_slices depth_input = number_of_slices if number_of_slices < 3: depth_input = 3 total_size = dataset.get_val_pos_size() input_image = tf.placeholder(tf.float32, [batch_size, None, None, depth_input]) # Create the cnn with slim.arg_scope(det_lesion_arg_scope()): net, end_points = det_lesion_resnet(input_image, is_training_option=False) probabilities = end_points['det_lesion/output'] global_step = tf.Variable(0, name='global_step', trainable=False) # Create a saver to load the network saver = tf.train.Saver([v for v in tf.global_variables() if '-up' not in v.name and '-cr' not in v.name]) with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) sess.run(interp_surgery(tf.global_variables())) saver.restore(sess, checkpoint_path) if not os.path.exists(result_path): os.makedirs(result_path) results_file_soft = open(os.path.join(result_path, 'soft_results.txt'), 'w') results_file_hard = open(os.path.join(result_path, 'hard_results.txt'), 'w') # Test all windows count_patches = 0 for frame in range(0, total_size/batch_size + (total_size % batch_size > 0)): img, x_bb, y_bb = dataset.next_batch(batch_size, 'test', 1) curr_ct_scan = img[0] print 'Testing ' + curr_ct_scan image = preprocess_img(img, x_bb, y_bb) res = sess.run(probabilities, feed_dict={input_image: image}) for i in range(0, batch_size): count_patches += 1 img_part = img[i] res_part = res[i][0] if count_patches < (total_size + 1): results_file_soft.write(img_part.split('images_volumes/')[-1] + ' ' + str(x_bb[i]) + ' ' + str(y_bb[i]) + ' ' + str(res_part) + '\n') if res_part > 0.5: results_file_hard.write(img_part.split('images_volumes/')[-1] + ' ' + str(x_bb[i]) + ' ' + str(y_bb[i]) + '\n') results_file_soft.close() results_file_hard.close()