# Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. import torch import torch.nn as nn import torch.nn.functional as F import logging from ..BaseLayer import BaseConf, BaseLayer from utils.DocInherit import DocInherit from utils.exceptions import ConfigurationError import copy class CalculateDistanceConf(BaseConf): """ Configuration of CalculateDistance Layer Args: operations (list): a subset of ["cos", "euclidean", "manhattan", "chebyshev"]. """ # init the args def __init__(self, **kwargs): super(CalculateDistanceConf, self).__init__(**kwargs) # set default params @DocInherit def default(self): self.operations = ["cos", "euclidean", "manhattan", "chebyshev"] @DocInherit def declare(self): self.num_of_inputs = 2 self.input_ranks = [2] @DocInherit def inference(self): self.output_dim = copy.deepcopy(self.input_dims[0]) self.output_dim[-1] = 1 super(CalculateDistanceConf, self).inference() @DocInherit def verify(self): super(CalculateDistanceConf, self).verify() assert len(self.input_dims) == 2, "Operation requires that there should be two inputs" # to check if the ranks of all the inputs are equal rank_equal_flag = True for i in range(len(self.input_ranks)): if self.input_ranks[i] != self.input_ranks[0] or self.input_ranks[i] != 2: rank_equal_flag = False break if rank_equal_flag == False: raise ConfigurationError("For layer CalculateDistance, the ranks of each inputs should be equal and 2!") class CalculateDistance(BaseLayer): """ CalculateDistance layer to calculate the distance of sequences(2D representation) Args: layer_conf (CalculateDistanceConf): configuration of a layer """ def __init__(self, layer_conf): super(CalculateDistance, self).__init__(layer_conf) self.layer_conf = layer_conf def forward(self, x, x_len, y, y_len): """ Args: x: [batch_size, dim] x_len: [batch_size] y: [batch_size, dim] y_len: [batch_size] Returns: Tensor: [batch_size, 1], None """ batch_size = x.size()[0] if "cos" in self.layer_conf.operations: result = F.cosine_similarity(x , y) elif "euclidean" in self.layer_conf.operations: result = torch.sqrt(torch.sum((x-y)**2, dim=1)) elif "manhattan" in self.layer_conf.operations: result = torch.sum(torch.abs((x - y)), dim=1) elif "chebyshev" in self.layer_conf.operations: result = torch.abs((x - y)).max(dim=1) else: raise ConfigurationError("This operation is not supported!") result = result.view(batch_size, 1) return result, None