import sys from functools import partial import time import numpy as np import torch from torch.autograd import grad from torch.nn.functional import softmax from entmax import ( sparsemax, entmax15, entmax_bisect ) def bench(f_): timings_fwd = [] timings_bck = [] for _ in range(100): with f_ as f: tic = time.perf_counter() f.forward() torch.cuda.synchronize() toc = time.perf_counter() timings_fwd.append(toc - tic) tic = time.perf_counter() f.backward() torch.cuda.synchronize() toc = time.perf_counter() timings_bck.append(toc - tic) return (np.percentile(timings_fwd, [25, 50, 75]), np.percentile(timings_bck, [25, 50, 75])) class MappingBencher(object): def __init__(self, mapping, X): self.mapping = mapping self.X_data = X def __enter__(self): self.X = self.X_data.clone().requires_grad_() self.dY = torch.randn_like(self.X) return self def forward(self): self.Y = self.mapping(self.X, dim=-1) def backward(self): grad(outputs=(self.Y,), inputs=(self.X,), grad_outputs=(self.Y)) def __exit__(self, *args): try: del self.X del self.Y except AttributeError: pass class EntmaxAlphaBencher(object): def __init__(self, X, n_iter=25): self.n_iter = n_iter self.X_data = X def __enter__(self): self.X = self.X_data.clone().requires_grad_() self.dY = torch.randn_like(self.X) self.alpha = 1.01 + torch.rand(self.X.shape[0], 1, device=self.X.device, requires_grad=True) return self def forward(self): self.Y = entmax_bisect(self.X, self.alpha, dim=-1, n_iter=self.n_iter) def backward(self): grad(outputs=(self.Y,), inputs=(self.X, self.alpha), grad_outputs=(self.Y)) def __exit__(self, *args): try: del self.X del self.alpha except AttributeError: pass try: del self.Y except AttributeError: pass def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--dim", dest="dim", type=int, default=30) parser.add_argument( "--batch-size", dest="batch_size", type=int, default=64 * 8 * 30 ) parser.add_argument("--device", dest="device", default="cpu") opt = parser.parse_args() print(opt) X = torch.randn(opt.batch_size, opt.dim, device=opt.device) torch.cuda.synchronize() torch.cuda.synchronize() print("softmax", bench(MappingBencher(softmax, X))) print("sparsemax", bench(MappingBencher(sparsemax, X))) print("entmax15", bench(MappingBencher(entmax15, X))) print("a-entmax 25iter", bench(EntmaxAlphaBencher(X, n_iter=25))) print("a-entmax 10iter", bench(EntmaxAlphaBencher(X, n_iter=10))) if __name__ == "__main__": main()