__author__ = "Peter F. Neher"
# Simple example for segmentation (U-Net) using Caffe2 including training, testing, saving and loading of networks

import numpy as np
from caffe2.python import workspace, core, model_helper, brew, optimizer, utils
from caffe2.proto import caffe2_pb2
import matplotlib.pyplot as plt

# randomly creates segmentation images with noise + ground truth
def get_data(batchsize, numsq=1) :
    data = []
    gt_segmentation = []
    sz = 64
    classes = 3
    for i in range(batchsize) :
        l = np.zeros((sz,sz,classes))
        if classes>1 :
            l[:,:,0] = 1
        for i in range(np.random.randint(0,numsq+1)) :
            s = np.random.randint(5, 20)
            x = np.random.randint(0, sz-s)
            y = np.random.randint(0, sz-s)
            l[x:x+s,y:y+s,0] = 0

            c = np.random.randint(1, classes)
            l[x:x+s,y:y+s,c] = 1

        noise = np.random.normal(0, 0.004, (sz,sz,classes))
        data.append(np.copy(l)+noise)
        gt_segmentation.append(l)
    return np.array(data).astype('float32'), np.array(gt_segmentation).astype('float32')

# create actual network structure (here U-net, Ronneberger et al.)
def create_unet_model(m, device_opts, is_test) :

    base_n_filters = 16
    kernel_size = 3
    pad = (kernel_size-1)/2
    do_dropout = True
    num_classes = 3

    weight_init=("MSRAFill", {})

    with core.DeviceScope(device_opts):

        contr_1_1 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, 'data', 'conv_1_1', dim_in=num_classes, dim_out=base_n_filters, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_1_1'), 'contr_1_1', dim_in=base_n_filters, epsilon=1e-3, momentum=0.1, is_test=is_test)
        contr_1_2 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, contr_1_1, 'conv_1_2', dim_in=base_n_filters, dim_out=base_n_filters, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_1_2'), 'contr_1_2', dim_in=base_n_filters, epsilon=1e-3, momentum=0.1, is_test=is_test)
        pool1 = brew.max_pool(m, contr_1_2, 'pool1', kernel=2, stride=2)

        contr_2_1 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, pool1, 'conv_2_1', dim_in=base_n_filters, dim_out=base_n_filters*2, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_2_1'), 'contr_2_1', dim_in=base_n_filters*2, epsilon=1e-3, momentum=0.1, is_test=is_test)
        contr_2_2 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, contr_2_1, 'conv_2_2', dim_in=base_n_filters*2, dim_out=base_n_filters*2, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_2_2'), 'contr_2_2', dim_in=base_n_filters*2, epsilon=1e-3, momentum=0.1, is_test=is_test)
        pool2 = brew.max_pool(m, contr_2_2, 'pool2', kernel=2, stride=2)

        contr_3_1 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, pool2, 'conv_3_1', dim_in=base_n_filters*2, dim_out=base_n_filters*4, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_3_1'), 'contr_3_1', dim_in=base_n_filters*4, epsilon=1e-3, momentum=0.1, is_test=is_test)
        contr_3_2 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, contr_3_1, 'conv_3_2', dim_in=base_n_filters*4, dim_out=base_n_filters*4, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_3_2'), 'contr_3_2', dim_in=base_n_filters*4, epsilon=1e-3, momentum=0.1, is_test=is_test)
        pool3 = brew.max_pool(m, contr_3_2, 'pool3', kernel=2, stride=2)

        contr_4_1 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, pool3, 'conv_4_1', dim_in=base_n_filters*4, dim_out=base_n_filters*8, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_4_1'), 'contr_4_1', dim_in=base_n_filters*8, epsilon=1e-3, momentum=0.1, is_test=is_test)
        contr_4_2 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, contr_4_1, 'conv_4_2', dim_in=base_n_filters*8, dim_out=base_n_filters*8, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_4_2'), 'contr_4_2', dim_in=base_n_filters*8, epsilon=1e-3, momentum=0.1, is_test=is_test)
        pool4 = brew.max_pool(m, contr_4_2, 'pool4', kernel=2, stride=2)

        if do_dropout:
            pool4 = brew.dropout(m, pool4, 'drop', ratio=0.4)

        encode_5_1 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, pool4, 'conv_5_1', dim_in=base_n_filters*8, dim_out=base_n_filters*16, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_5_1'), 'encode_5_1', dim_in=base_n_filters*16, epsilon=1e-3, momentum=0.1, is_test=is_test)
        encode_5_2 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, encode_5_1, 'conv_5_2', dim_in=base_n_filters*16, dim_out=base_n_filters*16, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_5_2'), 'encode_5_2', dim_in=base_n_filters*16, epsilon=1e-3, momentum=0.1, is_test=is_test)
        upscale5 = brew.conv_transpose(m, encode_5_2, 'upscale5', dim_in=base_n_filters*16, dim_out=base_n_filters*16, kernel=2, stride=2, weight_init=weight_init)

        concat6 = brew.concat(m, [upscale5, contr_4_2], 'concat6')#, axis=1)
        expand_6_1 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, concat6, 'conv_6_1', dim_in=base_n_filters * 8*3, dim_out=base_n_filters * 8, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_6_1'), 'expand_6_1', dim_in=base_n_filters * 8, epsilon=1e-3, momentum=0.1, is_test=is_test)
        expand_6_2 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, expand_6_1, 'conv_6_2', dim_in=base_n_filters * 8, dim_out=base_n_filters * 8, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_6_2'), 'expand_6_2', dim_in=base_n_filters * 8, epsilon=1e-3, momentum=0.1, is_test=is_test)
        upscale6 = brew.conv_transpose(m, expand_6_2, 'upscale6', dim_in=base_n_filters * 8, dim_out=base_n_filters * 8, kernel=2, stride=2, weight_init=weight_init)

        concat7 = brew.concat(m, [upscale6, contr_3_2], 'concat7')
        expand_7_1 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, concat7, 'conv_7_1', dim_in=base_n_filters * 4*3, dim_out=base_n_filters * 4, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_7_1'), 'expand_7_1', dim_in=base_n_filters * 4, epsilon=1e-3, momentum=0.1, is_test=is_test)
        expand_7_2 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, expand_7_1, 'conv_7_2', dim_in=base_n_filters * 4, dim_out=base_n_filters * 4, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_7_2'), 'expand_7_2', dim_in=base_n_filters * 4, epsilon=1e-3, momentum=0.1, is_test=is_test)
        upscale7 = brew.conv_transpose(m, expand_7_2, 'upscale7', dim_in=base_n_filters * 4, dim_out=base_n_filters * 4, kernel=2, stride=2, weight_init=weight_init)

        concat8 = brew.concat(m, [upscale7, contr_2_2], 'concat8')
        expand_8_1 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, concat8, 'conv_8_1', dim_in=base_n_filters * 2*3, dim_out=base_n_filters * 2, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_8_1'), 'expand_8_1', dim_in=base_n_filters * 2, epsilon=1e-3, momentum=0.1, is_test=is_test)
        expand_8_2 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, expand_8_1, 'conv_8_2', dim_in=base_n_filters * 2, dim_out=base_n_filters * 2, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_8_2'), 'expand_8_2', dim_in=base_n_filters * 2, epsilon=1e-3, momentum=0.1, is_test=is_test)
        upscale8 = brew.conv_transpose(m, expand_8_2, 'upscale8', dim_in=base_n_filters * 2, dim_out=base_n_filters * 2, kernel=2, stride=2, weight_init=weight_init)

        concat9 = brew.concat(m, [upscale8, contr_1_2], 'concat9')
        expand_9_1 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, concat9, 'conv_9_1', dim_in=base_n_filters * 3, dim_out=base_n_filters, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_9_1'), 'expand_9_1', dim_in=base_n_filters, epsilon=1e-3, momentum=0.1, is_test=is_test)
        expand_9_2 = brew.spatial_bn(m, brew.relu(m, brew.conv(m, expand_9_1, 'conv_9_2', dim_in=base_n_filters, dim_out=base_n_filters, kernel=kernel_size, pad=pad, weight_init=weight_init), 'nonl_9_2'), 'expand_9_2', dim_in=base_n_filters, epsilon=1e-3, momentum=0.1, is_test=is_test)

        output_segmentation = brew.conv(m, expand_9_2, 'output_segmentation', dim_in=base_n_filters, dim_out=num_classes, kernel=1, pad=0, stride=1, weight_init=weight_init)
        m.net.AddExternalOutput(output_segmentation)

        output_sigmoid = m.Sigmoid(output_segmentation, 'output_sigmoid')
        m.net.AddExternalOutput(output_sigmoid)

        return output_segmentation

