from datetime import datetime import sys sys.path.append('..') import torch as th import pandas as pd from tqdm import tqdm def precision_k(pred, label, k=[1, 3, 5]): batch_size = pred.shape[0] precision = [] for _k in k: p = 0 for i in range(batch_size): p += label[i, pred[i, :_k]].mean().item() precision.append(p*100/batch_size) return precision def ndcg_k(pred, label, k=[1, 3, 5]): batch_size = pred.shape[0] ndcg = [] for _k in k: score = 0 rank = th.log2(th.arange(2, 2 + _k, dtype=label.dtype, device=label.device)) for i in range(batch_size): l = label[i, pred[i, :_k]] n = l.sum().item() if(n == 0): continue dcg = (l/rank).sum().item() label_count = label[i].sum().item() norm = 1 / th.log2(th.arange(2, 2 + min(_k, label_count), dtype=label.dtype)) norm = norm.sum().item() score += dcg/norm ndcg.append(score*100/batch_size) return ndcg def evaluate(net, if_log=False, test_data_loader=None, data_path='./data/sample', test_batch_size=50, word_num=500): if(test_data_loader == None): test_data_loader = load_test_data(data_path, test_batch_size, word_num) p1, p3, p5 = 0, 0, 0 ndcg1, ndcg3, ndcg5 = 0, 0, 0 with th.no_grad(): for batch_idx, (X_batch, y_batch) in tqdm(enumerate(test_data_loader), desc='evaluating'): _batch_size = X_batch.shape[0] X_batch = X_batch.cuda() y_batch = y_batch.cuda() output = net(X_batch) pred = output.topk(k=5)[1] _p1, _p3, _p5 = precision_k(pred, y_batch, k=[1, 3, 5]) p1 += _p1 p3 += _p3 p5 += _p5 _ndcg1, _ndcg3, _ndcg5 = ndcg_k(pred, y_batch, k=[1, 3, 5]) ndcg1 += _ndcg1 ndcg3 += _ndcg3 ndcg5 += _ndcg5 batch_idx += 1 p1 /= batch_idx p3 /= batch_idx p5 /= batch_idx ndcg1 /= batch_idx ndcg3 /= batch_idx ndcg5 /= batch_idx print('P@1\t%.3f\t\tP@3\t%.3f\t\tP@5\t%.3f' %(p1, p3, p5)) print('nDCG@1\t%.3f\t\tnDCG@3\t%.3f\t\tnDCG@5\t%.3f' %(ndcg1, ndcg3, ndcg5)) if(if_log): log_columns = ['P@1', 'P@3', 'P@5', 'nDCGP@1', 'nDCG@3', 'nDCG@5'] log = pd.DataFrame([[p1, p3, p5, ndcg1, ndcg3, ndcg5]], columns=log_columns) log.to_csv('./log/result-' + str(datetime.now()) + '.csv', encoding='utf-8', index=False)