from __future__ import print_function
import os
import pickle

import numpy
from data import get_test_loader, get_text_loader
import time
import numpy as np
from vocab import Vocabulary  # NOQA
import torch
from model import VSE, order_sim
from collections import OrderedDict


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=0):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / (.0001 + self.count)

    def __str__(self):
        """String representation for logging
        """
        # for values that should be recorded exactly e.g. iteration number
        if self.count == 0:
            return str(self.val)
        # for stats
        return '%.4f (%.4f)' % (self.val, self.avg)


class LogCollector(object):
    """A collection of logging objects that can change from train to val"""

    def __init__(self):
        # to keep the order of logged variables deterministic
        self.meters = OrderedDict()

    def update(self, k, v, n=0):
        # create a new meter if previously not recorded
        if k not in self.meters:
            self.meters[k] = AverageMeter()
        self.meters[k].update(v, n)

    def __str__(self):
        """Concatenate the meters in one log line
        """
        s = ''
        for i, (k, v) in enumerate(self.meters.items()):
            if i > 0:
                s += '  '
            s += k + ' ' + str(v)
        return s

    def tb_log(self, tb_logger, prefix='', step=None):
        """Log using tensorboard
        """
        for k, v in self.meters.items():
            tb_logger.log_value(prefix + k, v.val, step=step)


def encode_data(model, data_loader, log_step=10, logging=print):
    """Encode all images and captions loadable by `data_loader`
    """
    batch_time = AverageMeter()
    val_logger = LogCollector()

    # switch to evaluate mode
    model.val_start()

    end = time.time()

    # numpy array to keep all the embeddings
    img_embs = None
    cap_embs = None
    for i, (images, captions, lengths, ids) in enumerate(data_loader):
        # make sure val logger is used
        model.logger = val_logger

        # compute the embeddings
        img_emb, cap_emb = model.forward_emb(images, captions, lengths, volatile=True)

        # initialize the numpy arrays given the size of the embeddings
        if cap_embs is None:
            if img_emb is not None:
                img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1)))
            if cap_emb is not None:
                if img_emb is not None:
                    pn_number = cap_emb.size(0) // img_emb.size(0)
                else:
                    pn_number = 1
                cap_embs = np.zeros((len(data_loader.dataset) * pn_number, cap_emb.size(1)))

        # preserve the embeddings by copying from gpu and converting to numpy
        if img_emb is not None:
            img_embs[ids] = img_emb.data.cpu().numpy().copy()

        if (cap_emb is not None) and (img_emb is not None):
            pn_number = cap_emb.size(0) // img_emb.size(0)
            modified_ids = list()
            for idx in ids:
                modified_ids.append(idx * pn_number)
                for j in range(1, pn_number):
                    modified_ids.append(idx * pn_number + j)
            ids = modified_ids
        if cap_emb is not None:
            cap_embs[ids] = cap_emb.data.cpu().numpy().copy()

        # measure accuracy and record loss
        # model.forward_loss(img_emb, cap_emb)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % log_step == 0:
            logging('Test: [{0}/{1}]\t'
                    '{e_log}\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    .format(
                        i, len(data_loader), batch_time=batch_time,
                        e_log=str(model.logger)))
        del images, captions

    return img_embs, cap_embs


