#!/usr/bin/env python
"""
Reproduce the artificial example described in the paper [Zellinger2017a]

The central moment discrepancy (CMD) is used for domain adaptation
as first described in the preliminary conference paper [Zellinger2017b].
It is implemented as keras objective function.
This implementation uses keras 1.1.0.

[Zellinger2017a] W. Zellinger, B.A. Moser, T. Grubinger, E. Lughofer,
T. Natschlaeger, and S. Saminger-Platz, "Robust unsupervised domain adaptation
for neural networks via moment alignment," arXiv preprint arXiv:1711.06114, 2017
[Zellinger2017b] W.Zellinger, T. Grubinger, E. Lughofer, T. Ntschlaeger,
and Susanne Saminger-Platz, "Central moment discrepancy (cmd) for
domain-invariant representation learning," International Conference on Learning
Representations (ICLR), 2017

__author__ = "Werner Zellinger"
__copyright__ = "Copyright 2017, Werner Zellinger"
__credits__ = ["Thomas Grubinger, Robert Pollak"]
__license__ = "GPL"
__version__ = "1.0.0"
__maintainer__ = "Werner Zellinger"
__email__ = "werner.zellinger@jku.at"
"""

from __future__ import print_function

import matplotlib
matplotlib.use('PS')
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
np.random.seed(0)
from scipy import stats
from keras.layers import Dense, Input, merge
from keras.models import Model
from keras.optimizers import Adadelta
from keras import backend as K

plt.close('all')

N_HIDDEN_NODES = 15
N_MOMENTS = 5
N_CLASSES = 3
DATA_FOLDER = 'data/artificial_dataset/'
TMP_FOLDER = 'temp/artificial_example/'
OUTPUT_FOLDER = 'output/artificial_example/'
    

def cmd(labels, y_pred):
    """
    central moment discrepancy (cmd)
    objective function for keras models (theano or tensorflow backend)
    
    - Zellinger, Werner, et al. "Robust unsupervised domain adaptation for
    neural networks via moment alignment.", TODO
    - Zellinger, Werner, et al. "Central moment discrepancy (CMD) for
    domain-invariant representation learning.", ICLR, 2017.
    """
    x1 = y_pred[:,:N_HIDDEN_NODES]
    x2 = y_pred[:,N_HIDDEN_NODES:]
    mx1 = x1.mean(0)
    mx2 = x2.mean(0)
    sx1 = x1 - mx1
    sx2 = x2 - mx2
    dm = l2diff(mx1,mx2)
    scms = dm
    for i in range(N_MOMENTS-1):
        # moment diff of centralized samples
        scms+=moment_diff(sx1,sx2,i+2)
    return scms

def l2diff(x1, x2):
    """
    standard euclidean norm
    """
    return ((x1-x2)**2).sum().sqrt()      

def moment_diff(sx1, sx2, k):
    """
    difference between moments
    """
    ss1 = (sx1**K.cast(k,'int32')).mean(0)
    ss2 = (sx2**K.cast(k,'int32')).mean(0)
    return l2diff(ss1,ss2)

def neural_network(domain_adaptation=False):
    """
    moment alignment neural network (MANN)
    
    - Zellinger, Werner, et al. "Robust unsupervised domain adaptation for
    neural networks via moment alignment.", arXiv preprint arXiv:1711.06114, 2017
    """
    # layer definition
    input_s = Input(shape=(2,), name='souce_input')
    input_t = Input(shape=(2,), name='target_input')
    encoding = Dense(N_HIDDEN_NODES,
                     activation='sigmoid',
                     name='hidden')
    prediction = Dense(N_CLASSES,
                       activation='softmax',
                       name='pred')
    # network architecture
    encoded_s = encoding(input_s)
    encoded_t = encoding(input_t)
    pred_s = prediction(encoded_s)
    pred_t = prediction(encoded_t)
    dense_s_t = merge([encoded_s,encoded_t], mode='concat', concat_axis=1)
    # input/output definition
    nn = Model(input=[input_s,input_t],
               output=[pred_s,pred_t,dense_s_t])
    # seperate model for activation visualization
    visualize_model = Model(input=[input_s,input_t],
                            output=[encoded_s,encoded_t])
    # compile model
    if domain_adaptation==False:
        cmd_weight = 0.
    else:
        # Please note that the loss weight of the cmd is one per default
        # (see paper).
        cmd_weight = 1.
    nn.compile(loss=['categorical_crossentropy',
                     'categorical_crossentropy',cmd],
               loss_weights=[1.,0.,cmd_weight],
               optimizer=Adadelta(),
               metrics=['accuracy'])
    return nn, visualize_model

def plot_classification_boarders(nn,save_name):
    """
    plot dataset and classification boarders
    """
    plt.figure()
    plt.plot(x_s[y_s[:,0]==1,0],x_s[y_s[:,0]==1,1],color='k',marker=r'$+$',
             linestyle='',ms=15)
    plt.plot(x_s[y_s[:,1]==1,0],x_s[y_s[:,1]==1,1],color='k',marker=r'$-$',
             linestyle='',ms=15)
    plt.plot(x_s[y_s[:,2]==1,0],x_s[y_s[:,2]==1,1],color='k',marker='*',
             linestyle='',ms=15)
    plt.plot(x_t[:,0],x_t[:,1],'k.')
    x_min = -1
    y_min = -0.75
    x_max = 1.2
    y_max = 1.3
    xy = np.mgrid[x_min:x_max:0.001, y_min:y_max:0.001].reshape(2,-1).T
    z = nn.predict([xy,xy])[0]
    ind = np.argmax(z,axis=1)
    z[ind!=0,0]=0
    z[ind!=1,1]=0
    x,y = np.mgrid[x_min:x_max:0.001, y_min:y_max:0.001]
    plt.contour(x,y,z[:,0].reshape(x.shape),levels = [0.1],
                colors=('k',),linestyles=('-',),linewidths=(2,))
    plt.contour(x,y,z[:,1].reshape(x.shape),levels = [0],
                colors=('k',),linestyles=('-',),linewidths=(2,))
    plt.axis('off')
    plt.savefig(save_name)
    
