import numpy as np
import logging
from timeit import default_timer as timer
from scipy.optimize import fmin_l_bfgs_b, basinhopping
import torch
import torch.nn.functional as F
from v1_metrics import compute_eer
import data_reader.adv_kaldi_io as ako

"""
validation without stochastic search for threshold 
important: EER does not need a threshold. 
"""

## Get the same logger from main"
logger = logging.getLogger("anti-spoofing")

def validation(args, model, device, train_loader, train_scp, train_utt2label, val_loader, val_scp, val_utt2label):
    logger.info("Starting Validation")
    train_loss, train_scores = compute_loss(model, device, train_loader)
    val_loss, val_scores = compute_loss(model, device, val_loader)
    train_preds, train_labels = utt_scores(train_scores, train_scp, train_utt2label) 
    val_preds, val_labels =  utt_scores(val_scores, val_scp, val_utt2label)
    train_eer = compute_eer(train_labels, train_preds)
    val_eer  = compute_eer(val_labels, val_preds)

    logger.info('===> Training set: Average loss: {:.4f}\tEER: {:.4f}\n'.format(
                train_loss, train_eer))
    logger.info('===> Validation set: Average loss: {:.4f}\tEER: {:.4f}\n'.format(
                val_loss, val_eer))
    return val_loss, val_eer

def utt_scores(scores, scp, utt2label):
    """return predictions and labels per utterance
    """
    utt2len   = ako.read_key_len(scp)
    utt2label = ako.read_key_label(utt2label)
    key_list  = ako.read_all_key(scp)

    preds, labels = [], []
    idx = 0
    for key in key_list:
        frames_per_utt = utt2len[key]
        avg_scores = np.average(scores[idx:idx+frames_per_utt])
        idx = idx + frames_per_utt
        preds.append(avg_scores)
        labels.append(utt2label[key])

    return np.array(preds), np.array(labels)

def compute_loss(model, device, data_loader):
    model.eval()
    loss = 0
    correct = 0
    scores  = []

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            target = target.view(-1,1).float()
            #output, hidden = model(data, None)
            output = model(data)
            loss += F.binary_cross_entropy(output, target, size_average=False)

            scores.append(output.data.cpu().numpy())

    loss /= len(data_loader.dataset) # average loss
    scores = np.vstack(scores) # scores per frame

    return loss, scores