# add loss and optimizer
def add_training_operators(output_segmentation, model, device_opts) :

    with core.DeviceScope(device_opts):
        loss = model.SigmoidCrossEntropyWithLogits([output_segmentation, "gt_segmentation"], 'loss')
        avg_loss = model.AveragedLoss(loss, "avg_loss")
        model.AddGradientOperators([loss])
        opt = optimizer.build_adam(model, base_learning_rate=0.01)

def train(INIT_NET, PREDICT_NET, epochs, batch_size, device_opts) :

    data, gt_segmentation = get_data(batch_size)
    workspace.FeedBlob("data", data, device_option=device_opts)
    workspace.FeedBlob("gt_segmentation", gt_segmentation, device_option=device_opts)

    train_model= model_helper.ModelHelper(name="train_net", arg_scope = {"order": "NHWC"})
    output_segmentation = create_unet_model(train_model, device_opts=device_opts, is_test=0)
    add_training_operators(output_segmentation, train_model, device_opts=device_opts)
    with core.DeviceScope(device_opts):
        brew.add_weight_decay(train_model, 0.001)

    workspace.RunNetOnce(train_model.param_init_net)
    workspace.CreateNet(train_model.net)

    print '\ntraining for', epochs, 'epochs'
    for j in range(0, epochs):
        data, gt_segmentation = get_data(batch_size, 4)

        workspace.FeedBlob("data", data, device_option=device_opts)
        workspace.FeedBlob("gt_segmentation", gt_segmentation, device_option=device_opts)

        workspace.RunNet(train_model.net, 1)   # run for 10 times
        print str(j) + ': ' + str(workspace.FetchBlob("avg_loss"))

    print 'training done'
    test_model= model_helper.ModelHelper(name="test_net", arg_scope = {"order": "NHWC"}, init_params=False)
    create_unet_model(test_model, device_opts=device_opts, is_test=1)
    workspace.RunNetOnce(test_model.param_init_net)
    workspace.CreateNet(test_model.net, overwrite=True)

    print '\nsaving test model'
    save_net(INIT_NET, PREDICT_NET, test_model)

