import os
import glob
import time
import shutil
import numpy as np
from collections import OrderedDict
import __future__
import logging
import matplotlib
import tensorflow as tf
import csv

from tensorflow.python import debug as tf_debug
from layers import *
from ops import *
from lib import _dice_eval, _save, _save_nii_prediction, _jaccard, _dice, _label_decomp, _indicator_eval, read_nii_image

np.random.seed(0)
contour_map = { # a map used for mapping label value to its name, used for output
    "bg": 0,
    "la_myo": 1,
    "la_blood": 2,
    "lv_blood": 3,
    "aa": 4
}

verbose = True
logging.basicConfig(filename = "curr_log", level=logging.DEBUG, format='%(asctime)s %(message)s')
if verbose == True:
    logging.getLogger().addHandler(logging.StreamHandler())
raw_size = [256, 256, 3] # original raw input size
volume_size = [256, 256, 3] # volume size after processing, for the tfrecord file
label_size = [256, 256, 1] # size of label
decomp_feature = { # configuration for decoding tf_record file
            'dsize_dim0': tf.FixedLenFeature([], tf.int64),
            'dsize_dim1': tf.FixedLenFeature([], tf.int64),
            'dsize_dim2': tf.FixedLenFeature([], tf.int64),
            'lsize_dim0': tf.FixedLenFeature([], tf.int64),
            'lsize_dim1': tf.FixedLenFeature([], tf.int64),
            'lsize_dim2': tf.FixedLenFeature([], tf.int64),
            'data_vol': tf.FixedLenFeature([], tf.string),
            'label_vol': tf.FixedLenFeature([], tf.string)}

