#!/usr/bin/env python
import inspect
import os
import random
import sys
import matplotlib.cm as cmx
import matplotlib.colors as colors
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.legend as lgd
import matplotlib.markers as mks
from parse_log import parse_log
import numpy as np

RPN_FIELDS=    [
        ['NumIters', 'loss_cls'],
        ['NumIters', 'loss_bbox']
    ]

FAST_RCNN_FIELDS=[
            ['NumIters', 'loss_bbox'],
            ['NumIters', 'loss_cls']
        ]

FIELDS = [RPN_FIELDS, FAST_RCNN_FIELDS, RPN_FIELDS, FAST_RCNN_FIELDS]
LABELS = ["Stage 1 : RPN", "Stage 1 : Faster-RCNN", "Stage 2 : RPN", "Stage 2 : Faster-RCNN"]

def enum(**enums):
    return type('Enum', (), enums)

PLOT_MODE=enum(NORMAL="normal", MOVING_AVG="movingavg", BOTH="both")

def get_log_file_suffix():
    return '.log'

def get_chart_type_description_separator():
    return '  vs. '

def is_x_axis_field(field):
    x_axis_fields = ['Iters', 'Seconds']
    return field in x_axis_fields

def create_field_index():
    train_key = 'Train'
    test_key = 'Test'
    field_index = {train_key:{'NumIters':0, 'Seconds':1, train_key + ' loss':2,
                              train_key + ' learning rate':3, 'rpn_cls_loss': 4, 'rpn_loss_bbox': 5},
                   test_key:{'NumIters':0, 'Seconds':1, test_key + ' accuracy':2,
                             test_key + ' loss':3}}
    fields = set()
    for data_file_type in field_index.keys():
        fields = fields.union(set(field_index[data_file_type].keys()))
    fields = list(fields)
    fields.sort()
    return field_index, fields

def get_supported_chart_types():
    field_index, fields = create_field_index()
    num_fields = len(fields)
    supported_chart_types = []
    for i in xrange(num_fields):
        if not is_x_axis_field(fields[i]):
            for j in xrange(num_fields):
                if i != j and is_x_axis_field(fields[j]):
                    supported_chart_types.append('%s%s%s' % (
                        fields[i], get_chart_type_description_separator(),
                        fields[j]))
    return supported_chart_types

def get_chart_type_description(chart_type):
    supported_chart_types = get_supported_chart_types()
    chart_type_description = supported_chart_types[chart_type]
    return chart_type_description

def get_data_file_type(chart_type):
    description = get_chart_type_description(chart_type)
    data_file_type = description.split()[0]
    return data_file_type

def get_data_file(path_to_log):
    return os.path.basename(path_to_log) + '.train';

def get_field_descriptions(chart_type):
    description = get_chart_type_description(chart_type).split(
        get_chart_type_description_separator())
    print 'description:'
    print description
    y_axis_field = description[0]
    x_axis_field = description[1]
    return x_axis_field, y_axis_field

def get_field_indices(file):
    indices= {}
    idx=0
    with open(file, 'r') as f:
        first_line=f.readline().strip()
        for col_label in first_line.split(','):
            indices[col_label]=idx
            idx=idx+1
    return indices

def load_data(data_file, field_idx0, field_idx1):
    data = [[], []]
    with open(data_file, 'r') as f:
        f.readline() #skip column labels
        for line in f:

            line = line.strip()
            if line[0] != '#':
                fields = line.split(',')
                data[0].append(float(fields[field_idx0].strip()))
                data[1].append(float(fields[field_idx1].strip()))
    return data

def random_marker():
    markers = mks.MarkerStyle.markers
    num = len(markers.keys())
    idx = random.randint(0, num - 1)
    return markers.keys()[idx]

def get_data_label(path_to_log):
    label = path_to_log[path_to_log.rfind('/')+1 : path_to_log.rfind(
        get_log_file_suffix())]
    return label

def get_legend_loc(chart_type):
    x_axis, y_axis = get_field_descriptions(chart_type)
    loc = 'lower right'
    if y_axis.find('accuracy') != -1:
        pass
    if y_axis.find('loss') != -1 or y_axis.find('learning rate') != -1:
        loc = 'upper right'
    return loc

def moving_average(values, N):
    averaged=np.convolve(np.array(values), np.ones((N,))/N, mode='valid')
    return np.concatenate((np.tile(averaged[0], N - 1), averaged))

def plot_chart(log_file, path_to_png, mode=PLOT_MODE.NORMAL):

    mean_ap=0
    phases, detected_mean_ap = parse_log(log_file)
    if detected_mean_ap != None:
        mean_ap=detected_mean_ap

    print "Processing %s with mAP=%f" % (path_to_png, mean_ap)

    plt.figure(1, figsize=(8, 32))

    end_phase=min(len(phases), 4)
    for phase_idx in range(0,end_phase):
        phase=np.array(phases[phase_idx])
        plt.subplot(411+phase_idx)
        label = LABELS[phase_idx]
        plt.title("%s%s"%( "mAP = %f    "%mean_ap if phase_idx == 0 else "",str(label[phase_idx])))


        for x_label,y_label in FIELDS[phase_idx]:
            ## TODO: more systematic color cycle for lines
            color = [random.random(), random.random(), random.random()]
            linewidth = 0.75
            ## If there too many datapoints, do not use marker.
    ##        use_marker = False
            use_marker = True

            # if (mode==PLOT_MODE.MOVING_AVG):

            x_data = [row[x_label] for row in phase]
            y_data = [row[y_label] for row in phase]


            if mode==PLOT_MODE.MOVING_AVG:
                y_data=moving_average(y_data, 100)
            elif mode == PLOT_MODE.BOTH:
                marker = random_marker()
                plt.plot(x_data, y_data, label=label, color=color,
                         marker=marker, linewidth=linewidth)

                color = [random.random(), random.random(), random.random()]
                y_data = moving_average(y_data, 100)

            if not use_marker:
                plt.plot(x_data, y_data, label = label, color = color,
                         linewidth = linewidth)
            else:
                marker = random_marker()
                plt.plot(x_data, y_data, label = label, color = color,
                         marker = marker, linewidth = linewidth)

    #legend_loc = get_legend_loc(chart_type)
    #plt.legend(loc = legend_loc, ncol = 1) # ajust ncol to fit the space
    #plt.xlabel(x_axis_field)
    #plt.ylabel(y_axis_field)

    # plt.annotate(fontsize='xx-small')
    print "Saving...",
    plt.savefig(path_to_png, dpi=600)
    print "done"
    plt.show()

def print_help():
    print """
    Usage: ./plot.py [log file] [output picture]
    """
    sys.exit()

if __name__ == '__main__':
    if len(sys.argv) < 3:
        print_help()
    else:
        log_file = sys.argv[1]
        path_to_png = sys.argv[2]
        if not os.path.exists(log_file):
            print 'Log file does not exist: %s' % log_file
            sys.exit()
        plot_chart(log_file, path_to_png, PLOT_MODE.MOVING_AVG)