import numpy as np
import logging
from timeit import default_timer as timer
from scipy.optimize import fmin_l_bfgs_b, basinhopping
import torch
import torch.nn.functional as F
from v1_metrics import compute_eer
import data_reader.adv_kaldi_io as ako

"""
utterance-based validation without stochastic search for threshold 
important: EER does not need a threshold. 
"""

## Get the same logger from main"
logger = logging.getLogger("anti-spoofing")

def validation(args, model, device, val_loader, val_scp, val_utt2label):
    logger.info("Starting Validation")
    val_loss, val_scores = compute_loss(model, device, val_loader)
    val_preds, val_labels =  utt_scores(val_scores, val_scp, val_utt2label)
    val_eer  = compute_eer(val_labels, val_preds)

    logger.info('===> Validation set: Average loss: {:.4f}\tEER: {:.4f}\n'.format(
                val_loss, val_eer))
    return val_loss, val_eer

def utt_scores(scores, scp, utt2label):
    """return predictions and labels per utterance
    """
    utt2label = ako.read_key_label(utt2label)
    key_list  = ako.read_all_key(scp)

    preds, labels = [], []
    idx = 0
    for key in key_list:
        preds.append(scores[idx])
        idx += 1
        labels.append(utt2label[key])

    return np.array(preds), np.array(labels)

def compute_loss(model, device, data_loader):
    model.eval()
    loss = 0
    scores  = []

    with torch.no_grad():
        for X1, X2, target in data_loader:
            X1, X2, target = X1.to(device), X2.to(device), target.to(device)
            target = target.view(-1,1).float()
            y = model(X1, X2)
            loss += F.binary_cross_entropy(y, target, size_average=False)
            scores.append(y.data.cpu().numpy())

    loss /= len(data_loader.dataset) # average loss
    scores = np.vstack(scores) # scores per utterance

    return loss, scores