class Full_DRN(object):

    def __init__(self, channels, n_class, batch_size, cost_kwargs={}, network_config = {}):

        ##### Done this function

        tf.reset_default_graph()

        self.n_class = n_class # please note background is another class
        self.batch_size = batch_size

        self.mr_front_weights = [] # conv weights of MR path
        self.ct_front_weights = [] # conv weights of CT path
        self.cls_weights = []   # weights of feature discriminator
        self.m_cls_weights = [] # weights for segmentation mask discriminator
        self.joint_weights = [] # weights of joint part between CT and MRI. The final segmentor in our case

        self.mr = tf.placeholder("float", shape=[None, volume_size[0], volume_size[1], channels], name = "mr_ph")
        self.ct = tf.placeholder("float", shape=[None, volume_size[0], volume_size[1], channels])
        self.ct_y = tf.placeholder("float", shape=[None, label_size[0], label_size[1], self.n_class])
        self.mr_y = tf.placeholder("float", shape=[None, label_size[0], label_size[1], self.n_class])

        self.mr_front_bn = tf.placeholder_with_default(False, shape = None, name = "main_batchnorm_training_switch")
        self.joint_bn = tf.placeholder_with_default(False, shape = None, name = "joint_batchnorm_training_switch")
        self.ct_front_bn = tf.placeholder_with_default(True, shape = None, name = "adapt_batchnorm_training_switch")

        # these two are useless. They are not passed into the program
        self.cls_bn = tf.placeholder_with_default(True, shape = None, name = "cls_batchnorm_training_switch")
        self.m_cls_bn = tf.placeholder_with_default(True, shape = None, name = "mask_cls_batchnorm_training_switch")

        self.network_config = network_config
        self.mr_front_trainable = self.network_config["mr_front_trainable"]
        self.ct_front_trainable = self.network_config["ct_front_trainable"]
        self.joint_trainable = self.network_config["joint_trainable"]
        self.cls_trainable = self.network_config["cls_trainable"]
        self.m_cls_trainable = self.network_config["m_cls_trainable"]

        self.keep_prob = tf.placeholder(tf.float32) # dropout keep probability

        # Get features from MRI and CT path, for early layers
        _mr_c4_2, _ct_c4_2, _mr_c6_2, _ct_c6_2 = self.create_zip_network(input_channel = channels,\
                                        feature_base = 16, num_cls = n_class, keep_prob = self.keep_prob,\
                                        main_bn = self.mr_front_bn, main_trainable = self.mr_front_trainable,\
                                        adapt_bn = self.ct_front_bn, adapt_trainable = self.ct_front_trainable)

        # Get features from MRI and CT, fromt the shared higher layers
        with tf.variable_scope("", reuse = tf.AUTO_REUSE) as scope:
            _ct_c9_2, _ct_b8, _ct_b7, _ct_logits = self.create_second_half( _ct_c6_2, feature_base = 16, input_channel = 3, num_cls = n_class, keep_prob = self.keep_prob, joint_bn = self.joint_bn, joint_trainable = self.joint_trainable)
            _mr_c9_2, _mr_b8, _mr_b7, _mr_logits = self.create_second_half( _mr_c6_2, feature_base = 16, input_channel = 3, num_cls = n_class, keep_prob = self.keep_prob, joint_bn = self.joint_bn, joint_trainable = self.joint_trainable)

        self.ct_conv9_2 = _ct_c9_2
        self.mr_conv9_2 = _mr_c9_2

        with tf.variable_scope("cls_scope", reuse = tf.AUTO_REUSE) as scope:
            self._ct_class_logits = self.create_classifier(_ct_c4_2, _ct_c6_2, _ct_b7, _ct_c9_2, _ct_logits)
            self._mr_class_logits = self.create_classifier(_mr_c4_2, _mr_c6_2, _mr_b7, _mr_c9_2, _mr_logits)

        self.predictor = pixel_wise_softmax_2(_ct_logits) # segmentation logits of CT
        self.compact_pred = tf.argmax(self.predicter, 3) # predictions

        self.compact_y = tf.argmax(self.ct_y, 3) # ground truth
        self.ct_dice_eval, self.ct_dice_eval_arr = _dice_eval(self.compact_pred, self.ct_y, self.n_class) # used for monitoring training process
        self.ct_dice_eval_c1 = self.ct_dice_eval_arr[1]
        self.ct_dice_eval_c2 = self.ct_dice_eval_arr[2]
        self.ct_dice_eval_c3 = self.ct_dice_eval_arr[3]
        self.ct_dice_eval_c4 = self.ct_dice_eval_arr[4]

        self.mr_seg_valid = pixel_wise_softmax_2(_mr_logits) # segmentation logits of MRI
        self.compact_mr_valid = tf.argmax(self.mr_seg_valid, 3)

        self.compact_mr_y = tf.argmax(self.mr_y, 3)
        self.mr_dice_eval, self.mr_dice_eval_arr = _dice_eval(self.compact_mr_valid, self.mr_y, self.n_class)

        with tf.variable_scope("mask_cls_scope", reuse = tf.AUTO_REUSE) as scope:
            self._ct_mask_logits = self.create_mask_critic(_ct_logits, num_cls = n_class)  # auxilary D loss for masks
            self._mr_mask_logits = self.create_mask_critic(_mr_logits, num_cls = n_class)

        self.cost_kwargs = cost_kwargs
        self.dis_loss, self.ct_gen_loss, self.fixed_coeff_reg, self.dis_reg, self.gen_reg = self._get_cost(_ct_logits, _mr_logits,  self._ct_class_logits, self._mr_class_logits,\
                                                self._ct_mask_logits, self._mr_mask_logits, self.cost_kwargs) # get cost

        self.confusion_matrix = tf.confusion_matrix( tf.reshape(self.compact_y,[-1]), tf.reshape(self.compact_pred, [-1]), num_classes = self.n_class )

    def create_zip_network(self, main_bn, main_trainable, adapt_bn, adapt_trainable, num_cls, feature_base = 16,  input_channel = 3, keep_prob = 0.75):

        # MR path starts from here
        with tf.variable_scope('group_1') as scope:
            w1_1 = weight_variable(shape = [3, 3, input_channel, feature_base], trainable = main_trainable)
            conv1_1 = conv2d(self.mr, w1_1, keep_prob )
            wr1_1 = weight_variable(shape = [ 3, 3, feature_base,feature_base], trainable = main_trainable)
            wr1_2 = weight_variable(shape = [3, 3, feature_base, feature_base], trainable = main_trainable)
            block1_1 = residual_block(conv1_1, wr1_1, wr1_2, keep_prob , is_train = main_bn, leak = True, bn_trainable = main_trainable , scope = 'pred_1_1'   ) # here the scope is for bn
            out1 = max_pool2d(block1_1, n = 2)
            self.mr_front_weights.append(w1_1)
            self.mr_front_weights.append(wr1_1)
            self.mr_front_weights.append(wr1_2)

        with tf.variable_scope('group_2') as scope:
            wr2_1 = weight_variable(shape = [3, 3, feature_base, feature_base * 2], trainable = main_trainable)
            wr2_2 = weight_variable(shape = [3, 3, feature_base * 2, feature_base * 2], trainable = main_trainable)
            block2_1 = residual_block(out1, wr2_1, wr2_2, inc_dim = True,keep_prob = keep_prob, leak = True, is_train = main_bn, bn_trainable = main_trainable, scope = 'pred_2_1'  )
            out2 = max_pool2d(block2_1, n = 2)
            self.mr_front_weights.append(wr2_1)
            self.mr_front_weights.append(wr2_2)

        with tf.variable_scope('group_3') as scope:
            wr3_1 = weight_variable( shape = [3, 3, feature_base * 2, feature_base * 4], trainable = main_trainable  )
            wr3_2 = weight_variable( shape = [3, 3, feature_base * 4, feature_base * 4], trainable = main_trainable  )
            block3_1 = residual_block( out2, wr3_1, wr3_2, keep_prob, inc_dim = True, is_train = main_bn, leak = True, bn_trainable = main_trainable , scope = 'pred_3_1'       )
            wr3_3 = weight_variable( shape = [3, 3, feature_base * 4, feature_base * 4], trainable = main_trainable  )
            wr3_4 = weight_variable( shape = [3, 3, feature_base * 4, feature_base * 4], trainable = main_trainable  )
            block3_2 = residual_block( block3_1, wr3_3, wr3_4,keep_prob = keep_prob, is_train = main_bn, leak = True, bn_trainable = main_trainable , scope = 'pred_3_2'      )
            out3 = max_pool2d(block3_2, n = 2)
            self.mr_front_weights.append(wr3_1)
            self.mr_front_weights.append(wr3_2)
            self.mr_front_weights.append(wr3_3)
            self.mr_front_weights.append(wr3_4)

        with tf.variable_scope('group_4') as scope:
            wr4_1 = weight_variable( shape = [3, 3, feature_base * 4, feature_base * 8], trainable = main_trainable  )
            wr4_2 = weight_variable( shape = [3, 3, feature_base * 8, feature_base * 8], trainable = main_trainable   )
            block4_1 = residual_block( out3, wr4_1, wr4_2, keep_prob,  inc_dim = True, is_train = main_bn, leak = True, bn_trainable = main_trainable , scope = 'pred_4_1'   )
            wr4_3 = weight_variable( shape = [3, 3, feature_base * 8, feature_base * 8], trainable = main_trainable  )
            wr4_4 = weight_variable( shape = [3, 3, feature_base * 8, feature_base * 8], trainable = main_trainable  )
            block4_2 = residual_block( block4_1, wr4_3, wr4_4, keep_prob, is_train = main_bn, leak = True, bn_trainable = main_trainable , scope = 'pred_4_2'    )
            self.mr_front_weights.append(wr4_1)
            self.mr_front_weights.append(wr4_2)
            self.mr_front_weights.append(wr4_3)
            self.mr_front_weights.append(wr4_4)

        with tf.variable_scope('group_5') as scope:
            wr5_1 = sharable_weight_variable( shape = [3, 3, feature_base * 8, feature_base * 16], trainable = main_trainable, name = "Variable"  )
            wr5_2 = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = main_trainable , name = "Variable_1"  )
            block5_1 = residual_block( block4_2, wr5_1, wr5_2, keep_prob = keep_prob, inc_dim = True, leak = True,  is_train = main_bn, bn_trainable = main_trainable, scope = 'pred_5_1'     )
            wr5_3 = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = main_trainable , name = "Variable_2"  )
            wr5_4 = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = main_trainable , name = "Variable_3"  )
            block5_2 = residual_block( block5_1, wr5_3, wr5_4, keep_prob = keep_prob, is_train = main_bn, leak = True, bn_trainable = main_trainable , scope = 'pred_5_2'  )
            self.mr_front_weights.append( wr5_1  )
            self.mr_front_weights.append( wr5_2  )
            self.mr_front_weights.append( wr5_3  )
            self.mr_front_weights.append( wr5_4  )

        with tf.variable_scope('group_6') as scope:
            wr6_1 = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = main_trainable , name = "Variable"  )
            wr6_2 = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = main_trainable , name = "Variable_1"  )
            block6_1 = residual_block( block5_2, wr6_1, wr6_2, keep_prob = keep_prob,  is_train = main_bn, leak = True, bn_trainable = main_trainable , scope = 'pred_6_1'       )
            wr6_3 = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = main_trainable , name = "Variable_2"  )
            wr6_4 = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = main_trainable,  name = "Variable_3"   )
            block6_2 = residual_block( block6_1, wr6_3, wr6_4, keep_prob = keep_prob, is_train = main_bn, leak = True, bn_trainable = main_trainable , scope = 'pred_6_2'       )
            self.mr_front_weights.append( wr6_1  )
            self.mr_front_weights.append( wr6_2  )
            self.mr_front_weights.append( wr6_3  )
            self.mr_front_weights.append( wr6_4  )

        # DAM for CT path starts from here
        with tf.variable_scope('adapt_1') as scope:
            w1_1a = sharable_weight_variable(shape = [3, 3, input_channel, feature_base ], trainable = adapt_trainable, name = "Variable")
            conv1_1a = conv2d(self.ct, w1_1a, keep_prob )
            wr1_1a = sharable_weight_variable(shape = [ 3, 3, feature_base ,feature_base ], trainable = adapt_trainable, name = "Variable_1")
            wr1_2a = sharable_weight_variable(shape = [3, 3, feature_base , feature_base ], trainable = adapt_trainable, name = "Variable_2")
            block1_1a = residual_block(conv1_1a, wr1_1a, wr1_2a, keep_prob , is_train = adapt_bn, leak = True, bn_trainable = adapt_trainable, scope = 'adapt_1'   )
            out1a = max_pool2d(block1_1a, n = 2)
            self.ct_front_weights.append(w1_1a)
            self.ct_front_weights.append(wr1_1a)
            self.ct_front_weights.append(wr1_2a)

        with tf.variable_scope('adapt_2') as scope:
            wr2_1a = sharable_weight_variable(shape = [3, 3, feature_base , feature_base * 2], trainable = adapt_trainable, name = "Variable")
            wr2_2a = sharable_weight_variable(shape = [3, 3, feature_base * 2, feature_base * 2], trainable = adapt_trainable, name = "Variable_1")
            block2_1a = residual_block(out1a, wr2_1a, wr2_2a, inc_dim = True,keep_prob = keep_prob, leak = True, is_train = adapt_bn, bn_trainable = adapt_trainable, scope = 'adapt_2'   )
            out2a = max_pool2d(block2_1a, n = 2)
            self.ct_front_weights.append(wr2_1a)
            self.ct_front_weights.append(wr2_2a)

        with tf.variable_scope('adapt_3') as scope:
            wr3_1a = sharable_weight_variable( shape = [3, 3, feature_base * 2, feature_base * 4], trainable = adapt_trainable, name = "Variable"  )
            wr3_2a = sharable_weight_variable( shape = [3, 3, feature_base * 4, feature_base * 4], trainable = adapt_trainable, name = "Variable_1"  )
            block3_1a = residual_block( out2a, wr3_1a, wr3_2a, keep_prob, inc_dim = True, leak = True, is_train = adapt_bn, bn_trainable = adapt_trainable , scope = 'adapt_3_1'   )
            wr3_3a = sharable_weight_variable( shape = [3, 3, feature_base * 4, feature_base * 4], trainable = adapt_trainable, name = "Variable_2"  )
            wr3_4a = sharable_weight_variable( shape = [3, 3, feature_base * 4, feature_base * 4], trainable = adapt_trainable , name = "Variable_3"  )
            block3_2a = residual_block( block3_1a, wr3_3a, wr3_4a,keep_prob = keep_prob, leak = True, is_train = adapt_bn, bn_trainable = adapt_trainable, scope = 'adapt_3_2'    )

            out3a = max_pool2d(block3_2a, n = 2)
            self.ct_front_weights.append(wr3_1a)
            self.ct_front_weights.append(wr3_2a)
            self.ct_front_weights.append(wr3_3a)
            self.ct_front_weights.append(wr3_4a)

        with tf.variable_scope('adapt_4') as scope:
            wr4_1a = sharable_weight_variable( shape = [3, 3, feature_base * 4, feature_base * 8], trainable = adapt_trainable, name  = "Variable"  )
            wr4_2a = sharable_weight_variable( shape = [3, 3, feature_base * 8, feature_base * 8], trainable = adapt_trainable , name  = "Variable_1"   )
            block4_1a = residual_block( out3a, wr4_1a, wr4_2a, keep_prob, inc_dim = True, leak = True, is_train = adapt_bn, bn_trainable = adapt_trainable, scope = 'adapt_4_1'     )

            wr4_3a = sharable_weight_variable( shape = [3, 3, feature_base * 8, feature_base * 8], trainable = adapt_trainable , name  = "Variable_2"  )
            wr4_4a = sharable_weight_variable( shape = [3, 3, feature_base * 8, feature_base * 8], trainable = adapt_trainable  , name  = "Variable_3" )
            block4_2a = residual_block( block4_1a, wr4_3a, wr4_4a, keep_prob, is_train = adapt_bn, leak = True, bn_trainable = adapt_trainable, scope = 'adapt_4_2'      )
            self.ct_front_weights.append(wr4_1a)
            self.ct_front_weights.append(wr4_2a)
            self.ct_front_weights.append(wr4_3a)
            self.ct_front_weights.append(wr4_4a)

        with tf.variable_scope('adapt_5') as scope:
            wr5_1a = sharable_weight_variable( shape = [3, 3, feature_base * 8, feature_base * 16], trainable = adapt_trainable, name = "Variable"  )
            wr5_2a = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = adapt_trainable , name = "Variable_1"  )
            block5_1a = residual_block( block4_2a, wr5_1a, wr5_2a, keep_prob = keep_prob, leak = True, inc_dim = True,  is_train = adapt_bn, bn_trainable = adapt_trainable, scope = 'adapt_5_1'     )

            wr5_3a = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = adapt_trainable , name = "Variable_2"  )
            wr5_4a = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = adapt_trainable , name = "Variable_3"  )
            block5_2a = residual_block( block5_1a, wr5_3a, wr5_4a, keep_prob = keep_prob, leak = True, is_train = adapt_bn, bn_trainable = adapt_trainable , scope = 'adapt_5_2'  )
            self.ct_front_weights.append( wr5_1a  )
            self.ct_front_weights.append( wr5_2a  )
            self.ct_front_weights.append( wr5_3a  )
            self.ct_front_weights.append( wr5_4a  )

        with tf.variable_scope('adapt_6') as scope:
            wr6_1a = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = adapt_trainable , name = "Variable"  )
            wr6_2a = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = adapt_trainable , name = "Variable_1"  )
            block6_1a = residual_block( block5_2a, wr6_1a, wr6_2a, keep_prob = keep_prob, leak = True,  is_train = adapt_bn, bn_trainable = adapt_trainable , scope = 'adapt_6_1'       )

            wr6_3a = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = adapt_trainable , name = "Variable_2"  )
            wr6_4a = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 16], trainable = adapt_trainable,  name = "Variable_3"   )
            block6_2a = residual_block( block6_1a, wr6_3a, wr6_4a, keep_prob = keep_prob, leak = True, is_train = adapt_bn, bn_trainable = adapt_trainable , scope = 'adapt_6_2'       )
            self.ct_front_weights.append( wr6_1a  )
            self.ct_front_weights.append( wr6_2a  )
            self.ct_front_weights.append( wr6_3a  )
            self.ct_front_weights.append( wr6_4a  )

        return block4_2, block4_2a, block6_2, block6_2a

    def create_second_half(self, input_feature, joint_bn, joint_trainable, num_cls, feature_base = 16,  input_channel = 3, keep_prob = 0.75):

        with tf.variable_scope('group_7', reuse = tf.AUTO_REUSE) as scope:
            wr7_1 = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base * 32], trainable = joint_trainable  , name = "Variable" )
            wr7_2 = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base * 32], trainable = joint_trainable  , name = "Variable_1" )
            block7_1 = residual_block( input_feature, wr7_1, wr7_2, keep_prob = keep_prob, leak = True, inc_dim = True,  is_train = joint_bn, bn_trainable = joint_trainable , scope = 'pred_7_1'     )
            wr7_3 = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base * 32], trainable = joint_trainable , name = "Variable_2"  )
            wr7_4 = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base * 32], trainable = joint_trainable  , name = "Variable_3" )
            block7_2 = residual_block( block7_1, wr7_3, wr7_4, keep_prob = keep_prob, leak = True,  is_train = joint_bn, bn_trainable = joint_trainable , scope = 'pred_7_2'      )
            self.mr_front_weights.append( wr7_1  )
            self.mr_front_weights.append( wr7_2  )
            self.mr_front_weights.append( wr7_3  )
            self.mr_front_weights.append( wr7_4  )

        with tf.variable_scope('group_8', reuse = tf.AUTO_REUSE) as scope:
            wr8_1 = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base * 32], trainable = joint_trainable , name = "Variable"  )
            wr8_2 = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base * 32], trainable = joint_trainable  , name = "Variable_1" )
            block8_1 = DR_block( block7_2, wr8_1, wr8_2, keep_prob = keep_prob, leak = True, is_train = joint_bn, rate = 2, bn_trainable = joint_trainable , scope = 'pred_8_1'       )
            wr8_3 = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base * 32], trainable = joint_trainable  , name = "Variable_2" )
            wr8_4 = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base * 32], trainable = joint_trainable  , name = "Variable_3" )
            block8_2 = DR_block( block8_1, wr8_3, wr8_4, keep_prob = keep_prob, leak = True,  is_train = joint_bn, rate = 2, bn_trainable = joint_trainable , scope = 'pred_8_2'   )
            self.mr_front_weights.append( wr8_1  )
            self.mr_front_weights.append( wr8_2  )
            self.mr_front_weights.append( wr8_3  )
            self.mr_front_weights.append( wr8_4  )

        with tf.variable_scope('group_9', reuse = tf.AUTO_REUSE) as scope:
            w9_1 = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base * 32], trainable = joint_trainable  , name = "Variable" )
            conv9_1 = conv_bn_relu2d( block8_2, w9_1, keep_prob, leak = True, is_train = joint_bn, bn_trainable = joint_trainable , scope = 'pred_9_1'   )
            w9_2 = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base * 32], trainable = joint_trainable  , name = "Variable_1" )
            conv9_2 = conv_bn_relu2d( conv9_1, w9_2, keep_prob, leak = True, is_train = joint_bn, bn_trainable = joint_trainable , scope = 'pred_9_2'    )
            self.mr_front_weights.append( w9_1  )
            self.mr_front_weights.append( w9_2  )

        with tf.variable_scope('group_10', reuse = tf.AUTO_REUSE) as scope:
            local_size = 8 * 8
            w10_1 = sharable_weight_variable( shape = [3, 3, feature_base * 32, local_size * num_cls * 8], trainable = joint_trainable , name = "Variable" )
            conv10_1 = conv2d( conv9_2, w10_1, keep_prob_ = keep_prob, padding = 'SYMMETRIC')
            self.mr_front_weights.append(w10_1)
            flat_conv10_1 = PS(conv10_1, r = 8, n_channel = num_cls * 8, batch_size = self.batch_size) # phase shift

        with tf.variable_scope('output', reuse = tf.AUTO_REUSE) as scope:
            w11_1 = sharable_weight_variable( shape = [5, 5, num_cls * 8, num_cls], trainable = joint_trainable  , name = "Variable" )
            logits = conv2d( flat_conv10_1, w11_1, keep_prob_ = 1., padding = 'SYMMETRIC'  )

        return conv9_2, block8_2, block7_2, logits

    def create_classifier(self, input_conv4, input_conv6, input_b7, input_conv9, seg_logits, feature_base = 16, keep_prob = 0.75, cls_bn = True, cls_trainable = True):
        """
        domain discriminator for MRI features and CT features
        """
        with tf.variable_scope('cls_0') as scope:
            flat_input_conv4 = PS(input_conv4, r=8, n_channel=2, batch_size=self.batch_size)  # 2
            flat_input_conv4 = tf.tile(flat_input_conv4, [1, 1, 1, 3]) # 6 in total
            flat_input_conv6 = PS(input_conv6, r=8, n_channel=4, batch_size=self.batch_size)  # 10 in total
            flat_input_b7 = PS(input_b7, r=8, n_channel=8, batch_size=self.batch_size)  # 18 in total
            flat_input_conv9 = PS(input_conv9, r = 8, n_channel = 8, batch_size = self.batch_size) # 26 in total

            input_comp = simple_concat2d(flat_input_conv4, flat_input_conv6) # 10
            input_comp = simple_concat2d(input_comp, flat_input_b7) # 18
            input_comp = simple_concat2d(input_comp, flat_input_conv9) # 26
            input_comp = simple_concat2d(input_comp, seg_logits) # 31 in total
            input_comp = simple_concat2d(input_comp, tf.expand_dims(tf.cast(tf.argmax(seg_logits, 3), tf.float32), 3))  # 1

        with tf.variable_scope('cls_1') as scope:
            wr1_1c = sharable_weight_variable( shape = [3, 3, feature_base * 2, feature_base * 4], trainable = cls_trainable  , name = "Variable" )
            wr1_2c = sharable_weight_variable( shape = [3, 3, feature_base * 4, feature_base * 4], trainable = cls_trainable  , name = "Variable_1" )
            block1_1c = residual_block( input_comp, wr1_1c, wr1_2c, keep_prob = keep_prob, inc_dim = True, is_train = cls_bn, bn_trainable = cls_trainable, scope = 'cls_1'   , leak = True   )
            wr1_3d = sharable_weight_variable( shape = [3,3, feature_base * 4, feature_base * 4], trainable = cls_trainable, name = "Variable_2"  )
            out1c = conv_bn_relu2d( block1_1c, wr1_3d, keep_prob, strides = [1,2,2,1], is_train = cls_bn, bn_trainable = cls_trainable, scope = 'cls_1_3', leak = True  )
            self.cls_weights.append( wr1_1c  )
            self.cls_weights.append( wr1_2c  )
            self.cls_weights.append( wr1_3d  )

        with tf.variable_scope('cls_2') as scope:
            wr2_1c = sharable_weight_variable( shape = [3, 3, feature_base * 4, feature_base *8], trainable = cls_trainable  , name = "Variable" )
            wr2_2c = sharable_weight_variable( shape = [3, 3, feature_base * 8, feature_base *8], trainable = cls_trainable  , name = "Variable_1" )
            block2_1c = residual_block( out1c, wr2_1c, wr2_2c, keep_prob = keep_prob, inc_dim = True, is_train = cls_bn, bn_trainable = cls_trainable, scope = 'cls_2'   , leak = True   )
            wr2_3d = sharable_weight_variable( shape = [5,5, feature_base * 8, feature_base * 8], trainable = cls_trainable, name = "Variable_2"  )
            out2c = conv_bn_relu2d( block2_1c, wr2_3d, keep_prob, strides = [1,2,2,1], is_train = cls_bn, bn_trainable = cls_trainable, scope = 'cls_2_3', leak = True  )
            self.cls_weights.append( wr2_1c  )
            self.cls_weights.append( wr2_2c  )
            self.cls_weights.append( wr2_3d  )
            self.debug_out2c = out2c
            self.debug_wr2_2c = wr2_2c

        with tf.variable_scope('cls_3') as scope:
            wr3_1c = sharable_weight_variable( shape = [3, 3, feature_base * 8, feature_base *16], trainable = cls_trainable  , name = "Variable" )
            wr3_2c = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base *16], trainable = cls_trainable  , name = "Variable_1" )
            block3_1c = residual_block( out2c, wr3_1c, wr3_2c,  keep_prob = keep_prob, inc_dim = True, is_train = cls_bn, bn_trainable = cls_trainable, scope = 'cls_3'   , leak = True   )
            wr3_3d = sharable_weight_variable( shape = [3,3, feature_base * 16, feature_base * 16], trainable = cls_trainable, name = "Variable_2"  )
            out3c = conv_bn_relu2d( block3_1c, wr3_3d, keep_prob, strides = [1,2,2,1], is_train = cls_bn, bn_trainable = cls_trainable, scope = 'cls_3_3', leak = True  )
            self.cls_weights.append( wr3_1c  )
            self.cls_weights.append( wr3_2c  )
            self.cls_weights.append( wr3_3d  )

        with tf.variable_scope('cls_4') as scope:
            wr4_1c = sharable_weight_variable( shape = [3, 3, feature_base * 16, feature_base *32], trainable = cls_trainable  , name = "Variable" )
            wr4_2c = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base *32], trainable = cls_trainable  , name = "Variable_1" )
            block4_1c = residual_block( out3c, wr4_1c, wr4_2c,  keep_prob = keep_prob, inc_dim = True, is_train = cls_bn, bn_trainable = cls_trainable, scope = 'cls_4'   , leak = True   )
            wr4_3d = sharable_weight_variable( shape = [3,3, feature_base * 32, feature_base * 32], trainable = cls_trainable, name = "Variable_2"  )
            out4c = conv_bn_relu2d( block4_1c, wr4_3d, keep_prob, strides = [1,2,2,1], is_train = cls_bn, bn_trainable = cls_trainable, scope = 'cls_4_3', leak = True  )
            self.cls_weights.append( wr4_1c  )
            self.cls_weights.append( wr4_2c  )
            self.cls_weights.append( wr4_3d  )

        with tf.variable_scope('cls_5') as scope:
            wr5_1c = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base *32], trainable = cls_trainable  , name = "Variable" )
            wr5_2c = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base *32], trainable = cls_trainable  , name = "Variable_1" )
            block5_1c = residual_block( out4c, wr5_1c, wr5_2c,  keep_prob = keep_prob, is_train = cls_bn, bn_trainable = cls_trainable, scope = 'cls_5'   , leak = True   )
            wr5_3d = sharable_weight_variable( shape = [5,5, feature_base * 32, feature_base * 32], trainable = cls_trainable, name = "Variable_2"  )
            out5c = conv_bn_relu2d( block5_1c, wr5_3d, keep_prob, strides = [1,4,4,1], is_train = cls_bn, bn_trainable = cls_trainable, scope = 'cls_5_3', leak = True  )
            self.cls_weights.append( wr5_1c  )
            self.cls_weights.append( wr5_2c  )
            self.cls_weights.append( wr5_3d  )

        with tf.variable_scope('cls_6') as scope:
            wr6_1c = sharable_weight_variable( shape = [3, 3, feature_base * 32, feature_base *32], trainable = cls_trainable  , name = "Variable" )
            conv_6c = conv_bn_relu2d(out5c, wr6_1c, strides = [1,2,2,1], keep_prob = keep_prob, padding = "SYMMETRIC", scope = 'cls_6', is_train = cls_bn, bn_trainable = cls_trainable, leak = True)
            self.cls_weights.append( wr6_1c  )

        with tf.variable_scope('cls_out') as scope:
            wc_out = sharable_weight_variable( shape = [ feature_base* 32 * 4,1 ], trainable = cls_trainable , name = "Variable"  )
            out6c_flat = tf.reshape(conv_6c, [-1, feature_base * 32 * 4])
            cls_logits = tf.matmul(out6c_flat, wc_out)
            self.cls_weights.append(wc_out)

        return cls_logits

    def create_mask_critic(self, input_mask, feature_base = 16, keep_prob = 0.75, num_cls = 5, m_cls_bn = True, m_cls_trainable = True):
        """
        domain discriminator for MRI and CT segmentation maskS

        """
        with tf.variable_scope('mask_cls_1') as scope:
            wr1_1m = sharable_weight_variable( shape = [3, 3, num_cls, feature_base], trainable = m_cls_trainable  , name = "Variable" )
            out1m = conv_bn_relu2d( input_mask, wr1_1m, keep_prob, strides = [1,2,2,1], is_train = m_cls_bn, bn_trainable = m_cls_trainable, scope = 'mask_cls_1', leak = True  ) # use strided conv instead of maxpool to
            self.m_cls_weights.append( wr1_1m )

        with tf.variable_scope('mask_cls_2') as scope:
            wr2_1m = sharable_weight_variable( shape = [3, 3, feature_base, feature_base ], trainable = m_cls_trainable  , name = "Variable" )
            wr2_2m = sharable_weight_variable( shape = [3, 3, feature_base, feature_base ], trainable = m_cls_trainable  , name = "Variable_1" )
            block2_1m = residual_block( out1m, wr2_1m, wr2_2m, keep_prob = keep_prob, inc_dim = False, is_train = m_cls_bn, bn_trainable = m_cls_trainable, scope = 'm_cls_2'   , leak = True   )
            wr2_3d = sharable_weight_variable( shape = [5,5, feature_base, feature_base * 2], trainable = m_cls_trainable, name = "Variable_2"  )
            out2m = conv_bn_relu2d( block2_1m, wr2_3d, keep_prob, strides = [1,4,4,1], is_train = m_cls_bn, bn_trainable = m_cls_trainable, scope = 'm_cls_2_3', leak = True  )
            self.m_cls_weights.append( wr2_1m  )
            self.m_cls_weights.append( wr2_2m  )
            self.m_cls_weights.append( wr2_3d  )

        with tf.variable_scope('mask_cls_3') as scope:
            wr3_1m = sharable_weight_variable( shape = [3, 3, feature_base * 2, feature_base * 4], trainable = m_cls_trainable  , name = "Variable" )
            wr3_2m = sharable_weight_variable( shape = [3, 3, feature_base * 4, feature_base * 4 ], trainable = m_cls_trainable  , name = "Variable_1" )
            block3_1m = residual_block( out2m, wr3_1m, wr3_2m, keep_prob = keep_prob, inc_dim = True, is_train = m_cls_bn, bn_trainable = m_cls_trainable, scope = 'm_cls_3'   , leak = True   )
            wr3_3d = sharable_weight_variable( shape = [5,5, feature_base * 4, feature_base * 8], trainable = m_cls_trainable, name = "Variable_2"  )
            out3m = conv_bn_relu2d( block3_1m, wr3_3d, keep_prob, strides = [1,4,4,1], is_train = m_cls_bn, bn_trainable = m_cls_trainable, scope = 'm_cls_3_3', leak = True  )
            self.m_cls_weights.append( wr3_1m  )
            self.m_cls_weights.append( wr3_2m  )
            self.m_cls_weights.append( wr3_3d  )

        with tf.variable_scope('mask_cls_4') as scope:
            wr4_1m = sharable_weight_variable( shape = [5, 5, feature_base * 8, feature_base * 16], trainable = m_cls_trainable  , name = "Variable" )
            conv_4m = conv_bn_relu2d(out3m, wr4_1m, strides = [1,4,4,1], keep_prob = keep_prob, padding = "SYMMETRIC", scope = 'm_cls_4', is_train = m_cls_bn, bn_trainable = m_cls_trainable, leak = True)
            self.m_cls_weights.append( wr4_1m  )

        with tf.variable_scope('m_cls_out') as scope:
            wm_out = sharable_weight_variable( shape = [ feature_base* 16 * 4,1 ], trainable = m_cls_trainable , name = "Variable"  )
            out5m_flat = tf.reshape(conv_4m, [-1, feature_base * 16 * 4])
            m_cls_logits = tf.matmul(out5m_flat, wm_out)
            self.m_cls_weights.append(wm_out)

        return m_cls_logits

    def _get_cost(self, ct_logits, mr_logits,  ct_cls_logits, mr_cls_logits, ct_mask_logits, mr_mask_logits, cost_kwargs):

        miu_dis = cost_kwargs["miu_dis"] # coefficient for discriminator loss
        miu_gen = cost_kwargs["miu_gen"] # used to be 0.5 0.5 1
        lambda_mask_loss = cost_kwargs.pop("lambda_mask_loss", 1.0) # weighting of mask critic score

        self.miu_dis = tf.Variable(miu_dis, name = "miu_dis") # coefficient for discrminator
        self.miu_gen = tf.Variable(miu_gen, name = "miu_gen")

        # loss for main critic and mask critic
        dis_loss = -1 * self.miu_dis * tf.reduce_mean( mr_cls_logits - ct_cls_logits  ) # loss functions of WGAN
        gen_loss = -1 * self.miu_gen * tf.reduce_mean( ct_cls_logits  )

        m_dis_loss = -1 * self.miu_dis * tf.reduce_mean( mr_mask_logits - ct_mask_logits  )
        m_gen_loss = -1 * self.miu_gen * tf.reduce_mean( ct_mask_logits  )

        ############  L2 norm regularizer  ######################
        reg_coeff = cost_kwargs.pop("regularizer", 1.0e-4) # regularizer coefficients for non-GAN parts
        mr_front_reg = sum([tf.nn.l2_loss(variable) for variable in self.mr_front_weights]) # regulizer for MRI varibles, fixed for the unsupervised setting
        joint_reg = sum([tf.nn.l2_loss(variable) for variable in self.joint_weights]) # regularizer for joint part, fixed for the unsupervised setting
        fixed_coeff_reg = reg_coeff * (mr_front_reg + joint_reg) # for training observation to confirm the source segmenter is not updated

        gan_reg_coeff = cost_kwargs.pop("gan_regularizer", 1.0e-4)  # regularizer coefficients for GAN parts, note, seems that it works well when it is larger
        gen_reg = gan_reg_coeff * self.miu_gen * sum([tf.nn.l2_loss(variable) for variable in self.ct_front_weights]) # regulizers for WGAN
        dis_reg = gan_reg_coeff * self.miu_dis * sum([tf.nn.l2_loss(variable) for variable in self.cls_weights])
        m_dis_reg = gan_reg_coeff * self.miu_dis * sum([tf.nn.l2_loss(variable) for variable in self.m_cls_weights])

        dis_loss += lambda_mask_loss * m_dis_loss
        gen_loss += lambda_mask_loss * m_gen_loss
        dis_reg += lambda_mask_loss * m_dis_reg

        return dis_loss, gen_loss, fixed_coeff_reg, dis_reg, gen_reg

    def _get_variables_by_scope(self):
        """
        Group different variables (MR, CT, GAN, etc)to different groups
        """
        logging.info("extent of joint part and segmenter need to be manually set, including variables and bns")

        self.adapt_vars = [] # variables for adaptation (CT)
        self.cls_vars = [] # variables for domain-classifier (i.e. discriminator) for WGAN
        self.seg_vars = [] # variables for segmentation, fixed higher layers in source segmenter
        self.mri_seg_vars = [] # variables for segmentation, MRI early players, fixed as well

        var_list = tf.contrib.framework.get_variables()
        for var in var_list:
            if "cls" in var.name:
                self.cls_vars.append(var)
            elif "adapt" in var.name:
                self.adapt_vars.append(var)
            elif "output" in var.name:
                self.seg_vars.append(var)
                self.mri_seg_vars.append(var)
            elif "group" in var.name:
                _group_name = var.name.split("/")[0]
                _group_no = float(_group_name.split("_")[-1] )
                self.mri_seg_vars.append(var)

    def restore(self, sess, model_path, no_gan=False, clear_rms=False):
        """
        Restores a session from a checkpoint

        :param sess: current session instance
        :param model_path: path to file system checkpoint location
        :param no_gan: only restore mr variables
        :param clear_rms: does not restore RMSprop internal variables, please set is true
        """
        saver = tf.train.Saver(tf.contrib.framework.get_variables() + tf.get_collection_ref("internal_batchnorm_variables") )
        logging.info("Model restored from file: %s" % model_path)
        if no_gan is True:
            logging.info("I only load the main variables! without batchnorm!!!")
            variables = tf.global_variables()
            reader = tf.pywrap_tensorflow.NewCheckpointReader(model_path)
            var_keep_dic = reader.get_variable_to_shape_map()
            variables_to_restore = []
            for v in variables:
                if v.name.split(':')[0] in var_keep_dic:
                    if ("adapt" in v.name) or ("cls" in v.name) or("Adam" in v.name):
                        continue
                    if ("group" in v.name) or ("output" in v.name):
                        logging.info("restoring "+str(v.name))
                        variables_to_restore.append(v)
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, model_path)

            logging.info("Model restored from file: %s, the pre-trained MRI model (without bn params)" % model_path)

            return 0

        if clear_rms is True:
            logging.info("Calculating RMS parameters from beginning")
            variables = tf.global_variables()
            reader = tf.pywrap_tensorflow.NewCheckpointReader(model_path)
            var_keep_dic = reader.get_variable_to_shape_map()
            variables_to_restore = []
            for v in variables:
                if v.name.split(':')[0] in var_keep_dic:
                    if ("RMS" in v.name) :
                        continue
                    else:
                        logging.info("restoring "+str(v.name))
                        variables_to_restore.append(v)
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, model_path)

            logging.info("Model restored from file: %s and RMS variables are ignored" % model_path)
            return 0

        try: # else, just restore as much as possible
            saver.restore(sess, model_path)
            logging.info("Model restored from file: %s" % model_path)
        except:
            variables = tf.global_variables()
            reader = tf.pywrap_tensorflow.NewCheckpointReader(model_path)
            var_keep_dic = reader.get_variable_to_shape_map()
            variables_to_restore = []
            for v in variables:
                if v.name.split(':')[0] in var_keep_dic:
                    skip_flg = False
                    for kwd in self.network_config["restore_skip_kwd"]: # if it is manully specified to be skipped, don't restore it
                        if kwd in v.name:
                            skip_flg = True
                            break
                    if skip_flg is False:
                        variables_to_restore.append(v)
                        logging.info("cannot fully restore the model, restoring "+str(v.name))
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, model_path)

            logging.info("Model restored from file: %s with relaxation" % model_path)