def evalrank(model_path, data_path=None, data_name=None, split='dev', fold5=False):
    """
    Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
    cross-validation is done (only for MSCOCO). Otherwise, the full data is
    used for evaluation.
    """
    # load model and options
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    opt.use_external_captions = False
    if data_path is not None:
        opt.data_path = data_path
    if data_name is not None:
        opt.data_name = data_name

    # load vocabulary used by the model
    with open(os.path.join(opt.vocab_path,
                           '%s_vocab.pkl' % opt.data_name), 'rb') as f:
        vocab = pickle.load(f)
    opt.vocab_size = len(vocab)

    # construct model
    model = VSE(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size,
                                  opt.batch_size, opt.workers, opt)

    print('Computing results...')
    img_embs, cap_embs = encode_data(model, data_loader)
    print('Images: %d, Captions: %d' %
          (img_embs.shape[0] / 5, cap_embs.shape[0]))

    if not fold5:
        # no cross-validation, full evaluation
        r, rt = i2t(img_embs, cap_embs, measure=opt.measure, return_ranks=True)
        ri, rti = t2i(img_embs, cap_embs,
                      measure=opt.measure, return_ranks=True)
        ar = (r[0] + r[1] + r[2]) / 3
        ari = (ri[0] + ri[1] + ri[2]) / 3
        rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
        print("rsum: %.1f" % rsum)
        print("Average i2t Recall: %.1f" % ar)
        print("Image to text: %.1f\t%.1f\t%.1f\t%.1f\t%.1f" % r)
        print("Average t2i Recall: %.1f" % ari)
        print("Text to image: %.1f\t%.1f\t%.1f\t%.1f\t%.1f" % ri)
    else:
        # 5fold cross-validation, only for MSCOCO
        results = []
        for i in range(5):
            r, rt0 = i2t(img_embs[i * 5000:(i + 1) * 5000],
                         cap_embs[i * 5000:(i + 1) *
                                  5000], measure=opt.measure,
                         return_ranks=True)
            print("Image to text: %.1f\t%.1f\t%.1f\t%.1f\t%.1f" % r)
            ri, rti0 = t2i(img_embs[i * 5000:(i + 1) * 5000],
                           cap_embs[i * 5000:(i + 1) *
                                    5000], measure=opt.measure,
                           return_ranks=True)
            if i == 0:
                rt, rti = rt0, rti0
            print("Text to image: %.1f\t%.1f\t%.1f\t%.1f\t%.1f" % ri)
            ar = (r[0] + r[1] + r[2]) / 3
            ari = (ri[0] + ri[1] + ri[2]) / 3
            rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
            print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
            results += [list(r) + list(ri) + [ar, ari, rsum]]

        print("-----------------------------------")
        print("Mean metrics: ")
        mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
        print("rsum: %.1f" % (mean_metrics[10] * 6))
        print("Average i2t Recall: %.1f" % mean_metrics[11])
        print("Image to text: %.1f\t%.1f\t%.1f\t%.1f\t%.1f" %
              mean_metrics[:5])
        print("Average t2i Recall: %.1f" % mean_metrics[12])
        print("Text to image: %.1f\t%.1f\t%.1f\t%.1f\t%.1f" %
              mean_metrics[5:10])

    torch.save({'rt': rt, 'rti': rti}, model_path[:model_path.find('model_best')] + 'ranks.pth.tar')


def i2t(images, captions, npts=None, measure='cosine', return_ranks=False):
    """
    Images->Text (Image Annotation)
    Images: (5N, K) matrix of images
    Captions: (5N, K) matrix of captions
    """
    if npts is None:
        npts = images.shape[0] // 5
    index_list = []

    ranks = numpy.zeros(npts)
    top1 = numpy.zeros(npts)
    for index in range(npts):

        # Get query image
        im = images[5 * index].reshape(1, images.shape[1])

        # Compute scores
        if measure == 'order':
            bs = 100
            if index % bs == 0:
                mx = min(images.shape[0], 5 * (index + bs))
                im2 = images[5 * index:mx:5]
                d2 = order_sim(torch.Tensor(im2).cuda(),
                               torch.Tensor(captions).cuda())
                d2 = d2.cpu().numpy()
            d = d2[index % bs]
        else:
            d = numpy.dot(im, captions.T).flatten()
        inds = numpy.argsort(d)[::-1]
        index_list.append(inds[0])

        # Score
        rank = 1e20
        for i in range(5 * index, 5 * index + 5, 1):
            tmp = numpy.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank
        top1[index] = inds[0]

    # Compute metrics
    r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
    r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
    r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
    medr = numpy.floor(numpy.median(ranks)) + 1
    meanr = ranks.mean() + 1
    if return_ranks:
        return (r1, r5, r10, medr, meanr), (ranks, top1)
    else:
        return r1, r5, r10, medr, meanr


