import json
import pandas
import argparse 
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
from scipy.misc import imread, imresize
from matplotlib.colors import Normalize, LinearSegmentedColormap

flatten = lambda ll: [e for l in ll for e in l]

parser = argparse.ArgumentParser()

# experiment settings
parser.add_argument("--tier",  default = "val", choices = ["train", "val", "test"], type = str)
parser.add_argument("--expName",  default = "experiment", type = str)

# plotting
parser.add_argument("--cmap", default = "custom", type = str) # "gnuplot2", "GreysT" 

parser.add_argument("--trans", help = "transpose question attention", action = "store_true")
parser.add_argument("--sa", action = "store_true")
parser.add_argument("--gate", action = "store_true")

# filtering
parser.add_argument("--instances", nargs = "*", type = int)
parser.add_argument("--maxNum", default = 0, type = int)

parser.add_argument("--filter", default = [], nargs = "*", choices = ["mod", "length", "field"])
parser.add_argument("--filterMod", action = "store_true")
parser.add_argument("--filterLength", type = int) # 19
parser.add_argument("--filterField", type = str)
parser.add_argument("--filterIn", action = "store_true")
parser.add_argument("--filterList", nargs = "*") # ["how many", "more"], numbers

args = parser.parse_args()

isRight = lambda instance: instance["answer"] == instance["prediction"]
isRightStr = lambda instance: "RIGHT" if isRight(instance) else "WRONG"

# files
# jsonFilename = "valHPredictions.json" if args.humans else "valPredictions.json"
imagesDir = "./CLEVR_v1/images/{tier}".format(
    tier = args.tier)

dataFile = "./preds/{expName}/{tier}Predictions-{expName}.json".format(
    tier = args.tier, 
    expName = args.expName)

inImgName = lambda index: "{dir}/CLEVR_{tier}_{index}.png".format(
    dir = imagesDir, 
    index = ("000000%d" % index)[-6:],
    tier = args.tier)

outImgAttName = lambda instance, j: "./preds/{expName}/{tier}{id}Img_{step}.png".format(
    expName = args.expName, 
    tier = args.tier, 
    id = instance["index"], 
    step = j + 1)

outTableAttName = lambda instance, name: "./preds/{expName}/{tier}{id}{tableName}_{right}{orientation}.png".format(
    expName = args.expName, 
    tier = args.tier, 
    id = instance["index"], 
    tableName = name, 
    right = isRightStr(instance), 
    orientation = "_t" if args.trans else "")

# plotting
imageDims = (14,14)
figureImageDims = (2,3)
figureTableDims = (5,4)
fontScale = 1

# set transparent mask for low attention areas  
# cdict = plt.get_cmap("gnuplot2")._segmentdata
cdict = {"red": ((0.0, 0.0, 0.0), (0.6, 0.8, 0.8), (1.0, 1, 1)), 
    "green": ((0.0, 0.0, 0.0), (0.6, 0.8, 0.8), (1.0, 1, 1)), 
    "blue": ((0.0, 0.0, 0.0), (0.6, 0.8, 0.8), (1.0, 1, 1))}
cdict["alpha"] = ((0.0, 0.35, 0.35),
                  (1.0,0.65, 0.65))
plt.register_cmap(name = "custom", data = cdict)

def savePlot(fig, fileName):
    plt.savefig(fileName, dpi = 720)
    plt.close(fig) 
    del fig

def filter(instance):
    if "length" in args.filter: 
        if len(instance["question"].split(" ")) > args.filterLength:
             return True

    if "field" in args.filter:
        if args.filterIn:  
            if not (instance[args.filterField] in args.filterList):
                return True
        else:
            if not any((l in instance[args.filterField]) for l in args.filterList):
                return True            

    if "mod" in args.filter:
        if (not isRight(instance)) and args.filterMod:
            return True

        if isRight(instance) and (not args.filterMod):
            return True

    return False

