"""
Keras implementation of Multi-level Dense Capsule Networks (Sai Samarth R Phaye*, Apoorva Sikka*, Abhinav Dhall, Deepti R. Bathula), ACCV 2018.

This file trains a 3-level DCNet on CIFAR-10 dataset with the parameters as mentioned in the paper.

We have developed Multi-level DCNets' code using the following GitHub repositories:
- Xifeng Guo's CapsNet code (https://github.com/XifengGuo/CapsNet-Keras) 
- titu1994's DenseNet code (https://github.com/titu1994/DenseNet)

Usage:
       python 3leveldcnet.py
       python 3leveldcnet.py --epochs 50
       python 3leveldcnet.py --epochs 50 --routings 3
       ... ...

Author: Sai Samarth R Phaye, E-mail: `phaye.samarth@gmail.com`, Github: `https://github.com/ssrp/Multi-level-DCNet`
"""

import numpy as np
import random as rn
import os
os.environ['PYTHONHASHSEED'] = '0'
np.random.seed(42)
rn.seed(12345)
import tensorflow as tf
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
session_conf.gpu_options.allow_growth=True
from keras import backend as K
tf.set_random_seed(1234)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)
K.set_image_data_format('channels_last')

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
from keras import layers, models, optimizers
from keras.utils import to_categorical
import matplotlib.pyplot as plt
from utils import combine_images, plot_log
from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask
from keras.layers.normalization import BatchNormalization
from keras.preprocessing.image import ImageDataGenerator
import densenet


def MultiLevelDCNet(input_shape, n_class, routings):
    """
    A Multi-level DCNet on CIFAR-10.

    :param input_shape: data shape, 3d, [width, height, channels]
    :param n_class: number of classes
    :param routings: number of routing iterations
    :return: Two Keras Models, the first one used for training, and the second one for evaluation.
    """
    
    x = layers.Input(shape=input_shape)
    concat_axis = 1 if K.image_data_format() == 'channels_first' else -1

    ########################### Level 1 Capsules ###########################
    # Incorporating DenseNets - Creating a dense block with 8 layers having 32 filters and 32 growth rate.
    conv, nb_filter = densenet.DenseBlock(x, growth_rate=32, nb_layers=8, nb_filter=32)
    # Batch Normalization
    DenseBlockOutput = BatchNormalization(axis=concat_axis, epsilon=1.1e-5)(conv)

    # Creating Primary Capsules (Level 1)
    # Here PrimaryCapsConv2D is the Conv2D output which is used as the primary capsules by reshaping and squashing (squash activation).
    # primarycaps_1 (size: [None, num_capsule, dim_capsule]) is the "reshaped and sqashed output" which will be further passed to the dynamic routing protocol.
    primarycaps_1, PrimaryCapsConv2D = PrimaryCap(DenseBlockOutput, dim_capsule=8, n_channels=12, kernel_size=5, strides=2, padding='valid')

    # Applying ReLU Activation to primary capsules 
    conv = layers.Activation('relu')(PrimaryCapsConv2D)

    ########################### Level 2 Capsules ###########################
    # Incorporating DenseNets - Creating a dense block with 8 layers having 32 filters and 32 growth rate.
    conv, nb_filter = densenet.DenseBlock(conv, growth_rate=32, nb_layers=8, nb_filter=32)
    # Batch Normalization
    DenseBlockOutput = BatchNormalization(axis=concat_axis, epsilon=1.1e-5)(conv)

    # Creating Primary Capsules (Level 2)
    primarycaps_2, PrimaryCapsConv2D = PrimaryCap(DenseBlockOutput, dim_capsule=8, n_channels=12, kernel_size=5, strides=2, padding='valid')

    # Applying ReLU Activation to primary capsules 
    conv = layers.Activation('relu')(PrimaryCapsConv2D)

    ########################### Level 3 Capsules ###########################
    # Incorporating DenseNets - Creating a dense block with 8 layers having 32 filters and 32 growth rate.
    conv, nb_filter = densenet.DenseBlock(conv, growth_rate=32, nb_layers=8, nb_filter=32)
    # Batch Normalization
    DenseBlockOutput = BatchNormalization(axis=concat_axis, epsilon=1.1e-5)(conv)

    # Creating Primary Capsules (Level 3)
    primarycaps_3, PrimaryCapsConv2D = PrimaryCap(DenseBlockOutput, dim_capsule=8, n_channels=12, kernel_size=3, strides=2, padding='valid')

    # Merging Primary Capsules for the Merged DigitCaps (CapsuleLayer formed by combining all levels of primary capsules)
    mergedLayer = layers.merge([primarycaps_1,primarycaps_2,primarycaps_3], mode='concat', concat_axis=1)


    ########################### Separate DigitCaps Outputs (used for training) ###########################
    # Merged DigitCaps
    digitcaps_0 = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,
                             name='digitcaps0')(mergedLayer)
    out_caps_0 = Length(name='capsnet_0')(digitcaps_0)

    # First Level DigitCaps
    digitcaps_1 = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,
                             name='digitcaps1')(primarycaps_1)
    out_caps_1 = Length(name='capsnet_1')(digitcaps_1)

    # Second Level DigitCaps
    digitcaps_2 = CapsuleLayer(num_capsule=n_class, dim_capsule=12, routings=routings,
                             name='digitcaps2')(primarycaps_2)
    out_caps_2 = Length(name='capsnet_2')(digitcaps_2)

    # Third Level DigitCaps
    digitcaps_3 = CapsuleLayer(num_capsule=n_class, dim_capsule=10, routings=routings,
                             name='digitcaps3')(primarycaps_3)
    out_caps_3 = Length(name='capsnet_3')(digitcaps_3)

    ########################### Combined DigitCaps Output (used for evaluation) ###########################
    digitcaps = layers.merge([digitcaps_1,digitcaps_2,digitcaps_3, digitcaps_0], mode='concat', concat_axis=2,
                             name='digitcaps')
    out_caps = Length(name='capsnet')(digitcaps)

    # Reconstruction (decoder) network
    y = layers.Input(shape=(n_class,))
    masked_by_y = Mask()([digitcaps, y])  # The true label is used to mask the output of capsule layer. For training
    masked = Mask()(digitcaps)  # Mask using the capsule with maximal length. For prediction

    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(600, activation='relu', input_dim=int(digitcaps.shape[2]*n_class), name='zero_layer'))
    decoder.add(layers.Dense(600, activation='relu', name='one_layer'))
    decoderFinal = models.Sequential(name='decoderFinal')
    # Concatenating two layers
    decoderFinal.add(layers.Merge([decoder.get_layer('zero_layer'), decoder.get_layer('one_layer')], mode='concat'))
    decoderFinal.add(layers.Dense(1200, activation='relu'))
    decoderFinal.add(layers.Dense(np.prod([32,32,1]), activation='sigmoid'))
    decoderFinal.add(layers.Reshape(target_shape=[32,32,1], name='out_recon'))

    # Model for training
    train_model = models.Model([x, y], [out_caps_0, out_caps_1, out_caps_2, out_caps_3, decoderFinal(masked_by_y)])

    # Model for evaluation (prediction)
    # Note that out_caps is the final prediction. Other predictions could be used for analysing separate-level predictions. 
    eval_model = models.Model(x, [out_caps, out_caps_0, out_caps_1, out_caps_2, out_caps_3, decoderFinal(masked)])

    return train_model, eval_model


