from collections import namedtuple
from pprint import pprint

# "jpeg_psnr,jpeg_ssim,our_ssim,our_q,jpeg_psnrhvs,png_size,model_number,our_size,filename,jpeg_vifp,jpeg_q,jpeg_msssim,our_psnrhvsm,jpeg_psnrhvsm,our_vifp,our_psnr,our_msssim,our_psnrhvs,jpeg_size"


def process_one(eg):
    value_map = {}
    eg_s = eg.split('|')

    metrics = "PSNR SSIM MSSSIM VIFP PSNRHVS PSNRHVSM".lower().split(' ')
    meta = "filename model_number png_size jpeg_size our_size jpeg_q our_q".split() 

    first_line = eg_s[0].strip().split()
    for index, m in enumerate(meta):
        value_map[m] = first_line[index]

    for typ, v in zip(['jpeg', 'our'], [eg_s[1], eg_s[2]]):
        for m, value in zip(metrics, v.strip().split()):
            value_map[typ + '_' + m] = float(value)
    return value_map


def process_log(filename):
    f = open(filename).read().splitlines()
    values = [process_one(l) for l in f]
    return values

def print_given_average_metrics(values, metric):
    total = len(values)
    avg = sum([float(v[metric]) for v in values])/total
    print metric, str(avg)

def print_all_average_metrics(values, silent=False):
    out = {}
    kk = values[0].keys()
    total = len(values)
    for k in kk:
        if k == 'filename': continue
        avg = sum([float(v[k]) for v in values])/total
        out[str(k)] = avg
        if not silent:
            print str(k), str(avg)
    return out

def pprint_metrics(avg_metrics):
    out = {}
    for k,v in avg_metrics.iteritems():
        model, metric = k.split('_')
        if metric not in out:
            out[metric] = {}
        out[metric][model] = v
    return out

def pprint_by_categories(values, metric=None, data_type='mit'):
    from itertools import product

    if data_type == 'mit':
        categories = json.load(open('./categories.json'))
    else:
        categories = ['01', '02', '03', '04', '05', '06', '07', '08', '09'] + map(str, range(10,25))

    
    out = []
    for cat in categories:
        filtered_values = filter(lambda x: cat in x['filename'], values) 
        avg_metrics = print_all_average_metrics(filtered_values, silent=True) 
        if metric is not None:
            res = pprint_metrics(avg_metrics)[metric]
        else:
            res =  pprint_metrics(avg_metrics)
        pprint ( res )
        out.append((cat,res))
    return out

def plot(values, metric_name):

    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import sys

    plt.style.use('ggplot')

    fig, ax = plt.subplots(1, 1, figsize=(25, 3))
    ax.margins(0)

    x = []
    y = []
    for index,v in enumerate( values ):
        # if not index: continue
        # plt.plot(x, new_recall, linewidth=2, label='Condensed Mem Network')
        x.append(index)
        y.append(v[1]['our']-v[1]['jpeg'])

    # plt.plot(x,y, 'o')
    # plt.semilogy(x,y)
    y_neg = [max(0,i) for i in y]
    y_pos = [min(0,i) for i in y]

    plt.bar(x,y_neg)
    plt.bar(x,y_pos, color='r')
    plt.tick_params(axis='x', which='both', bottom='off', top='off', labelbottom='off')

    plt.title(metric_name.upper(), x=0.5, y=0.8, fontsize=14)
    plt.legend(loc='')
    ax.get_xaxis().set_visible(False)
    ax.xaxis.set_major_formatter(plt.NullFormatter())
    fig.tight_layout()
    # plt.savefig('plot_size_' + metric_name + '.png', bbox_inches='tight_layout', pad_inches=0)
    plt.savefig('plot_kodak_' + metric_name + '.png')


if __name__ == '__main__':
    import sys
    import json
    print_average    = True
    print_categories = False

    if len(sys.argv) > 1:
        filename = sys.argv[1]
    else:
        # filename = 'logs/log_small_ALL_model_69.log'
        filename = 'logs/log_small_100_model_6_79.log'
    values = process_log(filename)
    
    metrics = "PSNR SSIM MSSSIM VIFP PSNRHVS PSNRHVSM".lower().split(' ')
    
    if print_categories:
        for metric_name in metrics:
            out =  pprint_by_categories(values, metric_name, data_type='size')
            print metric_name, len(out)
            plot(out, metric_name)

    if print_average:
        avg_metrics = print_all_average_metrics(values, silent=True)
        avg_metrics = pprint_metrics(avg_metrics)
        pprint ( avg_metrics )
    print(filename, avg_metrics['ssim'])
    print_given_average_metrics(values, 'filename')
    print_given_average_metrics(values, 'our_q')
    print_given_average_metrics(values, 'jpeg_q')
    print_given_average_metrics(values, 'our_size')
    print_given_average_metrics(values, 'jpeg_size')
    print_given_average_metrics(values, 'jpeg_psnr')