import theano import theano.tensor as T import lasagne from lasagne.layers import InputLayer, Conv2DLayer, MaxPool2DLayer, batch_norm, DropoutLayer, GaussianNoiseLayer from lasagne.init import HeNormal from lasagne import nonlinearities from lasagne.layers import ConcatLayer, Upscale2DLayer from lasagne.regularization import l2, regularize_network_params import logging from params import params as P import numpy as np def output_size_for_input(in_size, depth): in_size -= 4 for _ in range(depth-1): in_size = in_size//2 in_size -= 4 for _ in range(depth-1): in_size = in_size*2 in_size -= 4 return in_size NET_DEPTH = P.DEPTH #Default 5 INPUT_SIZE = P.INPUT_SIZE #Default 512 OUTPUT_SIZE = output_size_for_input(INPUT_SIZE, NET_DEPTH) def filter_for_depth(depth): return 2**(P.BRANCHING_FACTOR+depth) def define_network(input_var): batch_size = None net = {} net['input'] = InputLayer(shape=(batch_size,P.CHANNELS,P.INPUT_SIZE,P.INPUT_SIZE), input_var=input_var) nonlinearity = nonlinearities.leaky_rectify if P.GAUSSIAN_NOISE > 0: net['input'] = GaussianNoiseLayer(net['input'], sigma=P.GAUSSIAN_NOISE) def contraction(depth, deepest): n_filters = filter_for_depth(depth) incoming = net['input'] if depth == 0 else net['pool{}'.format(depth-1)] net['conv{}_1'.format(depth)] = Conv2DLayer(incoming, num_filters=n_filters, filter_size=3, pad='valid', W=HeNormal(gain='relu'), nonlinearity=nonlinearity) net['conv{}_2'.format(depth)] = Conv2DLayer(net['conv{}_1'.format(depth)], num_filters=n_filters, filter_size=3, pad='valid', W=HeNormal(gain='relu'), nonlinearity=nonlinearity) if P.BATCH_NORMALIZATION: net['conv{}_2'.format(depth)] = batch_norm(net['conv{}_2'.format(depth)], alpha=P.BATCH_NORMALIZATION_ALPHA) if not deepest: net['pool{}'.format(depth)] = MaxPool2DLayer(net['conv{}_2'.format(depth)], pool_size=2, stride=2) def expansion(depth, deepest): n_filters = filter_for_depth(depth) incoming = net['conv{}_2'.format(depth+1)] if deepest else net['_conv{}_2'.format(depth+1)] upscaling = Upscale2DLayer(incoming,4) net['upconv{}'.format(depth)] = Conv2DLayer(upscaling, num_filters=n_filters, filter_size=2, stride=2, W=HeNormal(gain='relu'), nonlinearity=nonlinearity) if P.SPATIAL_DROPOUT > 0: bridge_from = DropoutLayer(net['conv{}_2'.format(depth)], P.SPATIAL_DROPOUT) else: bridge_from = net['conv{}_2'.format(depth)] net['bridge{}'.format(depth)] = ConcatLayer([ net['upconv{}'.format(depth)], bridge_from], axis=1, cropping=[None, None, 'center', 'center']) net['_conv{}_1'.format(depth)] = Conv2DLayer(net['bridge{}'.format(depth)], num_filters=n_filters, filter_size=3, pad='valid', W=HeNormal(gain='relu'), nonlinearity=nonlinearity) #if P.BATCH_NORMALIZATION: # net['_conv{}_1'.format(depth)] = batch_norm(net['_conv{}_1'.format(depth)]) if P.DROPOUT > 0: net['_conv{}_1'.format(depth)] = DropoutLayer(net['_conv{}_1'.format(depth)], P.DROPOUT) net['_conv{}_2'.format(depth)] = Conv2DLayer(net['_conv{}_1'.format(depth)], num_filters=n_filters, filter_size=3, pad='valid', W=HeNormal(gain='relu'), nonlinearity=nonlinearity) for d in range(NET_DEPTH): #There is no pooling at the last layer deepest = d == NET_DEPTH-1 contraction(d, deepest) for d in reversed(range(NET_DEPTH-1)): deepest = d == NET_DEPTH-2 expansion(d, deepest) # Output layer net['out'] = Conv2DLayer(net['_conv0_2'], num_filters=P.N_CLASSES, filter_size=(1,1), pad='valid', nonlinearity=None) #import network_repr #print network_repr.get_network_str(net['out']) logging.info('Network output shape '+ str(lasagne.layers.get_output_shape(net['out']))) return net def score_metrics(out, target_var, weight_map, l2_loss=0): _EPSILON=1e-8 out_flat = out.dimshuffle(1,0,2,3).flatten(ndim=2).dimshuffle(1,0) target_flat = target_var.dimshuffle(1,0,2,3).flatten(ndim=1) weight_flat = weight_map.dimshuffle(1,0,2,3).flatten(ndim=1) prediction = lasagne.nonlinearities.softmax(out_flat) prediction_binary = T.argmax(prediction, axis=1) dice_score = (T.sum(T.eq(2, prediction_binary+target_flat))*2.0 / (T.sum(prediction_binary) + T.sum(target_flat))) loss = lasagne.objectives.categorical_crossentropy(T.clip(prediction,_EPSILON,1-_EPSILON), target_flat) loss = loss * weight_flat loss = loss.mean() loss += l2_loss accuracy = T.mean(T.eq(prediction_binary, target_flat), dtype=theano.config.floatX) return loss, accuracy, dice_score, target_flat, prediction, prediction_binary def define_updates(network, input_var, target_var, weight_var): params = lasagne.layers.get_all_params(network, trainable=True) out = lasagne.layers.get_output(network) test_out = lasagne.layers.get_output(network, deterministic=True) l2_loss = P.L2_LAMBDA * regularize_network_params(network, l2) train_metrics = score_metrics(out, target_var, weight_var, l2_loss) loss, acc, dice_score, target_prediction, prediction, prediction_binary = train_metrics val_metrics = score_metrics(test_out, target_var, weight_var, l2_loss) t_loss, t_acc, t_dice_score, t_target_prediction, t_prediction, t_prediction_binary = train_metrics l_r = theano.shared(np.array(P.LEARNING_RATE, dtype=theano.config.floatX)) if P.OPTIMIZATION == 'nesterov': updates = lasagne.updates.nesterov_momentum( loss, params, learning_rate=l_r, momentum=P.MOMENTUM) if P.OPTIMIZATION == 'adam': updates = lasagne.updates.adam( loss, params, learning_rate=l_r) logging.info("Defining train function") train_fn = theano.function([input_var, target_var, weight_var],[ loss, l2_loss, acc, dice_score, target_prediction, prediction, prediction_binary], updates=updates) logging.info("Defining validation function") val_fn = theano.function([input_var, target_var, weight_var], [ t_loss, l2_loss, t_acc, t_dice_score, t_target_prediction, t_prediction, t_prediction_binary]) return train_fn, val_fn, l_r def define_predict(network, input_var): params = lasagne.layers.get_all_params(network, trainable=True) out = lasagne.layers.get_output(network, deterministic=True) out_flat = out.dimshuffle(1,0,2,3).flatten(ndim=2).dimshuffle(1,0) prediction = lasagne.nonlinearities.softmax(out_flat) print "Defining predict" predict_fn = theano.function([input_var],[prediction]) return predict_fn