def margin_loss(y_true, y_pred):
    """
    Margin loss, as introduced for Capsule Networks.
    :param y_true: [None, n_classes]
    :param y_pred: [None, num_capsule]
    :return: a scalar loss value.
    """
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))

    return K.mean(K.sum(L, 1))


def train(model, data, args):
    """
    Training a 3-level DCNet
    :param model: the 3-level DCNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """

    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data
    row = x_train.shape[1]
    col = x_train.shape[2]
    channel = x_train.shape[3]

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs', histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.h5', monitor='val_capsnet_acc',
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (args.lr_decay ** epoch))

    # compile the model
    # Notice the four separate losses (for separate backpropagations)
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, margin_loss, margin_loss, margin_loss, 'mse'],
                  loss_weights=[1., 1., 1., 1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})

    #model.load_weights('result/weights.h5')

    """
    # Training without data augmentation:
    model.fit([x_train, y_train], [y_train, y_train, y_train, y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, y_test, y_test, y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])
    """

    # Training with data augmentation
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, y_batch, y_batch, y_batch, x_batch[:,:,:,0:1]])

    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
                        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
                        epochs=args.epochs,
                        validation_data=[[x_test, y_test], [y_test, y_test, y_test, y_test, x_test[:,:,:,0:1]]],
                        callbacks=[log, tb, checkpoint, lr_decay])

    # Save model weights
    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    plot_log(args.save_dir + '/log.csv', show=True)

    return model


