""" Copyright (c) 2019-present NAVER Corp. MIT License """ import numpy as np def calc_rank(truth, candidates): return 1 + candidates.index(truth) if truth in candidates else 0 def calc_partial_rank(truth, candidates): for i, candidate in enumerate(candidates, 1): if truth == candidate or truth.startswith(candidate + ' '): return i return 0 def mrr_summary(ranks, pranks, seens, n_candidates): ranks = np.array(ranks) pranks = np.array(pranks) n = np.zeros(3, dtype=int) rank = np.zeros((3, n_candidates + 1), dtype=int) prank = np.zeros((3, n_candidates + 1), dtype=int) reciprocal = np.array([0.] + [1. / r for r in range(1, n_candidates + 1)]).reshape(1, -1) for s, r, pr in zip(seens, ranks, pranks): for i in [1 - s, 2]: n[i] += 1 rank[i, r] += 1 prank[i, pr] += 1 mrr = np.cumsum(rank * reciprocal, 1) / n.reshape((3, 1)) pmrr = np.cumsum(prank * reciprocal, 1) / n.reshape((3, 1)) logs = [] for i in range(1, n_candidates + 1): i_str = ' '.join(f"{mrr[s, i]:.4f} ({seen_str})" for s, seen_str in enumerate(['seen', 'unseen', 'all'])) logs.append(f"mrr @{i:-2d}: {i_str}") logs.append(" ") for i in range(1, n_candidates + 1): i_str = ' '.join(f"{pmrr[s, i]:.4f} ({seen_str})" for s, seen_str in enumerate(['seen', 'unseen', 'all'])) logs.append(f"pmrr @{i:-2d}: {i_str}") logs.append(" ") return logs def mrl_summary(recover_lengths, seens, n_candidates): recover_lengths = np.array(recover_lengths) seens = np.array(seens) mrl = np.concatenate((recover_lengths[seens == 1].mean(0).reshape((1, -1)), recover_lengths[seens == 0].mean(0).reshape((1, -1)), recover_lengths.mean(0).reshape((1, -1))), 0) logs = [] for i in range(1, n_candidates + 1): i_str = ' '.join(f"{mrl[s, i]:.4f} ({seen_str})" for s, seen_str in enumerate(['seen', 'unseen', 'all'])) logs.append(f"mrl @{i:-2d}: {i_str}") logs.append(" ") return logs