from __future__ import division
import os
import time
from glob import glob
import tensorflow as tf
import numpy as np
from six.moves import xrange

from utils import *
from loss_functions import *
from scipy.misc import imsave

class MR2CT(object):
    def __init__(self, sess, batch_size=10, depth_MR=32, height_MR=32,
                 width_MR=32, depth_CT=32, height_CT=24,
                 width_CT=24, l_num=2, wd=0.0005, checkpoint_dir=None, path_patients_h5=None, learning_rate=2e-8):
        """
        Args:
            sess: TensorFlow session
            batch_size: The size of batch. Should be specified before training.
            output_size: (optional) The resolution in pixels of the images. [64]
            y_dim: (optional) Dimension of dim for y. [None]
            z_dim: (optional) Dimension of dim for Z. [100]
            gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
            df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
            gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024]
            dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024]
            c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]
        """
        self.sess = sess
        self.l_num=l_num
        self.wd=wd
        self.learning_rate=learning_rate
        self.batch_size=batch_size       
        self.depth_MR=depth_MR
        self.height_MR=height_MR
        self.width_MR=width_MR
        self.depth_CT=depth_CT
        self.height_CT=height_CT
        self.width_CT=width_CT
        self.checkpoint_dir = checkpoint_dir
        self.data_generator = Generator_3D_patches(path_patients_h5,self.batch_size)
        self.build_model()

    def build_model(self):
    	self.inputMR=tf.placeholder(tf.float32, shape=[None, self.depth_MR, self.height_MR, self.width_MR, 1])
        self.CT_GT=tf.placeholder(tf.float32, shape=[None, self.depth_CT, self.height_CT, self.width_CT, 1])
        batch_size_tf = tf.shape(self.inputMR)[0]  #variable batchsize so we can test here
        self.train_phase = tf.placeholder(tf.bool, name='phase_train')
        self.G = self.generator(self.inputMR,batch_size_tf)
        print 'shape output G ',self.G.get_shape()
        self.global_step = tf.Variable(0, name='global_step', trainable=False)
        self.g_loss=lp_loss(self.G, self.CT_GT, self.l_num, batch_size_tf)
        print 'learning rate ',self.learning_rate
        #self.g_optim =tf.train.AdamOptimizer(self.learning_rate).minimize(self.g_loss)
        #tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.g_loss)
        self.merged = tf.merge_all_summaries()
        self.writer = tf.train.SummaryWriter("./summaries", self.sess.graph)
        self.saver = tf.train.Saver()


    def generator(self,inputMR,batch_size_tf):        
        
        ######## FCN for the 32x32x32 to 24x24x24 , added dilaion by yourself####################################        
        conv1_a = conv_op_3d_bn(inputMR, name="conv1_a", kh=5, kw=5, kz=5,  n_out=48, dh=1, dw=1, dz=1, wd=self.wd, padding='VALID',train_phase=self.train_phase)#30
        conv2_a = conv_op_3d_bn(conv1_a, name="conv2_a", kh=3, kw=3, kz=3,  n_out=96, dh=1, dw=1, dz=1, wd=self.wd, padding='SAME',train_phase=self.train_phase)
        conv3_a = conv_op_3d_bn(conv2_a, name="conv3_a", kh=3, kw=3, kz=3,  n_out=128, dh=1, dw=1, dz=1, wd=self.wd, padding='SAME',train_phase=self.train_phase)#28
        conv4_a = conv_op_3d_bn(conv3_a, name="conv4_a", kh=5, kw=5, kz=5,  n_out=96, dh=1, dw=1, dz=1, wd=self.wd, padding='VALID',train_phase=self.train_phase)
        conv5_a = conv_op_3d_bn(conv4_a, name="conv5_a", kh=3, kw=3, kz=3,  n_out=48, dh=1, dw=1, dz=1, wd=self.wd, padding='SAME',train_phase=self.train_phase)#26
        conv6_a = conv_op_3d_bn(conv5_a, name="conv6_a", kh=3, kw=3, kz=3,  n_out=32, dh=1, dw=1, dz=1, wd=self.wd, padding='SAME',train_phase=self.train_phase)
        #conv7_a = conv_op_3d_bn(conv6_a, name="conv7_a", kh=3, kw=3, kz=3,  n_out=1, dh=1, dw=1, dz=1, wd=self.wd, padding='SAME',train_phase=self.train_phase)#24
        conv7_a = conv_op_3d_norelu(conv6_a, name="conv7_a", kh=3, kw=3, kz=3,  n_out=1, dh=1, dw=1, dz=1, wd=self.wd, padding='SAME')#24 I modified it here,dong
        self.MR_16_downsampled=conv7_a#JUST FOR TEST
        return conv7_a




    def train(self, config):
    	path_test='/home/dongnie/warehouse/prostate/ganData64to24Test'
        print 'global_step ', self.global_step.name
        print 'trainable vars '
        for v in tf.trainable_variables():
            print v.name

        if self.load(self.checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            self.sess.run(tf.initialize_all_variables())
        temp = set(tf.all_variables())
        start = self.global_step.eval() # get last global_step
        print("Start from:", start)

        ############ This is for only initializing adam vars####################
        temp = set(tf.all_variables())
        self.g_optim =tf.train.AdamOptimizer(self.learning_rate).minimize(self.g_loss)
        self.sess.run(tf.initialize_variables(set(tf.all_variables()) - temp))
        print("Start after adam (should be the same):", start)
        #####################################

        for it in range(start,config.iterations):

            X,y=self.data_generator.next()
            

            # Update G network
            _, loss_eval, layer_out_eval = self.sess.run([self.g_optim, self.g_loss, self.MR_16_downsampled],
                        feed_dict={ self.inputMR: X, self.CT_GT:y, self.train_phase: True })
            self.global_step.assign(it).eval() # set and update(eval) global_step with index, i
            

            if it%config.show_every==0:#show loss every show_every its
                print 'it ',it,'loss ',loss_eval
                print 'layer min ', np.min(layer_out_eval)
                print 'layer max ', np.max(layer_out_eval)
                print 'layer mean ', np.mean(layer_out_eval)
             #    print 'trainable vars ' 
            	# for v in tf.trainable_variables(): 
	            #     print v.name 
	            #     data_var=self.sess.run(v) 
	            #     grads = tf.gradients(self.g_loss, v) 
	            #     var_grad_val = self.sess.run(grads, feed_dict={self.inputMR: X, self.CT_GT:y }) 
	            #     print 'grad min ', np.min(var_grad_val) 
	            #     print 'grad max ', np.max(var_grad_val) 
	            #     print 'grad mean ', np.mean(var_grad_val) 
	            #     #print 'shape ',data_var.shape 
	            #     print 'filter min ', np.min(data_var) 
	            #     print 'filter max ', np.max(data_var) 
	            #     print 'filter mean ', np.mean(data_var)    
	                #self.writer.add_summary(summary, it)
                            # print 'trainable vars ' 

            
            if it%config.test_every==0 and it!=0:#==0:#test one subject                

                mr_test_itk=sitk.ReadImage(os.path.join(path_test,'prostate_1to1_MRI.nii'))
                ct_test_itk=sitk.ReadImage(os.path.join(path_test,'prostate_1to1_CT.nii'))
                mrnp=sitk.GetArrayFromImage(mr_test_itk)
                #mu=np.mean(mrnp)
                #mrnp=(mrnp-mu)/(np.max(mrnp)-np.min(mrnp))
                ctnp=sitk.GetArrayFromImage(ct_test_itk)
                print mrnp.dtype
                print ctnp.dtype
                ct_estimated=self.test_1_subject(mrnp,ctnp,[32,32,32],[24,24,24],[5,5,2])
                psnrval=psnr(ct_estimated,ctnp)
                print ct_estimated.dtype
                print ctnp.dtype
                print 'psnr= ',psnrval
                volout=sitk.GetImageFromArray(ct_estimated)
                sitk.WriteImage(volout,'ct_estimated_{}'.format(it)+'.nii.gz')

            if it%config.save_every==0:#save weights every save_every iterations
                self.save(self.checkpoint_dir, it)

    def evaluate(self,patch_MR):
        """ patch_MR is a np array of shape [H,W,nchans]
        """
        patch_MR=np.expand_dims(patch_MR,axis=0)#[1,H,W,nchans]
        patch_MR=np.expand_dims(patch_MR,axis=4)#[1,H,W,nchans]

        patch_CT_pred, MR16_eval= self.sess.run([self.G,self.MR_16_downsampled],
                        feed_dict={ self.inputMR: patch_MR, self.train_phase: False})

        patch_CT_pred=np.squeeze(patch_CT_pred)#[Z,H,W]
        #imsave('mr32.png',np.squeeze(MR16_eval[0,:,:,2]))
        #imsave('ctpred.png',np.squeeze(patch_CT_pred[0,:,:,0]))
        #print 'mean of layer  ',np.mean(MR16_eval)
        #print 'min ct estimated ',np.min(patch_CT_pred)
        #print 'max ct estimated ',np.max(patch_CT_pred)
        #print 'mean of ctpatch estimated ',np.mean(patch_CT_pred)
        return patch_CT_pred


    def test_1_subject(self,MR_image,CT_GT,MR_patch_sz,CT_patch_sz,step):
        """
            receives an MR image and returns an estimated CT image of the same size
        """
        matFA=MR_image
        matSeg=CT_GT
        dFA=MR_patch_sz
        dSeg=CT_patch_sz

        eps=1e-5
        [row,col,leng]=matFA.shape
        margin1=int((dFA[0]-dSeg[0])/2)
        margin2=int((dFA[1]-dSeg[1])/2)
        margin3=int((dFA[2]-dSeg[2])/2)
        cubicCnt=0
        marginD=[margin1,margin2,margin3]
        print 'matFA shape is ',matFA.shape
        matFAOut=np.zeros([row+2*marginD[0],col+2*marginD[1],leng+2*marginD[2]])
        print 'matFAOut shape is ',matFAOut.shape
        matFAOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matFA

        # matFAOut[0:marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matFA[0:marginD[0],:,:] #we'd better flip it along the first dimension
        # matFAOut[row+marginD[0]:matFAOut.shape[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matFA[row-marginD[0]:matFA.shape[0],:,:] #we'd better flip it along the 1st dimension

        # matFAOut[marginD[0]:row+marginD[0],0:marginD[1],marginD[2]:leng+marginD[2]]=matFA[:,0:marginD[1],:] #we'd better flip it along the 2nd dimension
        # matFAOut[marginD[0]:row+marginD[0],col+marginD[1]:matFAOut.shape[1],marginD[2]:leng+marginD[2]]=matFA[:,col-marginD[1]:matFA.shape[1],:] #we'd better to flip it along the 2nd dimension

        # matFAOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],0:marginD[2]]=matFA[:,:,0:marginD[2]] #we'd better flip it along the 3rd dimension
        # matFAOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]+leng:matFAOut.shape[2]]=matFA[:,:,leng-marginD[2]:matFA.shape[2]]

        if margin1!=0:
            matFAOut[0:marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matFA[marginD[0]-1::-1,:,:] #reverse 0:marginD[0]
            matFAOut[row+marginD[0]:matFAOut.shape[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matFA[matFA.shape[0]-1:row-marginD[0]-1:-1,:,:] #we'd better flip it along the 1st dimension
        if margin2!=0:
            matFAOut[marginD[0]:row+marginD[0],0:marginD[1],marginD[2]:leng+marginD[2]]=matFA[:,marginD[1]-1::-1,:] #we'd flip it along the 2nd dimension
            matFAOut[marginD[0]:row+marginD[0],col+marginD[1]:matFAOut.shape[1],marginD[2]:leng+marginD[2]]=matFA[:,matFA.shape[1]-1:col-marginD[1]-1:-1,:] #we'd flip it along the 2nd dimension
        if margin3!=0:
            matFAOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],0:marginD[2]]=matFA[:,:,marginD[2]-1::-1] #we'd better flip it along the 3rd dimension
            matFAOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]+leng:matFAOut.shape[2]]=matFA[:,:,matFA.shape[2]-1:leng-marginD[2]-1:-1]
        


        matOut=np.zeros((matSeg.shape[0],matSeg.shape[1],matSeg.shape[2]))
        used=np.zeros((matSeg.shape[0],matSeg.shape[1],matSeg.shape[2]))+eps
        #fid=open('trainxxx_list.txt','a');
        for i in range(0,row-dSeg[0],step[0]):
            for j in range(0,col-dSeg[1],step[1]):
                for k in range(0,leng-dSeg[2],step[2]):
                    volSeg=matSeg[i:i+dSeg[0],j:j+dSeg[1],k:k+dSeg[2]]
                    #print 'volSeg shape is ',volSeg.shape
                    volFA=matFAOut[i:i+dSeg[0]+2*marginD[0],j:j+dSeg[1]+2*marginD[1],k:k+dSeg[2]+2*marginD[2]]
                    #print 'volFA shape is ',volFA.shape
                    #mynet.blobs['dataMR'].data[0,0,...]=volFA
                    #mynet.forward()
                    #temppremat = mynet.blobs['softmax'].data[0].argmax(axis=0) #Note you have add softmax layer in deploy prototxt
                    temppremat=self.evaluate(volFA)
                    #print 'patchout shape ',temppremat.shape
                    #temppremat=volSeg
                    matOut[i:i+dSeg[0],j:j+dSeg[1],k:k+dSeg[2]]=matOut[i:i+dSeg[0],j:j+dSeg[1],k:k+dSeg[2]]+temppremat;
                    used[i:i+dSeg[0],j:j+dSeg[1],k:k+dSeg[2]]=used[i:i+dSeg[0],j:j+dSeg[1],k:k+dSeg[2]]+1;
        matOut=matOut/used
        return matOut


            
    def save(self, checkpoint_dir, step):
        model_name = "MR2CT.model"
        
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir, model_name),
                        global_step=step)

    def load(self, checkpoint_dir):
        print(" [*] Reading checkpoints...")

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            return True
        else:
            return False