class Trainer(object):
    """
    Train a unet instance
    """
    def __init__(self, net, mr_train_list, mr_val_list, ct_train_list, ct_val_list,\
                 adapt_var_list, mr_var_list, old_bn_list, new_bn_list,\
                 test_label_list = None, test_nii_list = None,\
                 num_cls=None, batch_size = 6,\
                 opt_kwargs={}, train_config = {}):

        self.net = net
        self.batch_size = batch_size
        self.num_cls = num_cls # including background
        self.opt_kwargs = opt_kwargs
        self.ct_train_list = ct_train_list # a list of training files
        self.ct_val_list = ct_val_list # a list of validation files
        self.mr_train_list = mr_train_list # a list of training files for MRI
        self.mr_val_list = mr_val_list
        self.test_label_list = test_label_list # test files (npz format)
        self.test_nii_list = test_nii_list # test files (npz format)
        self.adapt_var_list = adapt_var_list # a list of variables in CT path
        self.mr_var_list = mr_var_list # a list of variables in MRI path in correspondance with variables in adapt_var_list, this is used for manually initialize variables in CT path with those of MRI path
        self.old_bn_list = old_bn_list # a list of batch_norm internal variables in baseline model
        self.new_bn_list = new_bn_list # a list of batch_norm internal variables for the MRI path in current model
        self.ct_train_queue = tf.train.string_input_producer(ct_train_list, num_epochs = None, shuffle = True) # tensorflow input queue for CT supervision (disabled), CT and MRI
        self.ct_val_queue = tf.train.string_input_producer(ct_val_list, num_epochs = None, shuffle = True)
        self.mr_train_queue = tf.train.string_input_producer(mr_train_list, num_epochs = None, shuffle = True)
        self.mr_val_queue = tf.train.string_input_producer(mr_val_list, num_epochs = None, shuffle = True)
        self.train_config = train_config # configuations for training
        self.lr_update_flag = train_config["lr_update"]

    def next_batch(self, input_queue, capacity = 120, num_threads = 2, min_after_dequeue = 30, label_type = 'float'):

        reader = tf.TFRecordReader()
        fid, serialized_example = reader.read(input_queue)
        parser = tf.parse_single_example(serialized_example, features = decomp_feature)
        dsize_dim0 = tf.cast(parser['dsize_dim0'], tf.int32)
        dsize_dim1 = tf.cast(parser['dsize_dim1'], tf.int32)
        dsize_dim2 = tf.cast(parser['dsize_dim2'], tf.int32)
        lsize_dim0 = tf.cast(parser['lsize_dim0'], tf.int32)
        lsize_dim1 = tf.cast(parser['lsize_dim1'], tf.int32)
        lsize_dim2 = tf.cast(parser['dsize_dim2'], tf.int32)
        data_vol = tf.decode_raw(parser['data_vol'], tf.float32)
        label_vol = tf.decode_raw(parser['label_vol'], tf.float32)

        data_vol = tf.reshape(data_vol, raw_size)
        label_vol = tf.reshape(label_vol, raw_size)
        data_vol = tf.slice(data_vol, [0,0,0],volume_size)
        label_vol = tf.slice(label_vol, [0,0,1], label_size)

        data_feed, label_feed, fid_feed = tf.train.shuffle_batch([data_vol, label_vol, fid], batch_size =self.batch_size , capacity = capacity, \
                                                            num_threads = num_threads, min_after_dequeue = min_after_dequeue)

        pair_feed = tf.concat([data_feed, label_feed], axis = 3) # concatenate them

        return pair_feed, fid_feed

    def _get_optimizer(self, training_iters, global_step):
        """
        Use RMSprop instead of Adam for training WGAN
        """
        learning_rate = self.opt_kwargs.pop("learning_rate", None) # default set to 0.0002
        self.LR_refresh = learning_rate
        self.learning_rate_node = tf.Variable(learning_rate)


        # optimizer for discriminator/ domain classifier
        dis_optimizer = tf.train.RMSPropOptimizer(learning_rate=self.learning_rate_node,
                                                            **self.opt_kwargs).minimize(self.net.dis_loss + 1.0 / self.train_config['dis_sub_iter'] * self.net.dis_reg,
                                                            global_step=global_step,\
                                                            var_list = self.net.cls_vars)

        # optimizer for training generator
        gen_optimizer = tf.train.RMSPropOptimizer(learning_rate=self.learning_rate_node,
                                                            **self.opt_kwargs).minimize(self.net.ct_gen_loss + 1.0 / self.train_config['gen_sub_iter'] * self.net.gen_reg,
                                                            global_step=global_step,\
                                                            var_list = self.net.adapt_vars)
        # clip operation for WGAN for Lipschitz constrain
        self.clip_op = [tf.assign(var, tf.clip_by_value(var, -0.03, 0.03)) for var in self.net.cls_vars if "Variable" in var.name]

        return dis_optimizer, gen_optimizer

    def _initialize(self, training_iters, output_path):
        """
        initialization and tensorboard setting
        """
        self.global_step = tf.Variable(0)

        scalar_summaries = [] # tensorboard summaries
        scalar_summaries.append(tf.summary.scalar('fixed_coeff_reg', self.net.fixed_coeff_reg)) # regulizer of MRI segemter weights, monitor MRI weights unchanged
        scalar_summaries.append(tf.summary.scalar('discriminator_loss', self.net.dis_loss))
        scalar_summaries.append(tf.summary.scalar('generator_loss', self.net.ct_gen_loss))

        scalar_summaries.append(tf.summary.scalar('ct_dice_eval_c1_lv_myo', self.net.ct_dice_eval_c1))
        scalar_summaries.append(tf.summary.scalar('ct_dice_eval_c2_la_blood', self.net.ct_dice_eval_c2))
        scalar_summaries.append(tf.summary.scalar('ct_dice_eval_c3_lv_blood', self.net.ct_dice_eval_c3))
        scalar_summaries.append(tf.summary.scalar('ct_dice_eval_c4_aa', self.net.ct_dice_eval_c4))

        scalar_summaries.append(tf.summary.scalar('mri_dice', self.net.mr_dice_eval)) # set to show absolute value for mr segmentation

        train_images = []
        train_images.append(tf.summary.image('ct_pred', tf.expand_dims(tf.cast(self.net.compact_pred, tf.float32), 3 )) ) # ct prediction
        train_images.append(tf.summary.image('ct_image', tf.expand_dims(tf.cast(self.net.ct[:,:,:,1], tf.float32), 3 )) )
        train_images.append(tf.summary.image('ct_gt', tf.expand_dims(tf.cast(self.net.compact_y, tf.float32), 3))) # ground truth for CT segmentation
        train_images.append(tf.summary.image('mri_validation_pred', tf.expand_dims(tf.cast(self.net.compact_mr_valid, tf.float32), 3 )) ) # mri segmentation for debugging
        train_images.append(tf.summary.image('mri_image', tf.expand_dims(tf.cast(self.net.mr[:,:,:,1], tf.float32), 3 )) )
        train_images.append(tf.summary.image('mri_gt', tf.expand_dims(tf.cast(self.net.compact_mr_y, tf.float32), 3))) # ground truth for CT segmentation

        val_images = []
        val_images.append(tf.summary.image('ct_val_pred', tf.expand_dims(tf.cast(self.net.compact_pred, tf.float32), 3))) # prediction for validation
        val_images.append(tf.summary.image('ct_image', tf.expand_dims(tf.cast(self.net.ct[:,:,:,1], tf.float32), 3)))
        val_images.append(tf.summary.image('ct_val_gt', tf.expand_dims(tf.cast(self.net.compact_y, tf.float32), 3)))

        self.net._get_variables_by_scope() # get variable groups
        self.dis_optimizer, self.gen_optimizer = self._get_optimizer(training_iters, self.global_step) # get optimizers

        scalar_summaries.append(tf.summary.scalar('learning_rate', self.learning_rate_node))

        # get summary writers
        self.scalar_summary_op = tf.summary.merge(scalar_summaries)
        self.train_image_summary_op = tf.summary.merge(train_images)
        self.val_image_summary_op = tf.summary.merge(val_images)

        # variable initializers
        init_glb = tf.global_variables_initializer()
        init_loc = tf.variables_initializer(tf.local_variables())

        return init_glb, init_loc


    def _adapt_copy_weights(self, internal = False):

        if internal is False:
            if len(self.mr_var_list) != len(self.adapt_var_list):
                raise ValueError("cannot copy weight to adaptation because of incorrect varaible lists")
            with tf.variable_scope("", reuse = True):
                for idx in range(len(self.mr_var_list)):
                    logging.info("Now initializing adaptation variable %s with mainstream variable %s"%( self.adapt_var_list[idx], self.mr_var_list[idx]   ))
                    _curr_mr_var = tf.get_default_graph().get_tensor_by_name(self.mr_var_list[idx])
                    _curr_adapt_var = tf.get_default_graph().get_tensor_by_name(self.adapt_var_list[idx])
                    upd_op = tf.assign(_curr_adapt_var,_curr_mr_var)
                    upd_op.eval()

        else:
            logging.info("automatically seeks for variable correspondance")
            all_var_list = tf.contrib.framework.get_variables()
            self.mr_var_list = []
            self.adapt_var_list = []
            for v in all_var_list:
                if ("RMS" in v.name) or ("Adam" in v.name):
                    continue
                else:
                    if "group" in v.name:
                        self.mr_var_list.append(v)
                    elif "adapt" in v.name:
                        self.adapt_var_list.append(v)
                    else:
                        continue
            if len(self.mr_var_list) != len(self.adapt_var_list):
                raise ValueError("cannot copy weight to adaptation because of incorrect varaible list")
            for _curr_adapt_var, _curr_mr_var in zip(self.adapt_var_list, self.mr_var_list):
                upd_op = tf.assign(_curr_adapt_var,_curr_mr_var)
                upd_op.eval()

        logging.info("adaptation module has been initialized! Please remember that it is a one-time operation")


    def _load_batch_norm_weights(self, output_path):
        """
        convenience function for loading weights from eariler version of baseline model for the CT/MR segmentation network
        old_bn_list: a list of bn variable names in baseline model
        new_bn_list: a list of bn Variable names in current model
        """
        if len(self.old_bn_list) != len(self.new_bn_list):
            raise ValueError("two mappings mismatch")

        checkpoint = tf.train.get_checkpoint_state(output_path)
        self.copy_bn_dict = {}
        for old_var, new_var in zip(self.old_bn_list, self.new_bn_list):
            n_group = new_var.split("_")[1]
            new_var = "group_" + n_group + "/" + new_var
            self.copy_bn_dict[old_var] = new_var
            old_variable = tf.contrib.framework.load_variable( output_path, old_var  )
            new_variable = tf.get_default_graph().get_tensor_by_name(new_var)
            upd_op = tf.assign(new_variable, old_variable)
            upd_op.eval()

            logging.info("%s has send value to %s"%(old_var, new_var))

        return 0

    def train(self, output_path, restore=True, restored_path=None, training_iters=200, epochs=1000, dropout=0.75, display_step=5):

        self.output_path = output_path
        if not os.path.exists(output_path):
            logging.info("Allocating '{:}'".format(output_path))
            os.makedirs(output_path)

        self._initialize_logs()
        save_path = os.path.join(output_path, "model.cpkt")
        if epochs == 0:
            return save_path
        init_glb, init_loc = self._initialize(training_iters, output_path)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True # False
        with tf.Session(config=config) as sess:
            sess.run([ init_glb, init_loc] )
            coord = tf.train.Coordinator()
            # For restore models, there are three situations:
            # 1. warming up discriminator, init from MRI segmenter: "restore_from_baseline=True, clear_rms=True"
            #    if restore_from_baseline set True, clear_rms whatever,
            #    this would restore the pre-trained MRI segmenter (without BN), this works together with following lines 1076-1079 to manually load BN
            # 2. after warming up discriminator, start training GAN: "restpre_from_baseline=False, clear_rms=True"
            #    this would restore the entire GAN system with warmed up discriminator (excluding RMS from optimizer)
            # 3. fine-tune GAN from a breakpoint: "restore_from_baseline=False, clear_rms=False"
            if restore:
                if restored_path is None:
                    raise Exception("No restore path is provided")
                ckpt = tf.train.get_checkpoint_state(restored_path)
                if ckpt and ckpt.model_checkpoint_path:
                    self.net.restore(sess, ckpt.model_checkpoint_path, no_gan = self.train_config["restore_from_baseline"], clear_rms = self.train_config["clear_rms"])

            if self.train_config["restore_from_baseline"] is True:  # here initialize the MRI and CT part with the pre-trained MRI segmenter, only call once beginning train
                self._load_batch_norm_weights(restored_path)  # load batchnorm variables of MRI-specific and joint part
                print("initializing from baseline model!")
                self._adapt_copy_weights()  # copy MRI weights to CT adaptation layers for initialization

            if self.lr_update_flag is True: # manually reset learning rate when needed
                sess.run( tf.assign(self.learning_rate_node, self.LR_refresh)  )
                logging.info("New learning rate %s has been loaded"%str(self.LR_refresh))

            train_summary_writer = tf.summary.FileWriter(output_path + "/train_log" + self.train_config['tag'], graph=sess.graph)
            val_summary_writer = tf.summary.FileWriter(output_path + "/val_log" + self.train_config['tag'], graph=sess.graph)

            ct_feed_all, ct_feed_fid = self.next_batch(self.ct_train_queue)
            ct_feed_val, ct_feed_val_fid = self.next_batch(self.ct_val_queue)

            mr_feed_all, mr_feed_fid = self.next_batch(self.mr_train_queue)
            mr_feed_val, mr_feed_val_fid = self.next_batch(self.mr_val_queue)

            threads = tf.train.start_queue_runners(sess = sess, coord = coord, start = True)

            # read iteration configurations
            dis_interval = self.train_config['dis_interval'] # frequency of discriminator updates, default 1. if set 2, update discriminator every 2 iterations
            gen_interval = self.train_config['gen_interval'] # frequency of generator updates, default 1. if set 2, update generator every 2 iterations

            dis_sub_iter = self.train_config['dis_sub_iter'] # number of sub-iteration in one updates, recommended to be larger than gen_sub_iter
            gen_sub_iter = self.train_config['gen_sub_iter']

            # set if we what to increase *_sub_iter every <sub_iter_upd_interval>.
            # for example, if this is set 1, and sub_iter_upd_interval is 100, then increase dis_sub_iter by 1 every 100 iterations
            dis_sub_iter_inc = self.train_config.pop('dis_sub_iter_inc', 0)
            gen_sub_iter_inc = self.train_config.pop('gen_sub_iter_inc', 0)

            sub_iter_upd_interval = self.train_config.pop('iter_upd_interval', 999999999999)
            for epoch in range(epochs):
                for step in range((epoch*training_iters), ((epoch+1)*training_iters)):
                    logging.info("Running step %s epoch %s ..."%(str(step), str(epoch)))
                    start = time.time()
                    # according to DCGAN paper, first update discriminator

                    if dis_interval == 0:
                        pass
                    elif (step % dis_interval == 0) and (step != 0):
                        for itr_dummy in range(dis_sub_iter):
                            # read samples from the pipeline, decomp them and feed them into the discriminator
                            ct_batch, ct_fid = sess.run([ct_feed_all, ct_feed_fid])
                            ct_raw_y = ct_batch[:,:,:,3]
                            ct_batch = ct_batch[:,:,:,0:3]
                            ct_batch_y = _label_decomp(self.num_cls, ct_raw_y)

                            mr_batch, mr_fid = sess.run([mr_feed_all, mr_feed_fid])
                            mr_raw_y = mr_batch[:,:,:,3]
                            mr_batch = mr_batch[:,:,:,0:3]
                            mr_batch_y = _label_decomp(self.num_cls, mr_raw_y)

                            _, _ = sess.run((self.dis_optimizer, self.learning_rate_node),
                                                        feed_dict={ self.net.mr: mr_batch,
                                                                    self.net.ct: ct_batch,
                                                                    self.net.mr_front_bn: False,
                                                                    self.net.joint_bn: False,
                                                                    self.net.ct_front_bn: False,
                                                                    self.net.cls_bn: True,
                                                                    self.net.keep_prob: dropout})
                            # clip operation
                            sess.run(self.clip_op)
                            logging.info("discriminator updated %s of %s"%(str(itr_dummy),str(dis_sub_iter)))


                    # update generator
                    if gen_interval == 0:
                        pass
                    elif (step % gen_interval == 0) and (step != 0):
                        for _ in range(gen_sub_iter):
                            ct_batch, ct_fid = sess.run([ct_feed_all, ct_feed_fid])
                            ct_raw_y = ct_batch[:,:,:,3]
                            ct_batch = ct_batch[:,:,:,0:3]
                            ct_batch_y = _label_decomp(self.num_cls, ct_raw_y)

                            _, _ = sess.run((self.gen_optimizer, self.learning_rate_node),
                                                            feed_dict={ self.net.ct: ct_batch,
                                                                        self.net.mr_front_bn: False,
                                                                        self.net.joint_bn: False,
                                                                        self.net.ct_front_bn: True,
                                                                        self.net.cls_bn: False,
                                                                        self.net.keep_prob: dropout})
                            logging.info("generator updated")

                    # if we need to update iteration configurations, do it here
                    if (step % sub_iter_upd_interval == 0) and (step != 0):
                        dis_sub_iter += dis_sub_iter_inc
                        gen_sub_iter += gen_sub_iter_inc
                        logging.info("sub iterations updated!")

                    logging.info("Training step %s epoch %s has been finished!"%(str(step), str(epoch)))
                    logging.info("Time elapsed %s seconds"%(str(time.time() - start)))

                    # evaluation and write them to tensorboard
                    if step % display_step == 0:

                        # training batch
                        train_ct_batch = sess.run(ct_feed_all)
                        train_ct_raw_y = train_ct_batch[:,:,:,3]
                        train_ct_batch = train_ct_batch[:,:,:,0:3]
                        train_ct_batch_y = _label_decomp(self.num_cls, train_ct_raw_y)

                        mr_batch, mr_fid = sess.run([mr_feed_all, mr_feed_fid])
                        mr_raw_y = mr_batch[:,:,:,3]
                        mr_batch = mr_batch[:,:,:,0:3]
                        mr_batch_y = _label_decomp(self.num_cls, mr_raw_y)

                        self.output_minibatch_stats(sess, train_summary_writer, step, train_ct_batch, train_ct_batch_y, mr_batch, mr_batch_y)

                    if step % (display_step * 1) == 0:

                        # validation batch
                        ct_batch = sess.run(ct_feed_val)
                        ct_raw_y = ct_batch[:,:,:,3]
                        ct_batch = ct_batch[:,:,:,0:3]
                        ct_batch_y = _label_decomp(self.num_cls, ct_raw_y)

                        mr_batch = sess.run(mr_feed_val)
                        mr_raw_y = mr_batch[:,:,:,3]
                        mr_batch = mr_batch[:,:,:,0:3]
                        mr_batch_y = _label_decomp(self.num_cls, mr_raw_y)

                        self.output_minibatch_stats(sess, val_summary_writer, step, ct_batch, ct_batch_y, mr_batch, mr_batch_y, detail = True)

                    # save and restore the model periodically
                    if step % (self.train_config["checkpoint_space"]) == 0:
                        if step == 0:
                            continue
                        else:
                            save_path = _save(sess, save_path, global_step = self.global_step.eval())
                            print('*********************** save path ******************: ', save_path)
                            logging.info("Model has been saved ...")
                            last_ckpt = tf.train.get_checkpoint_state(output_path)
                            if last_ckpt and last_ckpt.model_checkpoint_path:
                                self.net.restore(sess, last_ckpt.model_checkpoint_path)
                            logging.info("Model has been restored for re-allocation")
                            # learning rate decay
                            _pre_lr = sess.run(self.learning_rate_node)
                            sess.run( tf.assign(self.learning_rate_node, _pre_lr *\
                                        self.train_config['lr_decay_factor'])  )

                logging.info("Global step %s"%str(self.global_step.eval()))

            logging.info("Optimization Finished!")
            coord.request_stop()
            coord.join(threads)
            return save_path

    def output_minibatch_stats(self, sess, summary_writer, step, ct_batch, ct_batch_y, mr_batch, mr_batch_y, detail = False):

        """
        minibatch stats for tensorboard observation
        """
        if detail is not True:
            summary_str, summary_img = sess.run([\
                                                    self.scalar_summary_op,
                                                    self.train_image_summary_op],
                                                    feed_dict={\
                                                    self.net.ct_front_bn : False,
                                                    self.net.mr_front_bn : False,
                                                    self.net.joint_bn : False,
                                                    self.net.cls_bn : False,
                                                    self.net.mr: mr_batch,
                                                    self.net.mr_y: mr_batch_y,
                                                    self.net.ct: ct_batch,
                                                    self.net.ct_y: ct_batch_y,
                                                    self.net.keep_prob: 1.\
                                                    })

        else:
            _, curr_conf_mat, summary_str, summary_img = sess.run([\
                                                    self.net.compact_pred,
                                                    self.net.confusion_matrix,
                                                    self.scalar_summary_op,
                                                    self.train_image_summary_op],
                                                    feed_dict={\
                                                    self.net.ct_front_bn : False,
                                                    self.net.mr_front_bn : False,
                                                    self.net.joint_bn : False,
                                                    self.net.cls_bn : False,
                                                    self.net.mr: mr_batch,
                                                    self.net.mr_y: mr_batch_y,
                                                    self.net.ct: ct_batch,
                                                    self.net.ct_y: ct_batch_y,
                                                    self.net.keep_prob: 1.\
                                                    })


            _indicator_eval(curr_conf_mat)
        summary_writer.add_summary(summary_str, step)
        summary_writer.add_summary(summary_img, step)
        summary_writer.flush()

    def test_eval(self, sess, output_path, flip_correction = True):

        all_cm = np.zeros([self.num_cls, self.num_cls])

        pred_folder = os.path.join(output_path, "dense_pred")
        try:
            os.makedirs(pred_folder)
        except:
            logging.info("prediction folder exists")

        self.test_pair_list = list(zip(self.test_label_list, self.test_nii_list))

        sample_eval_list = [] # evaluation of each sample

        for idx_file, pair in enumerate(self.test_pair_list):
            sample_cm = np.zeros([self.num_cls, self.num_cls]) # confusion matrix for each sample
            label_fid = pair[0]
            nii_fid = pair[1]
            if not os.path.isfile(nii_fid):
                raise Exception("cannot find sample %s"%str(nii_fid))
            raw = read_nii_image(nii_fid)
            raw_y = read_nii_image(label_fid)

            if flip_correction is True:
                raw = np.flip(raw, axis = 0)
                raw = np.flip(raw, axis = 1)
                raw_y = np.flip(raw_y, axis = 0)
                raw_y = np.flip(raw_y, axis = 1)

            tmp_y = np.zeros(raw_y.shape)

            frame_list = [kk for kk in range(1, raw.shape[2] - 1)]
            np.random.shuffle(frame_list)
            for ii in range( int( floor( raw.shape[2] // self.net.batch_size  )  )  ):
                vol = np.zeros( [self.net.batch_size, raw_size[0], raw_size[1], raw_size[2]]  )
                slice_y = np.zeros( [self.net.batch_size, label_size[0], label_size[1]]  )
                for idx, jj in enumerate(frame_list[ ii * self.net.batch_size : (ii + 1) * self.net.batch_size  ]):
                    vol[idx, ...] = raw[ ..., jj -1: jj+2  ].copy()
                    slice_y[idx,...] = raw_y[..., jj ].copy()

                vol_y = _label_decomp(self.num_cls, slice_y)
                pred, curr_conf_mat= sess.run([self.net.compact_pred, self.net.confusion_matrix], feed_dict =\
                                              {self.net.ct: vol, self.net.ct_y: vol_y, self.net.keep_prob: 1.0, self.net.mr_front_bn : False,\
                                               self.net.ct_front_bn: False})

                for idx, jj in enumerate(frame_list[ii * self.net.batch_size: (ii + 1) * self.net.batch_size]):
                    tmp_y[..., jj] = pred[idx, ...].copy()

                sample_cm += curr_conf_mat

            all_cm += sample_cm
            sample_dice = _dice(sample_cm)
            sample_jaccard = _jaccard(sample_cm)
            sample_eval_list.append((sample_dice, sample_jaccard))

        subject_dice_list, subject_jaccard_list = self.sample_metric_stddev(sample_eval_list)

        np.savetxt(os.path.join(output_path, "cm.csv"), all_cm)

        return subject_dice_list, subject_jaccard_list

    def sample_metric_stddev(self, sample_eval_list):
        """
        calculate stddev of each organ across samples
        """
        metric_mat = np.zeros( [len(sample_eval_list), self.num_cls, 2]  )
        for organ, ind in list(contour_map.items()):
            for ii in range(len(sample_eval_list)):
                metric_mat[ii, int(ind), 0] = sample_eval_list[ii][0][int(ind)] # dice
                metric_mat[ii, int(ind), 1] = sample_eval_list[ii][1][int(ind)] # jaccard

        print("------- inside the sample_metric_stddev file ---- ")
        for organ, ind in list(contour_map.items()):
            print(( "organ: %s"%organ ))
            print(( "dice_stddev: %s"%( np.std(metric_mat[:, int(ind), 0] ) ) ))
            print(( "jaccard_stddev: %s"%( np.std(metric_mat[:, int(ind), 1] )  )  ))

        print("------- inside the sample_metric_stddev file ----  ")
        for organ, ind in list(contour_map.items()):
            print(( "organ: %s"%organ ))
            print(( "dice_mean: %s"%( np.mean(metric_mat[:, int(ind), 0] ) ) ))
            print(( "jaccard_mean %s"%( np.mean(metric_mat[:, int(ind), 1] )  )  ))

        print("-------")
        print(( "all_dice_mean: %s"%( np.mean(metric_mat[:, 1:, 0] ) ) ))
        print(("all_jaccard_mean: %s" % (np.mean(metric_mat[:, 1:, 1] ) )))

        subject_level_list = np.mean(metric_mat, axis=0)
        subject_level_list_dice = subject_level_list[:,0]
        subject_level_list_jaccard = subject_level_list[:1]

        return subject_level_list_dice, subject_level_list_jaccard

    def _initialize_logs(self):
        """
        This log is actually useless so ignore it
        """
        self.acc_dice_dict = {}
        self.acc_jaccard_dict = {}
        self.log_eval_fid = os.path.join(self.output_path, "acc_eval.csv")
        for organ, ind in list(contour_map.items()):
            self.acc_dice_dict[organ] = [organ]
            self.acc_jaccard_dict[organ] = [organ]

    def test_model(self, this_model, output_path):

        init_glb, init_loc = self._initialize(1, output_path)

        with tf.Session() as sess:
            sess.run([init_glb, init_loc])
            self.net.restore(sess, this_model)
            logging.info("model has been loaded!")
            dice, jac = self.test_eval(sess, output_path)
            logging.info("testing finished")

        return dice, jac