def t2i(images, captions, npts=None, measure='cosine', return_ranks=False):
    """
    Text->Images (Image Search)
    Images: (5N, K) matrix of images
    Captions: (5N, K) matrix of captions
    """
    if npts is None:
        npts = images.shape[0] // 5
    ims = numpy.array([images[i] for i in range(0, len(images), 5)])
    ranks = numpy.zeros(5 * npts)
    top1 = numpy.zeros(5 * npts)
    for index in range(npts):

        # Get query captions
        queries = captions[5 * index:5 * index + 5]

        # Compute scores
        if measure == 'order':
            bs = 100
            if 5 * index % bs == 0:
                mx = min(captions.shape[0], 5 * index + bs)
                q2 = captions[5 * index:mx]
                d2 = order_sim(torch.Tensor(ims).cuda(),
                               torch.Tensor(q2).cuda())
                d2 = d2.cpu().numpy()

            d = d2[:, (5 * index) % bs:(5 * index) % bs + 5].T
        else:
            d = numpy.dot(queries, ims.T)
        inds = numpy.zeros(d.shape)
        for i in range(len(inds)):
            inds[i] = numpy.argsort(d[i])[::-1]
            ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0]
            top1[5 * index + i] = inds[i][0]

    # Compute metrics
    r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
    r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
    r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
    medr = numpy.floor(numpy.median(ranks)) + 1
    meanr = ranks.mean() + 1
    if return_ranks:
        return (r1, r5, r10, medr, meanr), (ranks, top1)
    else:
        return r1, r5, r10, medr, meanr


def i2t_split(images, captions_orig, captions_ex, npts=None, measure='cosine', return_ranks=False):
    if npts is None:
        npts = images.shape[0] // 5
    index_list = []

    ranks = numpy.zeros(npts)
    top1 = numpy.zeros(npts)
    for index in range(npts):
        # Get query image
        im = images[5 * index].reshape(1, images.shape[1])
        captions_ex_ind = [captions_orig] + [captions_ex[5 * index + j] for j in range(5)]
        captions = np.concatenate(captions_ex_ind, axis=0)

        # Compute scores
        if measure == 'order':
            bs = 100
            if index % bs == 0:
                mx = min(images.shape[0], 5 * (index + bs))
                im2 = images[5 * index:mx:5]
                d2 = order_sim(torch.Tensor(im2).cuda(),
                               torch.Tensor(captions).cuda())
                d2 = d2.cpu().numpy()
            d = d2[index % bs]
        else:
            d = numpy.dot(im, captions.T).flatten()
#            print(captions.shape)
        inds = numpy.argsort(d)[::-1]
        index_list.append(inds[0])

        # Score
        rank = 1e20
        for i in range(5 * index, 5 * index + 5, 1):
            tmp = numpy.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank
        top1[index] = inds[0]

    # Compute metrics
    r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
    r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
    r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
    medr = numpy.floor(numpy.median(ranks)) + 1
    meanr = ranks.mean() + 1
    if return_ranks:
        return (r1, r5, r10, medr, meanr), (ranks, top1)
    else:
        return r1, r5, r10, medr, meanr


