import numpy as np
import logging
from timeit import default_timer as timer
from scipy.optimize import fmin_l_bfgs_b, basinhopping
import torch
import torch.nn.functional as F
from v1_metrics import compute_eer, compute_confuse
import data_reader.adv_kaldi_io as ako

## Get the same logger from main"
logger = logging.getLogger("anti-spoofing")

def validation(args, model, device, train_loader, val_loader, val_scp, val_utt2label):
    logger.info("Starting Validation")
    train_loss, _, train_correct = compute_loss(model, device, train_loader)
    val_loss, val_scores, val_correct = compute_loss(model, device, val_loader)
    predictions, labels =  utt_scores(val_scores, val_scp, val_utt2label)
    val_eer, threshold  = best_eer(labels, predictions)

    logger.info('\n===> Training set: Average loss: {:.4f}\tAccuracy: {}/{} ({:.0f}%)\n'.format(
                train_loss, train_correct, len(train_loader.dataset),
                100. * train_correct / len(train_loader.dataset)
                ))
    logger.info('===> Validation set: Average loss: {:.4f}\tEER: {:.4f}\tThreshold: {}\n'.format(
                val_loss, val_eer, threshold))
    return val_loss, val_eer, threshold

def best_eer(true_labels, predictions):

    def f_neg(threshold):
        ## Scipy tries to minimize the function
        return compute_eer(true_labels, predictions >= threshold)

    # Initialization of best threshold search
    thr_0 = [0.20] * 1 # binary class
    constraints = [(0.,1.)] * 1 # binary class
    def bounds(**kwargs):
        x = kwargs["x_new"]
        tmax = bool(np.all(x <= 1))
        tmin = bool(np.all(x >= 0))
        return tmax and tmin

    # Search using L-BFGS-B, the epsilon step must be big otherwise there is no gradient
    minimizer_kwargs = {"method": "L-BFGS-B",
                        "bounds":constraints,
                        "options":{
                            "eps": 0.05
                            }
                       }

    # We combine L-BFGS-B with Basinhopping for stochastic search with random steps
    logger.info("===> Searching optimal threshold for each label")
    start_time = timer()

    opt_output = basinhopping(f_neg, thr_0,
                                stepsize = 0.1,
                                minimizer_kwargs=minimizer_kwargs,
                                niter=10,
                                accept_test=bounds)

    end_time = timer()
    logger.info("===> Optimal threshold for each label:\n{}".format(opt_output.x))
    logger.info("Threshold found in: %s seconds" % (end_time - start_time))

    score = opt_output.fun
    return score, opt_output.x

def utt_scores(scores, scp, utt2label):
    """return predictions and labels per utterance
    """
    utt2len   = ako.read_key_len(scp)
    utt2label = ako.read_key_label(utt2label)
    key_list  = ako.read_all_key(scp)

    preds, labels = [], []
    idx = 0
    for key in key_list:
        frames_per_utt = utt2len[key]
        avg_scores = np.average(scores[idx:idx+frames_per_utt])
        idx = idx + frames_per_utt
        preds.append(avg_scores)
        labels.append(utt2label[key])

    return np.array(preds), np.array(labels)

def compute_loss(model, device, data_loader, threshold=0.5):
    model.eval()
    loss = 0
    correct = 0
    scores  = []

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            target = target.view(-1,1).float()
            #output, hidden = model(data, None)
            output = model(data)
            loss += F.binary_cross_entropy(output, target, size_average=False)
            pred = output > 0.5
            correct += pred.byte().eq(target.byte()).sum().item() # not really meaningful

            scores.append(output.data.cpu().numpy())

    loss /= len(data_loader.dataset) # average loss
    scores = np.vstack(scores) # scores per frame

    return loss, scores, correct