def showImgAtt(img, instance, step, ax):
    dx, dy = 0.05, 0.05
    x = np.arange(-1.5, 1.5, dx)
    y = np.arange(-1.0, 1.0, dy)
    X, Y = np.meshgrid(x, y)
    extent = np.min(x), np.max(x), np.min(y), np.max(y)

    ax.cla()

    img1 = ax.imshow(img, interpolation = "nearest", extent = extent)
    ax.imshow(np.array(instance["attentions"]["kb"][step]).reshape(imageDims), cmap = plt.get_cmap(args.cmap), 
        interpolation = "bicubic", extent = extent)

    ax.set_axis_off()
    plt.axis("off")

    ax.set_aspect("auto")


def showImgAtts(instance):
    img = imread(inImgName(instance["imageId"]))

    length = len(instance["attentions"]["kb"])
    
    # show images
    for j in range(length):
        fig, ax = plt.subplots()
        fig.set_figheight(figureImageDims[0])
        fig.set_figwidth(figureImageDims[1])              
        
        showImgAtt(img, instance, j, ax)
        
        plt.subplots_adjust(bottom = 0, top = 1, left = 0, right = 1)
        savePlot(fig, outImgAttName(instance, j))

def showTableAtt(instance, table, x, y, name):
    # if args.trans:
    #     figureTableDims = (len(y) / 2 + 4, len(x) + 2)
    # else:
    #     figureTableDims = (len(y) / 2, len(x) / 2)
    # xx = np.arange(0, len(x), 1)
    # yy = np.arange(0, len(y), 1)
    # extent2 = np.min(xx), np.max(xx), np.min(yy), np.max(yy)
    
    fig2, bx = plt.subplots(1, 1) # figsize = figureTableDims
    bx.cla()

    sns.set(font_scale = fontScale)

    if args.trans:
        table = np.transpose(table)
        x, y = y, x
    
    tableMap = pandas.DataFrame(data = table, index = x, columns = y)
    
    bx = sns.heatmap(tableMap, cmap = "Purples", cbar = False, linewidths = .5, linecolor = "gray", square = True)
    
    # x ticks
    if args.trans:
        bx.xaxis.tick_top()
    locs, labels = plt.xticks()
    if args.trans:
        plt.setp(labels, rotation = 0)
    else:
        plt.setp(labels, rotation = 60)

    # y ticks
    locs, labels = plt.yticks()
    plt.setp(labels, rotation = 0)

    plt.savefig(outTableAttName(instance, name), dpi = 720)

def main():
    with open(dataFile) as inFile:
        results = json.load(inFile)

    # print(args.exp)

    count = 0
    if args.instances is None:
        args.instances = range(len(results))

    for i in args.instances:        
        if filter(results[i]):
            continue

        if count > args.maxNum and args.maxNum > 0:
            break
        count += 1

        length = len(results[i]["attentions"]["kb"])
        showImgAtts(results[i])

        iterations = range(1, length + 1)
        questionList = results[i]["question"].split(" ")
        table = np.array(results[i]["attentions"]["question"])[:,:(len(questionList) + 1)]        
        showTableAtt(results[i], table, iterations, questionList, "text")

        if args.sa:
            iterations = range(length)
            sa = np.zeros((length, length))
            for i in range(length):
                for j in range(i+1):
                    sa[i][j] = results[i]["attentions"]["self"][i][j]
            
            showTableAtt(results[i], sa[i][j], iterations, iterations, "sa")                    

        print(i)
        print("id:", results[i]["index"])        
        print("img:", results[i]["imageId"])
        print("Q:", results[i]["question"])
        print("G:", results[i]["answer"])
        print("P:", results[i]["prediction"])
        print(isRightStr(results[i]))

        if args.gate:
            print(results[i]["attentions"]["gate"])

        print("________________________________________________________________________")

if __name__ == "__main__":
    main()