def i2t_text_only(images, captions, npts=None, measure='cosine', return_ranks=False):
    """
    Images->Text (Image Annotation)
    Images: (5N, K) matrix of images
    Captions: (5N, K) matrix of captions
    """
    if npts is None:
        npts = images.shape[0] // 5
    index_list = []

    ranks = numpy.zeros(npts)
    top1 = numpy.zeros(npts)
    for index in range(npts):

        # Get query image
        im = images[5 * index].reshape(1, images.shape[1])

        # Compute scores
        if measure == 'order':
            bs = 100
            if index % bs == 0:
                mx = min(images.shape[0], 5 * (index + bs))
                im2 = images[5 * index:mx:5]
                d2 = order_sim(torch.Tensor(im2).cuda(),
                               torch.Tensor(captions).cuda())
                d2 = d2.cpu().numpy()
            d = d2[index % bs]
        else:
            d = numpy.dot(im, captions.T).flatten()
        pn_number = captions.shape[0] / images.shape[0]
        inds = numpy.argsort(d)[::-1]
        index_list.append(inds[0])

        # Score
        rank = 1e20
        for i in range(5 * index, 5 * index + 5, 1):
            tmp = numpy.where(inds == i * pn_number)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank
        top1[index] = inds[0]

    # Compute metrics
    r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
    r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
    r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
    medr = numpy.floor(numpy.median(ranks)) + 1
    meanr = ranks.mean() + 1
    if return_ranks:
        return (r1, r5, r10, medr, meanr), (ranks, top1)
    else:
        return r1, r5, r10, medr, meanr


