import keras
from keras import layers, models, optimizers
from keras import backend as K
from keras.utils import to_categorical
from keras.layers import Dense, Reshape
from keras.layers.core import Activation, Flatten
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D, Conv2D, MaxPooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras import callbacks
from keras.utils.vis_utils import plot_model

from utils import combine_images, load_emnist_balanced
from PIL import Image, ImageFilter
from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask
from snapshot import SnapshotCallbackBuilder
import os
import numpy as np
import tensorflow as tf
import os
import argparse

K.set_image_data_format('channels_last')

"""
Switching the GPU to allow growth
"""
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
K.set_session(sess)


def CapsNet(input_shape, n_class, routings):
    """
    Defining the CapsNet
    :param input_shape: data shape, 3d, [width, height, channels]
    :param n_class: number of classes
    :param routings: number of routing iterations
    :return: Two Keras Models, the first one used for training, and the second one for evaluation.
    """
    x = layers.Input(shape=input_shape)
    conv1 = layers.Conv2D(filters=64, kernel_size=3, strides=1, padding='valid', activation='relu', name='conv1')(x)
    conv2 = layers.Conv2D(filters=128, kernel_size=3, strides=1, padding='valid', activation='relu', name='conv2')(conv1)
    conv3 = layers.Conv2D(filters=256, kernel_size=3, strides=2, padding='valid', activation='relu', name='conv3')(conv2)
    primarycaps = PrimaryCap(conv3, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')
    digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,channels=32,name='digitcaps')(primarycaps)    
    out_caps = Length(name='capsnet')(digitcaps)

    """
    Decoder Network
    """
    y = layers.Input(shape=(n_class,))
    masked_by_y = Mask()([digitcaps, y])
    masked = Mask()(digitcaps) 

    decoder = models.Sequential(name='decoder')
    decoder.add(Dense(input_dim=16*n_class, activation="relu", output_dim=7*7*32))
    decoder.add(Reshape((7, 7, 32)))
    decoder.add(BatchNormalization(momentum=0.8))
    decoder.add(layers.Deconvolution2D(32, 3, 3,subsample=(1, 1),border_mode='same', activation="relu"))
    decoder.add(layers.Deconvolution2D(16, 3, 3,subsample=(2, 2),border_mode='same', activation="relu"))
    decoder.add(layers.Deconvolution2D(8, 3, 3,subsample=(2, 2),border_mode='same', activation="relu"))
    decoder.add(layers.Deconvolution2D(4, 3, 3,subsample=(1, 1),border_mode='same', activation="relu"))
    decoder.add(layers.Deconvolution2D(1, 3, 3,subsample=(1, 1),border_mode='same', activation="sigmoid"))
    decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))     
    
    """
    Models for training and evaluation (prediction)
    """
    train_model = models.Model([x, y], [out_caps, decoder(masked_by_y)])
    eval_model = models.Model(x, [out_caps, decoder(masked)])

    return train_model, eval_model


def margin_loss(y_true, y_pred):
    """
    Marginal loss used for the CapsNet training
    :param y_true: [None, n_classes]
    :param y_pred: [None, num_capsule]
    :return: a scalar loss value.
    """
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))

    return K.mean(K.sum(L, 1))


def train(model, data, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    (x_train, y_train), (x_test, y_test) = data

    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.h5', monitor='val_capsnet_acc',
                                           save_best_only=False, save_weights_only=True, verbose=1)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (args.lr_decay ** epoch))

    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})

    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
                        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
                        epochs=args.epochs,
                        shuffle = True,
                        validation_data=[[x_test, y_test], [y_test, x_test]],
                        callbacks=snapshot.get_callbacks(log,model_prefix=model_prefix))

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    return model


def test(model, data, args):
    """
    Testing the trained CapsuleNet
    """
    x_test, y_test = data
    y_pred, x_recon = model.predict(x_test, batch_size=args.batch_size*8)
    print('-'*30 + 'Begin: test' + '-'*30)
    print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/float(y_test.shape[0]))
    