def save_net(INIT_NET, PREDICT_NET, model) :

    with open(PREDICT_NET, 'wb') as f:
        f.write(model.net._net.SerializeToString())
    init_net = caffe2_pb2.NetDef()
    for param in model.params:
        #print param
        blob = workspace.FetchBlob(param)
        shape = blob.shape
        op = core.CreateOperator("GivenTensorFill", [], [param],arg=[ utils.MakeArgument("shape", shape),utils.MakeArgument("values", blob)])
        init_net.op.extend([op])
    init_net.op.extend([core.CreateOperator("ConstantFill", [], ["data"], shape=get_data(1)[0][0,:,:,:].shape)])
    with open(INIT_NET, 'wb') as f:
        f.write(init_net.SerializeToString())

def load_net(INIT_NET, PREDICT_NET, device_opts):

    init_def = caffe2_pb2.NetDef()
    with open(INIT_NET, 'r') as f:
        init_def.ParseFromString(f.read())
        init_def.device_option.CopyFrom(device_opts)
        workspace.RunNetOnce(init_def.SerializeToString())

    net_def = caffe2_pb2.NetDef()
    with open(PREDICT_NET, 'r') as f:
        net_def.ParseFromString(f.read())
        net_def.device_option.CopyFrom(device_opts)
        workspace.CreateNet(net_def.SerializeToString(), overwrite=True)

INIT_NET = '/path/to/init_net.pb'
PREDICT_NET = '/path/to/predict_net.pb'
device_opts = core.DeviceOption(caffe2_pb2.CUDA, 0) # change to 'core.DeviceOption(caffe2_pb2.CPU, 0)' for CPU processing

train(INIT_NET, PREDICT_NET, epochs=100, batch_size=100, device_opts=device_opts)

print '\n********************************************'
print 'loading test model'
# Problems loading batch-norm layer (SpatialBN) here!
load_net(INIT_NET, PREDICT_NET, device_opts=device_opts)

while True :
    data, gt_segmentation = get_data(1, 4)
    workspace.FeedBlob("data", data, device_option=device_opts)
    workspace.RunNet('test_net', 1)
    out1 = workspace.FetchBlob("output_sigmoid")

    fig, sub = plt.subplots(ncols=3, figsize=(15, 5))

    sub[0].set_title('Input')
    sub[0].imshow(data[0,:,:,:])

    sub[1].set_title('Ground Truth')
    sub[1].imshow(gt_segmentation[0,:,:,:])

    sub[2].set_title('Segmentation')
    sub[2].imshow(out1[0, :, :, :])

    plt.show()