def eval_with_extended(model_path, data_path=None, data_name=None, split='test'):
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    opt.use_external_captions = True
    opt.negative_number = 5
    if data_path is not None:
        opt.data_path = data_path
    if data_name is not None:
        opt.data_name = data_name

    # load vocabulary used by the model
    with open(os.path.join(opt.vocab_path,
                           '%s_vocab.pkl' % opt.data_name), 'rb') as f:
        vocab = pickle.load(f)
    opt.vocab_size = len(vocab)
    opt.use_external_captions = True

    # construct model
    model = VSE(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size,
                                  opt.batch_size, opt.workers, opt)
    print('Computing results...')
    img_embs, cap_embs = encode_data(model, data_loader)
    print('Images: %d, Captions: %d' %
          (img_embs.shape[0] // 5, cap_embs.shape[0]))

    r, rt = i2t_text_only(img_embs, cap_embs, measure=opt.measure, return_ranks=True)
    ar = (r[0] + r[1] + r[2]) / 3
    print("Average i2t Recall: %.1f" % ar)
    print("Image to text: %.1f\t%.1f\t%.1f\t%.1f\t%.1f" % r)
    torch.save({'rt': rt}, model_path[:model_path.find('model_best')] + 'ranks_extended.pth.tar')


def eval_with_single_extended(model_path, data_path=None, data_name=None, split='test', backup_vec_ex=None):
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    opt.use_external_captions = False
    if data_path is not None:
        opt.data_path = data_path
    if data_name is not None:
        opt.data_name = data_name

    # load vocabulary used by the model
    with open(os.path.join(opt.vocab_path,
                           '%s_vocab.pkl' % opt.data_name), 'rb') as f:
        vocab = pickle.load(f)
    opt.vocab_size = len(vocab)

    # construct model
    model = VSE(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size,
                                  opt.batch_size, opt.workers, opt)
    img_embs, cap_embs = encode_data(model, data_loader)
    if backup_vec_ex is None:
        cap_embs_ex = list()
        for i in range(img_embs.shape[0]):
            data_loader_ex = get_text_loader(
                split, opt.data_name, vocab, opt.batch_size, opt.workers, opt, 'ex/%d' % i)
            encoding = encode_data(model, data_loader_ex)[1]
            if encoding is not None:
                cap_embs_ex.append(encoding.copy())
            else:
                cap_embs_ex.append(np.zeros(cap_embs[:1].shape))
            print('Caption Embedding: %d' % i)
        # torch.save(cap_embs_ex, 'data/coco_precomp/cap_embs_ex.pth')
    else:
        cap_embs_ex = torch.load(backup_vec_ex)
    print('Computing results...')

    r, rt = i2t_split(img_embs, cap_embs, cap_embs_ex, measure=opt.measure, return_ranks=True)
    ar = (r[0] + r[1] + r[2]) / 3
    print("Average i2t Recall: %.1f" % ar)
    print("Image to text: %.1f\t%.1f\t%.1f\t%.1f\t%.1f" % r)
    torch.save({'rt': rt}, model_path[:model_path.find('model_best')] + 'ranks_single_extended.pth.tar')


def eval_with_manually_extended(model_path, data_path=None, split='test'):
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    opt.use_external_captions = False
    if data_path is not None:
        opt.data_path = data_path

    # load vocabulary used by the model
    with open(os.path.join(opt.vocab_path,
                           '%s_vocab.pkl' % opt.data_name), 'rb') as f:
        vocab = pickle.load(f)
    opt.vocab_size = len(vocab)

    # construct model
    model = VSE(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size,
                                  opt.batch_size, opt.workers, opt)
    img_embs, cap_embs = encode_data(model, data_loader)
    img_embs = img_embs[:100]
    cap_embs = cap_embs[:100]
    cap_embs_ex = list()
    data_loader_ex_0 = get_text_loader(
        split, opt.data_name, vocab, opt.batch_size, opt.workers, opt, 'manually_ex_%d' % 0)
    encoding_0 = encode_data(model, data_loader_ex_0)[1]
    data_loader_ex_1 = get_text_loader(
        split, opt.data_name, vocab, opt.batch_size, opt.workers, opt, 'manually_ex_%d' % 1)
    encoding_1 = encode_data(model, data_loader_ex_1)[1]
    for i in range(100):
        cap_emb = np.concatenate((encoding_0[i*2:i*2+2], encoding_1[i*2:i*2+2]), axis=0)
        cap_embs_ex.append(cap_emb)
    print('Computing results...')

    r, rt = i2t_split(img_embs, cap_embs, cap_embs_ex, measure=opt.measure, return_ranks=True)
    # r, rt = i2t(img_embs, cap_embs, measure=opt.measure, return_ranks=True)
    ar = (r[0] + r[1] + r[2]) / 3
    print("Average i2t Recall: %.1f" % ar)
    print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
    torch.save({'rt': rt}, model_path[:model_path.find('model_best')] + 'ranks_manually_extended_1.pth.tar')


def debug_show_similarity_with_manually_created_examples(model_path, data_path=None, split='test'):
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    opt.use_external_captions = False
    if data_path is not None:
        opt.data_path = data_path

    # load vocabulary used by the model
    with open(os.path.join(opt.vocab_path,
                           '%s_vocab.pkl' % opt.data_name), 'rb') as f:
        vocab = pickle.load(f)
    opt.vocab_size = len(vocab)

    # construct model
    model = VSE(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size,
                                  opt.batch_size, opt.workers, opt)
    img_embs, cap_embs = encode_data(model, data_loader)
    img_embs = img_embs[:100]
    cap_embs = cap_embs[:100]
    data_loader_ex_0 = get_text_loader(
        split, opt.data_name, vocab, opt.batch_size, opt.workers, opt, 'manually_ex_%d' % 0)
    encoding_0 = encode_data(model, data_loader_ex_0)[1]
    data_loader_ex_1 = get_text_loader(
        split, opt.data_name, vocab, opt.batch_size, opt.workers, opt, 'manually_ex_%d' % 1)
    encoding_1 = encode_data(model, data_loader_ex_1)[1]
    print('Computing results...')

    # compute similarity
    result = list()
    result_0 = list()
    result_1 = list()

    npts = img_embs.shape[0] // 5
    for index in range(npts):
        # Get query image
        im = img_embs[5 * index].reshape(1, img_embs.shape[1])

        # Compute scores
        if opt.measure == 'order':
            raise Exception('Measure order not supported.')
        else:
            result.append(numpy.dot(im, cap_embs.T).flatten())
            result_0.append(numpy.dot(im, encoding_0.T).flatten())
            result_1.append(numpy.dot(im, encoding_1.T).flatten())
    torch.save({'orig': result, 'Tete': result_0, 'Haoyue': result_1}, 'shy_runs/debug.pt')