def plot_activations(a_s,a_t,save_name):
    """
    activation visualization via seaborn library
    """
    n_dim=a_s.shape[1]
    n_rows=1
    n_cols=int(n_dim/n_rows)
    fig, axs = plt.subplots(nrows=n_rows,ncols=n_cols, sharey=True,
                            sharex=True)
    for k,ax in enumerate(axs.reshape(-1)):
        if k>=n_dim:
            continue
        sns.kdeplot(a_t[:,k],ax=ax, shade=True, label='target',
                    legend=False, color='0.4',bw=0.03)
        sns.kdeplot(a_s[:,k],ax=ax, shade=True, label='source',
                    legend=False, color='0',bw=0.03)
        plt.setp(ax.xaxis.get_ticklabels(),fontsize=10)
        plt.setp(ax.yaxis.get_ticklabels(),fontsize=10)
    fig.set_figheight(3)
    plt.setp(axs, xticks=[0, 0.5, 1])
    plt.setp(axs, ylim=[0,10])
    plt.savefig(save_name)


# load dataset
x_s = np.load(DATA_FOLDER+'x_s.npy')
y_s = np.load(DATA_FOLDER+'y_s.npy')
x_t = np.load(DATA_FOLDER+'x_t.npy')
y_t = np.load(DATA_FOLDER+'y_t.npy')


# plot dataset
plt.plot(x_s[y_s[:,0]==1,0],x_s[y_s[:,0]==1,1],color='k',marker=r'$+$',
         linestyle='',ms=15)
plt.plot(x_s[y_s[:,1]==1,0],x_s[y_s[:,1]==1,1],color='k',marker=r'$-$',
         linestyle='',ms=15)
plt.plot(x_s[y_s[:,2]==1,0],x_s[y_s[:,2]==1,1],color='k',marker='*',
         linestyle='',ms=15)
plt.plot(x_t[:,0],x_t[:,1],'k.')
plt.axis('off')
plt.savefig(OUTPUT_FOLDER+'dataset.jpg')


# train source model without domain adaptation
nn, nn_vis_model = neural_network(domain_adaptation=False)
nn.fit(x=[x_s,x_t],
       y=[y_s,y_t,np.zeros((x_s.shape[0],1))],
       shuffle=True,
       nb_epoch=10000,
       verbose=0,
       batch_size=x_s.shape[0])
# save the weights
nn.save_weights(TMP_FOLDER+'nn.hdf5')
# train another 5000 epochs for fair comparison
np.random.seed(0)
nn.fit(x=[x_s,x_t],
       y=[y_s,y_t,np.zeros((x_s.shape[0],1))],
       shuffle=True,
       nb_epoch=5000,
       verbose=0,
       batch_size=x_s.shape[0])
# plot the classification boarders (Fig. 3 left in paper)
plot_classification_boarders(nn, OUTPUT_FOLDER+'nn.jpg')
# predict the target accuracy
# The final source accuracy of the NN should be 100% and the final target
# accuracy should be around 89% depending on your system random numbers, theano
# configuration (float32,..), CuDNN version, etc.
print('\nNN acc='+str(nn.evaluate([x_s,x_t],[y_s,y_t,y_t])[-2]))


# adapt the network to to the target domain by means of the central
# moment discrepancy
mann, mann_vis_model = neural_network(domain_adaptation=True)
mann.load_weights(TMP_FOLDER+'nn.hdf5')
np.random.seed(0)
mann.fit(x=[x_s,x_t],
         y=[y_s,y_t,np.zeros((x_s.shape[0],1))],
         shuffle=True,
         nb_epoch=5000,
         verbose=0,
         batch_size=x_s.shape[0])
# plot the classification boarders (Fig. 3 right in paper)
plot_classification_boarders(mann, OUTPUT_FOLDER+'mann.jpg')
# predict the target accuracy
# The predicted acccuracy of the MANN is around 10% more than the
# accuracy of the NN before. This is the result of our approach and does not
# strongly depend on the random numbers.
print('\nMANN acc='+str(mann.evaluate([x_s,x_t],[y_s,y_t,y_t])[-2]))


# Plot activations of NN (Fig. 4 top)
a_s,a_t = nn_vis_model.predict([x_s,x_t])
# Find five most significantly (K-S test) different distributions
p_vals = np.zeros(a_s.shape[1])
for i in range(a_s.shape[1]):
    ksstat, p_vals[i] = stats.ks_2samp(a_s[:,i],a_t[:,i])
ind_worst = p_vals.argsort()[:5]
plot_activations(a_s[:,ind_worst],a_t[:,ind_worst],
                 OUTPUT_FOLDER+'activations_nn.jpg')


# Plot activations of MANN (Fig. 4 bottom)
a_s,a_t = mann_vis_model.predict([x_s,x_t])
# Find five most significantly (K-S test) different distributions
p_vals = np.zeros(a_s.shape[1])
for i in range(a_s.shape[1]):
    ksstat, p_vals[i] = stats.ks_2samp(a_s[:,i],a_t[:,i])
ind_worst = p_vals.argsort()[:5]
plot_activations(a_s[:,ind_worst],a_t[:,ind_worst],
                 OUTPUT_FOLDER+'activations_mann.jpg')