from torch.utils.data import DataLoader

from dataio.loader import get_dataset, get_dataset_path
from dataio.transformation import get_dataset_transformation
from utils.util import json_file_to_pyobj
from utils.visualiser import Visualiser
from models import get_model
import os, time

# import matplotlib
# matplotlib.use('Agg')

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import math, numpy
import numpy as np 
from scipy.misc import imresize
from skimage.transform import resize

def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None, title=''):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()
    plt.suptitle(title)

def plotNNFilterOverlay(input_im, units, figure_id, interp='bilinear',
                        colormap=cm.jet, colormap_lim=None, title='', alpha=0.8):
    plt.ion()
    filters = units.shape[2]
    fig = plt.figure(figure_id, figsize=(5,5))
    fig.clf()

    for i in range(filters):
        plt.imshow(input_im[:,:,0], interpolation=interp, cmap='gray')
        plt.imshow(units[:,:,i], interpolation=interp, cmap=colormap, alpha=alpha)
        plt.axis('off')
        plt.colorbar()
        plt.title(title, fontsize='small')
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

    # plt.savefig('{}/{}.png'.format(dir_name,time.time()))




## Load options
PAUSE = .01
#config_name = 'config_sononet_attention_fs8_v6.json'
#config_name = 'config_sononet_attention_fs8_v8.json'
#config_name = 'config_sononet_attention_fs8_v9.json'
#config_name = 'config_sononet_attention_fs8_v10.json'
#config_name = 'config_sononet_attention_fs8_v11.json'
#config_name = 'config_sononet_attention_fs8_v13.json'
#config_name = 'config_sononet_attention_fs8_v14.json'
#config_name = 'config_sononet_attention_fs8_v15.json'
#config_name = 'config_sononet_attention_fs8_v16.json'
#config_name = 'config_sononet_grid_attention_fs8_v1.json'
config_name = 'config_sononet_grid_attention_fs8_deepsup_v1.json'
config_name = 'config_sononet_grid_attention_fs8_deepsup_v2.json'
config_name = 'config_sononet_grid_attention_fs8_deepsup_v3.json'
config_name = 'config_sononet_grid_attention_fs8_deepsup_v4.json'

# config_name = 'config_sononet_grid_att_fs8_avg.json'
config_name = 'config_sononet_grid_att_fs8_avg_v2.json'
# config_name = 'config_sononet_grid_att_fs8_avg_v3.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v4.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v5.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v5.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v6.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v7.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v8.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v9.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v10.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v11.json'
#config_name = 'config_sononet_grid_att_fs8_avg_v12.json'

config_name = 'config_sononet_grid_att_fs8_avg_v12_scratch.json'
config_name = 'config_sononet_grid_att_fs4_avg_v12.json'

#config_name = 'config_sononet_grid_attention_fs8_v3.json'

json_opts = json_file_to_pyobj('/vol/bitbucket/js3611/projects/transfer_learning/ultrasound/configs_2/{}'.format(config_name))
train_opts = json_opts.training

dir_name = os.path.join('visualisation_debug', config_name)
if not os.path.isdir(dir_name):
    os.makedirs(dir_name)
    os.makedirs(os.path.join(dir_name,'pos'))
    os.makedirs(os.path.join(dir_name,'neg'))

# Setup the NN Model
model = get_model(json_opts.model)
if hasattr(model.net, 'classification_mode'):
    model.net.classification_mode = 'attention'
if hasattr(model.net, 'deep_supervised'):
    model.net.deep_supervised = False 

# Setup Dataset and Augmentation
dataset_class = get_dataset(train_opts.arch_type)
dataset_path = get_dataset_path(train_opts.arch_type, json_opts.data_path)
dataset_transform = get_dataset_transformation(train_opts.arch_type, opts=json_opts.augmentation)

# Setup Data Loader
dataset = dataset_class(dataset_path, split='train', transform=dataset_transform['valid'])
data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=1, shuffle=True)

# test
for iteration, data in enumerate(data_loader, 1):
    model.set_input(data[0], data[1])

    cls = dataset.label_names[int(data[1])]

    model.validate()
    pred_class = model.pred[1]
    pred_cls = dataset.label_names[int(pred_class)]

    #########################################################
    # Display the input image and Down_sample the input image
    input_img = model.input[0,0].cpu().numpy()
    #input_img = numpy.expand_dims(imresize(input_img, (fmap_size[0], fmap_size[1]), interp='bilinear'), axis=2)
    input_img = numpy.expand_dims(input_img, axis=2)

    # plotNNFilter(input_img, figure_id=0, colormap="gray")
    plotNNFilterOverlay(input_img, numpy.zeros_like(input_img), figure_id=0, interp='bilinear',
                        colormap=cm.jet, title='[GT:{}|P:{}]'.format(cls, pred_cls),alpha=0)

    chance = np.random.random() < 0.01 if cls == "BACKGROUND" else 1
    if cls != pred_cls:
        plt.savefig('{}/neg/{:03d}.png'.format(dir_name,iteration))
    elif cls == pred_cls and chance:
        plt.savefig('{}/pos/{:03d}.png'.format(dir_name,iteration))
    #########################################################
    # Compatibility Scores overlay with input
    attentions = []
    for i in [1,2]:
        fmap = model.get_feature_maps('compatibility_score%d'%i, upscale=False)
        if not fmap:
            continue

        # Output of the attention block
        fmap_0 = fmap[0].squeeze().permute(1,2,0).cpu().numpy()
        fmap_size = fmap_0.shape

        # Attention coefficient (b x c x w x h x s)
        attention = fmap[1].squeeze().cpu().numpy()
        attention = attention[:, :]
        #attention = numpy.expand_dims(resize(attention, (fmap_size[0], fmap_size[1]), mode='constant', preserve_range=True), axis=2)
        attention = numpy.expand_dims(resize(attention, (input_img.shape[0], input_img.shape[1]), mode='constant', preserve_range=True), axis=2)

        # this one is useless
        #plotNNFilter(fmap_0, figure_id=i+3, interp='bilinear', colormap=cm.jet, title='compat. feature %d' %i)

        plotNNFilterOverlay(input_img, attention, figure_id=i, interp='bilinear', colormap=cm.jet, title='[GT:{}|P:{}] compat. {}'.format(cls,pred_cls,i), alpha=0.5)
        attentions.append(attention)

    #plotNNFilterOverlay(input_img, attentions[0], figure_id=4, interp='bilinear', colormap=cm.jet, title='[GT:{}|P:{}] compat. (all)'.format(cls, pred_cls), alpha=0.5)
    plotNNFilterOverlay(input_img, numpy.mean(attentions,0), figure_id=4, interp='bilinear', colormap=cm.jet, title='[GT:{}|P:{}] compat. (all)'.format(cls, pred_cls), alpha=0.5)

    if cls != pred_cls:
        plt.savefig('{}/neg/{:03d}_hm.png'.format(dir_name,iteration))
    elif cls == pred_cls and chance:
        plt.savefig('{}/pos/{:03d}_hm.png'.format(dir_name,iteration))
    # Linear embedding g(x)
    # (b, c, h, w)
    #gx = fmap[2].squeeze().permute(1,2,0).cpu().numpy()
    #plotNNFilter(gx, figure_id=3, interp='nearest', colormap=cm.jet)

    plt.show()
    plt.pause(PAUSE)

model.destructor()
#if iteration == 1: break