'''
Error analysis
TODO Make generic/modular and move to nn
'''

import numpy as np
import cPickle as pickle
from editDist import edit_distance as ed
#from progressbar import ProgressBar
from colorama import Fore, Back


def disp_corr(hyp, ref):
    '''
    Display correspondences between hyp and ref
    '''
    pass


def disp_errs_by_pos(err_by_pos, out_file):
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    plt.plot(err_by_pos)
    #plt.show()
    plt.savefig(out_file)


def disp_err_corr(hyp_corr, ref_corr):
    hyp_str = ''
    ref_str = ''
    assert len(hyp_corr) == len(ref_corr)
    for k in xrange(len(hyp_corr)):
        if hyp_corr[k] == '[space]':
            hc = ' '
        elif hyp_corr[k] == '<ins>':
            hc = Back.GREEN + ' ' + Back.RESET
        else:
            hc = hyp_corr[k]

        if ref_corr[k] == '[space]':
            rc = ' '
        elif ref_corr[k] == '<del>':
            rc = Back.RED + ' ' + Back.RESET
        else:
            rc = ref_corr[k]

        if hc != rc and len(hc) == 1 and len(rc) == 1:
            hc = Back.BLUE + Fore.BLACK + hc + Fore.RESET + Back.RESET
            rc = Back.BLUE + Fore.BLACK + rc + Fore.RESET + Back.RESET
        hyp_str += hc
        ref_str += rc
    print hyp_str
    print ref_str


def replace_contractions(utt):
    while len(utt) and utt[-1] == '[space]':
        utt = utt[:-1]
    while len(utt) and utt[0] == '[space]':
        utt = utt[1:]

    # TODO Replace in training text instead
    utt_str = ''.join([c if c != '[space]' else ' ' for c in utt])

    '''
    utt_str = utt_str.replace('can\'t', 'cannot')
    utt_str = utt_str.replace('let\'s', 'let us')

    # Possessive vs " is"
    utt_str = utt_str.replace('ere\'s', 'ere is')
    utt_str = utt_str.replace('that\'s', 'that is')
    utt_str = utt_str.replace('he\'s', 'he is')
    utt_str = utt_str.replace('it\'s', 'it is')
    utt_str = utt_str.replace('how\'s', 'how is')
    utt_str = utt_str.replace('what\'s', 'what is')
    utt_str = utt_str.replace('when\'s', 'when is')
    utt_str = utt_str.replace('why\'s', 'why is')

    utt_str = utt_str.replace('\'re', ' are')

    utt_str = utt_str.replace('i\'m', 'i am')
    utt_str = utt_str.replace('\'ll', ' will')
    utt_str = utt_str.replace('\'d', ' would')  # had / would ambiguity
    utt_str = utt_str.replace('n\'t', ' not')
    utt_str = utt_str.replace('\'ve', ' have')

    utt_str = utt_str.replace(' uh', '')
    utt_str = utt_str.replace(' um', '')
    utt_str = utt_str.replace('uh ', '')
    utt_str = utt_str.replace('um ', '')
    '''

    utt = [c if c != ' ' else '[space]' for c in list(utt_str)]
    return utt


def compute_and_display_stats(hyps, refs, hypscores, refscores, numphones, subsets, subset=None, display=False):
    # Filter by subset
    if subset:
        print 'USING SUBSET: %s' % subset
        filt = subsets == subset
        hyps = hyps[filt]
        refs = refs[filt]
        hypscores = hypscores[filt]
        refscores = refscores[filt]
        numphones = numphones[filt]

    '''
    Compute stats
    '''

    hyp_lens = [len(s) for s in hyps]
    ref_lens = [len(s) for s in refs]

    max_hyp_len = max([len(hyp) for hyp in hyps])
    tot_errs_by_pos = np.zeros(max_hyp_len)
    counts_by_pos = np.zeros(max_hyp_len, dtype=np.int32)

    tot_dist = tot_eq = tot_ins = tot_dels = tot_subs = 0.0
    num_sents_correct = 0
    correct_sents_len = 0

    #pbar = ProgressBar(maxval=len(hyps)).start()

    k = 0
    for (hyp, ref, hypscore, refscore) in reversed(zip(hyps, refs, hypscores, refscores)):
        #hyp = replace_contractions(hyp)
        dist, eq, ins, dels, subs, errs_by_pos, hyp_corr, ref_corr = ed(hyp, ref)
        tot_eq += eq
        tot_ins += ins
        tot_dels += dels
        tot_subs += subs
        tot_errs_by_pos[0:errs_by_pos.shape[0]] += errs_by_pos
        counts_by_pos[0:errs_by_pos.shape[0]] += 1
        k += 1
        #pbar.update(k)

        if dist == 0:
            num_sents_correct += 1
            correct_sents_len += len(ref)
        tot_dist += dist

        if display:
            disp_err_corr(hyp_corr, ref_corr)
            print

    '''
    Display aggregate stats
    '''

    print 'avg len hyp: %f' % np.mean(hyp_lens)
    print 'avg len ref: %f' % np.mean(ref_lens)
    print 'avg num phones: %f' % np.mean(numphones)

    print 'avg ref score: %f' % (sum(refscores) / len(refscores))
    print 'avg hyp score: %f' % (sum(hypscores) / len(hypscores))

    tot_comp_len = float(np.sum([max(h, r) for (h, r) in zip(hyp_lens, ref_lens)]))
    print 'frac eq: %f ins: %f del: %f sub: %f' %\
        tuple(np.array([tot_eq, tot_ins, tot_dels, tot_subs]) / tot_comp_len)

    print 'CER: %f' % (100.0 * tot_dist / np.sum(numphones))

    print '%d/%d sents correct' % (num_sents_correct, len(hyps))
    print 'avg len of correct sent: %f' % (correct_sents_len / float(num_sents_correct))

    disp_errs_by_pos(tot_errs_by_pos / counts_by_pos, 'err_by_pos.%s.png' % ('all' if not subset else subset))


def main(args):
    '''
    Read in data
    '''

    # NOTE Make sure synced with order dumped in runDecode.py
    fid = open(args.pk_file, 'rb')
    hyps = np.array(pickle.load(fid))
    refs = np.array(pickle.load(fid))
    hypscores = np.array(pickle.load(fid))
    refscores = np.array(pickle.load(fid))
    numphones = np.array(pickle.load(fid))
    subsets = pickle.load(fid)
    fid.close()

    compute_and_display_stats(hyps, refs, hypscores, refscores, numphones, subsets, subset=None, display=args.display)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('pk_file', default='hyp.pk', help='Pickle file with data')
    parser.add_argument('--display', action='store_true')
    args = parser.parse_args()
    main(args)