class dataGeneration():
    def __init__(self, model,data,args,samples_to_generate = 2):
        """
        Generating new images 
        :param model: the pre-trained CapsNet model
        :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
        :param args: arguments
        :param samples_to_generate: number of new training samples to generate per class
        """
        self.model = model
        self.data = data
        self.args = args
        self.samples_to_generate = samples_to_generate
        print("-"*100)
        
        (x_train, y_train), (x_test, y_test), x_recon = self.remove_missclassifications()
        self.data = (x_train, y_train), (x_test, y_test)
        self.reconstructions = x_recon
        self.inst_parameter, self.global_position, self.masked_inst_parameter = self.get_inst_parameters()
        print("Instantiation parameters extracted.")
        print("-"*100)
        self.x_decoder_retrain,self.y_decoder_retrain = self.decoder_retraining_dataset()
        self.retrained_decoder = self.decoder_retraining()
        print("Decoder re-training completed.")
        print("-"*100)
        self.class_variance, self.class_max, self.class_min = self.get_limits()
        self.generated_images,self.generated_labels = self.generate_data()
        print("New images of the shape ",self.generated_images.shape," Generated.")
        print("-"*100)

    def save_output_image(self,samples,image_name):
        """
        Visualizing and saving images in the .png format 
        :param samples: images to be visualized
        :param image_name: name of the saved .png file
        """
        if not os.path.exists(args.save_dir+"/images"):
            os.makedirs(args.save_dir+"/images")
        img = combine_images(samples)
        img = img * 255
        Image.fromarray(img.astype(np.uint8)).save(args.save_dir + "/images/"+image_name+".png")
        print(image_name, "Image saved.")

    def remove_missclassifications(self):
        """
        Removing the wrongly classified samples from the training set. We do not alter the testing set.
        :return: dataset with miss classified samples removed and the initial reconstructions.
        """
        model = self.model
        data = self.data
        args = self.args   
        (x_train, y_train), (x_test, y_test) = data
        y_pred, x_recon = model.predict(x_train, batch_size=args.batch_size)
        acc = np.sum(np.argmax(y_pred, 1) == np.argmax(y_train, 1))/y_train.shape[0]
        cmp = np.argmax(y_pred, 1) == np.argmax(y_train, 1)
        bin_cmp = np.where(cmp == 0)[0]
        x_train = np.delete(x_train,bin_cmp,axis=0)
        y_train = np.delete(y_train,bin_cmp,axis=0)
        x_recon = np.delete(x_recon,bin_cmp,axis=0)
        self.save_output_image(x_train[:100],"original training")
        self.save_output_image(x_recon[:100],"original reconstruction")
        return (x_train, y_train), (x_test, y_test), x_recon
        
        
    def get_inst_parameters(self):
        """
        Extracting the instantiation parameters for the existing training set
        :return: instantiation parameters, corresponding labels and the masked instantiation parameters
        """
        model = self.model
        data = self.data
        args = self.args
        (x_train, y_train), (x_test, y_test) = data
        if not os.path.exists(args.save_dir+"/check"):
            os.makedirs(args.save_dir+"/check")
        
        if not os.path.exists(args.save_dir+"/check/x_inst.npy"):
            get_digitcaps_output = K.function([model.layers[0].input],[model.get_layer("digitcaps").output])

            get_capsnet_output = K.function([model.layers[0].input],[model.get_layer("capsnet").output])

            if (x_train.shape[0]%args.num_cls==0):
                lim = int(x_train.shape[0]/args.num_cls)
            else:
                lim = int(x_train.shape[0]/args.num_cls)+1

            for t in range(0,lim):
                if (t==int(x_train.shape[0]/args.num_cls)):
                    mod = x_train.shape[0]%args.num_cls
                    digitcaps_output = get_digitcaps_output([x_train[t*args.num_cls:t*args.num_cls+mod]])[0]
                    capsnet_output = get_capsnet_output([x_train[t*args.num_cls:t*args.num_cls+mod]])[0]
                else:
                    digitcaps_output = get_digitcaps_output([x_train[t*args.num_cls:(t+1)*args.num_cls]])[0]
                    capsnet_output = get_capsnet_output([x_train[t*args.num_cls:(t+1)*args.num_cls]])[0]
                masked_inst = []
                inst = []
                where = []
                for j in range(0,digitcaps_output.shape[0]):
                    ind = capsnet_output[j].argmax()
                    inst.append(digitcaps_output[j][ind])
                    where.append(ind)
                    for z in range(0,args.num_cls):
                        if (z==ind):
                            continue
                        else:
                            digitcaps_output[j][z] = digitcaps_output[j][z].fill(0.0)
                    masked_inst.append(digitcaps_output[j].flatten())
                masked_inst = np.asarray(masked_inst)
                masked_inst[np.isnan(masked_inst)] = 0
                inst = np.asarray(inst)
                where = np.asarray(where)
                if (t==0):
                    x_inst = np.concatenate([inst])
                    pos = np.concatenate([where])
                    x_masked_inst = np.concatenate([masked_inst])
                else:
                    x_inst = np.concatenate([x_inst,inst])
                    pos = np.concatenate([pos,where])
                    x_masked_inst = np.concatenate([x_masked_inst,masked_inst])
            np.save(args.save_dir+"/check/x_inst",x_inst)
            np.save(args.save_dir+"/check/pos",pos)
            np.save(args.save_dir+"/check/x_masked_inst",x_masked_inst)
        else:
            x_inst = np.load(args.save_dir+"/check/x_inst.npy")
            pos = np.load(args.save_dir+"/check/pos.npy")
            x_masked_inst = np.load(args.save_dir+"/check/x_masked_inst.npy")
        return x_inst,pos,x_masked_inst
    
    def decoder_retraining_dataset(self):
        """
        Generating the dataset for the decoder retraining technique with unsharp masking
        :return: training samples and labels for decoder retraining 
        """
        model = self.model
        data = self.data
        args = self.args
        x_recon = self.reconstructions
        (x_train, y_train), (x_test, y_test) = data 
        if not os.path.exists(args.save_dir+"/check"):
            os.makedirs(args.save_dir+"/check")
        if not os.path.exists(args.save_dir+"/check/x_decoder_retrain.npy"):
            for q in range(0,x_recon.shape[0]):
                save_img = Image.fromarray((x_recon[q]*255).reshape(28,28).astype(np.uint8))
                image_more_sharp = save_img.filter(ImageFilter.UnsharpMask(radius=1, percent=1000, threshold=1))
                img_arr = np.asarray(image_more_sharp)
                img_arr = img_arr.reshape(-1,28,28,1).astype('float32') / 255.
                if (q==0):
                    x_recon_sharped = np.concatenate([img_arr])
                else:
                    x_recon_sharped = np.concatenate([x_recon_sharped,img_arr])
            self.save_output_image(x_recon_sharped[:100],"sharpened reconstructions")
            x_decoder_retrain = self.masked_inst_parameter
            y_decoder_retrain = x_recon_sharped
            np.save(args.save_dir+"/check/x_decoder_retrain",x_decoder_retrain)
            np.save(args.save_dir+"/check/y_decoder_retrain",y_decoder_retrain)
        else:
            x_decoder_retrain = np.load(args.save_dir+"/check/x_decoder_retrain.npy")
            y_decoder_retrain = np.load(args.save_dir+"/check/y_decoder_retrain.npy")
        return x_decoder_retrain,y_decoder_retrain
    
    def decoder_retraining(self):
        """
        The decoder retraining technique to give the sharpening ability to the decoder 
        :return: the retrained decoder
        """
        model = self.model
        data = self.data
        args = self.args
        x_decoder_retrain, y_decoder_retrain = self.x_decoder_retrain,self.y_decoder_retrain
        
        decoder = eval_model.get_layer('decoder')
        decoder_in = layers.Input(shape=(16*47,))
        decoder_out = decoder(decoder_in)
        retrained_decoder = models.Model(decoder_in,decoder_out)
        if (args.verbose):
            retrained_decoder.summary()
        retrained_decoder.compile(optimizer=optimizers.Adam(lr=args.lr),loss='mse',loss_weights=[1.0])
        if not os.path.exists(args.save_dir+"/retrained_decoder.h5"):
            retrained_decoder.fit(x_decoder_retrain, y_decoder_retrain, batch_size=args.batch_size, epochs=20)
            retrained_decoder.save_weights(args.save_dir + '/retrained_decoder.h5')
        else:
            retrained_decoder.load_weights(args.save_dir + '/retrained_decoder.h5')
        
        retrained_reconstructions = retrained_decoder.predict(x_decoder_retrain, batch_size=args.batch_size)
        self.save_output_image(retrained_reconstructions[:100],"retrained reconstructions")
        return retrained_decoder
        
    def get_limits(self):
        """
        Calculating the boundaries of the instantiation parameter distributions
        :return: instantiation parameter indices in the descending order of variance, min and max values per class
        """
        args = self.args
        x_inst = self.inst_parameter
        pos = self.global_position
        glob_min = np.amin(x_inst.transpose(),axis=1)
        glob_max = np.amax(x_inst.transpose(),axis=1)
        
        if not os.path.exists(args.save_dir+"/check"):
            os.makedirs(args.save_dir+"/check")
        if not os.path.exists(args.save_dir+"/check/class_cov.npy"):
            for cl in range(0,self.args.num_cls):
                tmp_glob = []
                for it in range(0,x_inst.shape[0]):
                    if (pos[it]==cl):
                        tmp_glob.append(x_inst[it])
                tmp_glob = np.asarray(tmp_glob)
                tmp_glob = tmp_glob.transpose()
                tmp_cov_max = np.flip(np.argsort(np.around(np.cov(tmp_glob),5).diagonal()),axis=0)
                tmp_min = np.amin(tmp_glob,axis=1)
                tmp_max = np.amax(tmp_glob,axis=1)
                if (cl==0):
                    class_cov = np.vstack([tmp_cov_max])
                    class_min = np.vstack([tmp_min])
                    class_max = np.vstack([tmp_max])
                else:
                    class_cov = np.vstack([class_cov,tmp_cov_max])
                    class_min = np.vstack([class_min,tmp_min]) 
                    class_max = np.vstack([class_max,tmp_max]) 
            np.save(args.save_dir+"/check/class_cov",class_cov)
            np.save(args.save_dir+"/check/class_min",class_min)
            np.save(args.save_dir+"/check/class_max",class_max)
        else:
            class_cov = np.load(args.save_dir+"/check/class_cov.npy")
            class_min = np.load(args.save_dir+"/check/class_min.npy")
            class_max = np.load(args.save_dir+"/check/class_max.npy")
        return class_cov,class_max,class_min

    def generate_data(self):
        """
        Generating new images and samples with the data generation technique 
        :return: the newly generated images and labels
        """
        data = self.data
        args = self.args   
        (x_train, y_train), (x_test, y_test) = data
        x_masked_inst = self.masked_inst_parameter
        pos = self.global_position
        retrained_decoder = self.retrained_decoder
        class_cov = self.class_variance
        class_max = self.class_max
        class_min = self.class_min
        samples_to_generate = self.samples_to_generate
        generated_images = np.empty([0,x_train.shape[1],x_train.shape[2],x_train.shape[3]])
        generated_images_with_ori = np.empty([0,x_train.shape[1],x_train.shape[2],x_train.shape[3]])
        generated_labels = np.empty([0])
        for cl in range(0,args.num_cls):
            count = 0
            for it in range(0,x_masked_inst.shape[0]): 
                if (count==samples_to_generate):
                    break
                if (pos[it]==cl):
                    count = count + 1
                    generated_images_with_ori = np.concatenate([generated_images_with_ori,x_train[it].reshape(1,x_train.shape[1],x_train.shape[2],x_train.shape[3])])
                    noise_vec = x_masked_inst[it][x_masked_inst[it].nonzero()]
                    for inst in range(int(class_cov.shape[1]/2)):
                        ind = np.where(class_cov[cl]==inst)[0][0]
                        noise = np.random.uniform(class_min[cl][ind],class_max[cl][ind])
                        noise_vec[ind] = noise
                    x_masked_inst[it][x_masked_inst[it].nonzero()] = noise_vec
                    new_image = retrained_decoder.predict(x_masked_inst[it].reshape(1,args.num_cls*class_cov.shape[1]))
                    generated_images = np.concatenate([generated_images,new_image])
                    generated_labels = np.concatenate([generated_labels,np.asarray([cl])])
                    generated_images_with_ori = np.concatenate([generated_images_with_ori,new_image])
        self.save_output_image(generated_images,"generated_images")
        self.save_output_image(generated_images_with_ori,"generated_images with originals")
        generated_labels = keras.utils.to_categorical(generated_labels, num_classes=args.num_cls)
        if not os.path.exists(args.save_dir+"/generated_data"):
            os.makedirs(args.save_dir+"/generated_data")
        np.save(args.save_dir+"/generated_data/generated_images",generated_images)
        np.save(args.save_dir+"/generated_data/generated_label",generated_labels)
        return generated_images,generated_labels

