# Copyright 2018 Alexander Matthews
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import pickle
import numpy as np
from IPython import embed
from matplotlib import pylab as plt

import shared
import mmd_experiment

def process_mmd_experiment(width_class):
    results_file_name = mmd_experiment.results_file_stub + "_" +  width_class + ".pickle"
    results = pickle.load( open(results_file_name,'rb' ) )
    
    callibration_mmds = np.loadtxt('results/callibration_mmds.csv')
    mean_callibration = np.mean(callibration_mmds)
    
    mmd_squareds = results['mmd_squareds']
    hidden_layer_numbers = results['hidden_layer_numbers']
    hidden_unit_numbers = results['hidden_unit_numbers']
    num_repeats = mmd_squareds.shape[2]
    
    mean_mmds = np.mean( mmd_squareds, axis = 2 )
    std_mmds = np.std( mmd_squareds, axis = 2 ) / np.sqrt(num_repeats)
    
    plt.figure()
    
    for hidden_layer_number, index in zip(hidden_layer_numbers,range(len(hidden_layer_numbers))):
        if hidden_layer_number==1:
            layer_string = ' hidden layer'
        else:
            layer_string = ' hidden layers'
        line_name = str(hidden_layer_number) + layer_string
        plt.errorbar( hidden_unit_numbers, mean_mmds[:,index], yerr = 2.*std_mmds[:,index], label = line_name)
    plt.xlabel('Number of hidden units per layer')
    plt.xlim([0,60])
    plt.ylabel('MMD SQUARED(GP, NN)')
    plt.ylim([0.,0.02])
    plt.axhline(y=mean_callibration, color='r', linestyle='--')
    plt.legend()
    output_file_name = "../figures/mmds_" + width_class + ".pdf"
    plt.savefig(output_file_name)
    embed()
    plt.show()

if __name__ == '__main__':
    if len(sys.argv)!=2 or sys.argv[1] not in shared.valid_width_classes:
        print("Usage: ", sys.argv[0], " <width_class>")
        sys.exit(-1)
    process_mmd_experiment(sys.argv[1])