def test(model, data, args):
    x_test, y_test = data
    print('Testing the model...')
    y_pred, y_pred0, y_pred1, y_pred2, y_pred3, x_recon = model.predict(x_test, batch_size=100)

    print('Test Accuracy (All DigitCaps): ', 100.0*np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/(1.0*y_test.shape[0]))

    print('Test Accuracy (Merged DigitCaps): ', 100.0*np.sum(np.argmax(y_pred0, 1) == np.argmax(y_test, 1))/(1.0*y_test.shape[0]))

    print('Test Accuracy (Level 1 DigitCaps): ', 100.0*np.sum(np.argmax(y_pred1, 1) == np.argmax(y_test, 1))/(1.0*y_test.shape[0]))

    print('Test Accuracy (Level 2 DigitCaps): ', 100.0*np.sum(np.argmax(y_pred2, 1) == np.argmax(y_test, 1))/(1.0*y_test.shape[0]))

    print('Test Accuracy (Level 3 DigitCaps): ', 100.0*np.sum(np.argmax(y_pred3, 1) == np.argmax(y_test, 1))/(1.0*y_test.shape[0]))

    img = combine_images(np.concatenate([x_test[:50],x_recon[:50]]))
    image = img * 255
    Image.fromarray(image.astype(np.uint8)).save(args.save_dir + "/real_and_recon.png")
    print()
    print('Reconstructed images are saved to %s/real_and_recon.png' % args.save_dir)
    plt.imshow(plt.imread(args.save_dir + "/real_and_recon.png"))
    plt.show()


def preprocess_input(x, data_format=None):
    """Preprocesses a tensor encoding a batch of images.
    # Arguments
        x: input Numpy tensor, 4D.
        data_format: data format of the image tensor.
    # Returns
        Preprocessed tensor.
    """
    if data_format is None:
        data_format = K.image_data_format()
    assert data_format in {'channels_last', 'channels_first'}

    if data_format == 'channels_first':
        if x.ndim == 3:
            # 'RGB'->'BGR'
            x = x[::-1, ...]
            # Zero-center by mean pixel
            x[0, :, :] -= 103.939
            x[1, :, :] -= 116.779
            x[2, :, :] -= 123.68
        else:
            x = x[:, ::-1, ...]
            x[:, 0, :, :] -= 103.939
            x[:, 1, :, :] -= 116.779
            x[:, 2, :, :] -= 123.68
    else:
        # 'RGB'->'BGR'
        x = x[..., ::-1]
        # Zero-center by mean pixel
        x[..., 0] -= 103.939
        x[..., 1] -= 116.779
        x[..., 2] -= 123.68

    x *= 0.017 # scale values

    return x

def load_dataset():
    # Load the dataset from Keras
    from keras.datasets import cifar10
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()

    # Preprocessing the dataset
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train= preprocess_input(x_train)
    x_test= preprocess_input(x_test)
    x_train = x_train.reshape(-1, 32, 32, 3).astype('float32') 
    x_test = x_test.reshape(-1, 32, 32, 3).astype('float32')
    y_train = to_categorical(y_train.astype('float32'))
    y_test = to_categorical(y_test.astype('float32'))

    return (x_train, y_train), (x_test, y_test)


if __name__ == "__main__":
    import argparse
    from keras import callbacks

    # setting the hyper parameters
    parser = argparse.ArgumentParser(description="Multi-level DCNets on CIFAR-10.")
    parser.add_argument('--epochs', default=500, type=int)
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--lr', default=0.001, type=float,
                        help="Initial learning rate")
    parser.add_argument('--lr_decay', default=0.9, type=float,
                        help="The value multiplied by lr at each epoch. Set a larger value for larger epochs")
    parser.add_argument('--lam_recon', default=0.512, type=float,
                        help="The coefficient for the loss of decoder")
    parser.add_argument('-r', '--routings', default=3, type=int,
                        help="Number of iterations used in routing algorithm. should > 0")
    parser.add_argument('--shift_fraction', default=0.1, type=float,
                        help="Fraction of pixels to shift at most in each direction.")
    parser.add_argument('--debug', action='store_true',
                        help="Save weights by TensorBoard")
    parser.add_argument('--save_dir', default='./result')
    parser.add_argument('-t', '--testing', action='store_true',
                        help="Test the trained model on testing dataset")
    parser.add_argument('--digit', default=5, type=int,
                        help="Digit to manipulate")
    parser.add_argument('-w', '--weights', default=None,
                        help="The path of the saved weights. Should be specified when testing")
    args = parser.parse_args()
    print(args)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # load data
    (x_train, y_train), (x_test, y_test) = load_dataset()

    # define model
    model, eval_model = MultiLevelDCNet(input_shape=x_train.shape[1:],
                                                  n_class=len(np.unique(np.argmax(y_train, 1))),
                                                  routings=args.routings)
    model.summary()
    
    # train or test
    if args.weights is not None: # init the model weights with provided one
        model.load_weights(args.weights)
    if not args.testing:
        train(model=model, data=((x_train, y_train), (x_test, y_test)), args=args)
    else:  # as long as weights are given, will run testing
        if args.weights is None:
            print('No weights are provided. Will test using random initialized weights.')
        test(model=eval_model, data=(x_test, y_test), args=args)