if __name__ == "__main__":
    """
    Setting the hyper-parameters
    """
    parser = argparse.ArgumentParser(description="TextCaps")
    parser.add_argument('--epochs', default=60, type=int)
    parser.add_argument('--verbose', default=False, type=bool)
    parser.add_argument('--cnt', default=200, type=int)
    parser.add_argument('-n','--num_cls', default=47, type=int, help="Iterations")
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--samples_to_generate', default=10, type=int)
    parser.add_argument('--lr', default=0.001, type=float,
                        help="Initial learning rate")
    parser.add_argument('--lr_decay', default=0.9, type=float,
                        help="The value multiplied by lr at each epoch. Set a larger value for larger epochs")
    parser.add_argument('--lam_recon', default=0.392, type=float,
                        help="The coefficient for the loss of decoder")
    parser.add_argument('-r', '--routings', default=3, type=int,
                        help="Number of iterations used in routing algorithm. should > 0")
    parser.add_argument('--shift_fraction', default=0.1, type=float,
                        help="Fraction of pixels to shift at most in each direction.")
    parser.add_argument('--save_dir', default='./emnist_bal_200')
    parser.add_argument('-dg', '--data_generate', action='store_true',
                        help="Generate new data with pre-trained model")
    parser.add_argument('-w', '--weights', default=None,
                        help="The path of the saved weights. Should be specified when testing")
    args = parser.parse_args()
    print(args)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    (x_train, y_train), (x_test, y_test) = load_emnist_balanced(args.cnt)

    model, eval_model = CapsNet(input_shape=x_train.shape[1:],
                                                  n_class=len(np.unique(np.argmax(y_train, 1))),
                                                  routings=args.routings)
    if (args.verbose):
        model.summary()

    """
    Snap shot training
    :param M: number of snapshots
    :param nb_epoch: number of epochs
    :param alpha_zero: initial learning rate
    """
    M = 3
    nb_epoch = T = args.epochs
    alpha_zero = 0.01
    model_prefix = 'Model_'
    snapshot = SnapshotCallbackBuilder(T, M, alpha_zero,args.save_dir)
    
    if args.weights is not None:
        model.load_weights(args.weights)
    if not args.data_generate:      
        train(model=model, data=((x_train, y_train), (x_test, y_test)), args=args)
        test(model=eval_model, data=(x_test, y_test), args=args)
        
    else:
        if args.weights is None:
            print('No weights are provided. You need to train a model first.')
        else:
            data_generator = dataGeneration(model=eval_model, data=((x_train, y_train), (x_test, y_test)), args=args, samples_to_generate = args.samples_to_generate)