r"""torecsys.utils.opeations is a sub module of utils including anything used in the package, and I don't know where should I put them to. """ from functools import reduce import matplotlib.pyplot as plt import matplotlib.ticker as ticker import numpy as np import operator as op import torch import torch.nn as nn from typing import List, Tuple, Tuple, Union def combination(n: int, r: int) -> int: r"""function to calculate combination. Args: n (int): An integer of number of elements r (int): An integer of size of combinations Returns: int: An integer of number of combinations. """ r = min(r, n - r) numer = reduce(op.mul, range(n, n - r, -1), 1) denom = reduce(op.mul, range(1, r + 1), 1) return int(numer / denom) def dummy_attention(key : torch.Tensor, query: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r"""function for dummy in jit-compile features of torch, which have the same inputs and outputs to nn.MultiheadAttention().__call__() Args: key (T): inputs to be passed as output query (T): dummy inputs value (T): dummy inputs Returns: Tuple[T, T]: values = (key, dummy outputs = torch.Tensor([])) """ return key, torch.Tensor([]) def inner_product_similarity(a: torch.Tensor, b: torch.Tensor, dim=1) -> torch.Tensor: r"""function to calculate inner-product of two vectors Args: a (T, shape = (B, N_{a}, E)), dtype = torch.float: the first batch of vector to be multiplied. b (T, shape = (B, N_{b}, E)), dtype = torch.float: the second batch of vector to be multiplied. Returns: T, dtype = torch.float: inner product tensor. """ outputs = (a * b).sum(dim=dim) return outputs def regularize(parametes : List[Tuple[str, nn.Parameter]], weight_decay : float = 0.01, norm : int = 2) -> torch.Tensor: r"""function to calculate p-th order regularization of paramters in the model Args: parametes (List[Tuple[str, nn.Parameter]]): list of tuple of names and paramters to calculate the regularized loss weight_decay (float, optional): multiplier of regularized loss. Defaults to 0.01. norm (int, optional): order of norm to calculate regularized loss. Defaults to 2. Returns: T, shape = (1, ), dtype = torch.float: regularized loss """ loss = 0.0 for name, param in parametes: if "weight" in name: loss += torch.norm(param, p=norm) return loss * weight_decay def replicate_tensor(tensor: torch.Tensor, size: int, dim: int) -> torch.Tensor: r"""replicate tensor by batch / by row Args: tensor (T), shape = (B, ...): Tensor to be replicated. size (int): Size to replicate tensor. dim (int): Dimension to replicate tensor. If dim = 0, replicated by batch, else replicate by row. Returns: T, shape = (B * size, ...): Replicated Tensor. """ # get shape of tensor from pos_samples # inputs: tensor, shape = (B, ...) # output: batch_size, int, values = B # output: tensor_shape, tuple, values = (...) batch_size = tensor.size(0) tensor_shape = tuple(tensor.size()[1:]) # unsqueeze by dim 1 and repeat n-times by dim 1 # inputs: tensor, shape = (B, ...) # output: repeat_tensor, shape = (B, size, ...) / (1, B * size, ...) # TODO. update repeat_size by dim repeat_size = (1, size) + tuple([1 for _ in range(len(tensor_shape))]) repeat_tensor = tensor.unsqueeze(dim).repeat(repeat_size) # reshape repeat_tensor to (batch_size * size, ...) # inputs: repeat_tensor, shape = (B, size, ...) / (1, B * size, ...) # output: repeat_tensor, shape = (B * size, ...) reshape_size = (batch_size * size, ) + tensor_shape repeat_tensor = repeat_tensor.view(reshape_size) return repeat_tensor def show_attention(attentions : np.ndarray, xaxis : Union[list, str] = None, yaxis : Union[list, str] = None, savedir : str = None): r"""Show attention of MultiheadAttention in a mpl heatmap Args: attentions (np.ndarray), shape = (sequence length, sequence length), dtype = np.float32: Attentions Weights of output of nn.MultiheadAttention xaxis (str, optional): string or list of xaxis. Defaults to None. yaxis (str, optional): string or list of yaxis. Defaults to None. savedir (str, optional): string of directory to save the attention png. Defaults to None. """ # set up figure with colorbar fig = plt.figure() ax = fig.add_subplot(111) cax = ax.matshow(attentions) fig.colorbar(cax) # set up axes if xaxis is not None: if isinstance(xaxis, str): xaxis = [""] + xaxis.split(",") elif isinstance(xaxis, list): xaxis = [""] + xaxis ax.set_xticklabels(xaxis, rotation=90) if yaxis is not None: if isinstance(yaxis, str): yaxis = [""] + yaxis.split(",") elif isinstance(yaxis, list): yaxis = [""] + yaxis ax.set_yticklabels(yaxis) # show label at every tick ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) if savedir is None: plt.show() else: plt.savefig(savedir) def squash(inputs: torch.Tensor, dim=-1) -> torch.Tensor: r"""apply `squashing` non-linearity to inputs Args: inputs (T): Inputs tensor which is to be applied squashing. dim (int, optional): Dimension to be applied squashing. Defaults to -1. Returns: T: Squashed tensor. """ # calculate squared norm of inputs squared_norm = torch.sum(torch.pow(inputs, 2), dim=dim, keepdim=True) # apply `squashing` non-linearity to inputs c_j = (squared_norm / (1 + squared_norm)) * (inputs / (torch.sqrt(squared_norm) + 1e-8)) return c_j