from torch import nn from torch.nn import Sequential as Seq, Linear as Lin ############################## # Basic layers ############################## def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1): """ helper selecting activation :param act: :param inplace: :param neg_slope: :param n_prelu: :return: """ act = act_type.lower() if act == 'relu': layer = nn.ReLU(inplace) elif act == 'leakyrelu': layer = nn.LeakyReLU(neg_slope, inplace) elif act == 'prelu': layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) else: raise NotImplementedError('activation layer [%s] is not found' % act) return layer def norm_layer(norm_type, nc): # helper selecting normalization layer norm = norm_type.lower() if norm == 'batch': layer = nn.BatchNorm1d(nc, affine=True) elif norm == 'instance': layer = nn.InstanceNorm1d(nc, affine=False) else: raise NotImplementedError('normalization layer [%s] is not found' % norm) return layer class MultiSeq(Seq): def __init__(self, *args): super(MultiSeq, self).__init__(*args) def forward(self, *inputs): for module in self._modules.values(): if type(inputs) == tuple: inputs = module(*inputs) else: inputs = module(inputs) return inputs class MLP(nn.Module): def __init__(self, channels, act='relu', norm=None, bias=True): super(MLP, self).__init__() m = [] for i in range(1, len(channels)): m.append(Lin(channels[i - 1], channels[i], bias)) if act: m.append(act_layer(act)) if norm: m.append(norm_layer(norm, channels[-1])) self.body = Seq(*m) def forward(self, x, edge_index=None): return self.body(x)