import torch import torch.nn as nn class DenseResidualLayer(nn.Module): """ PyTorch like layer for standard linear layer with identity residual connection. :param num_features: (int) Number of input / output units for the layer. """ def __init__(self, num_features): super(DenseResidualLayer, self).__init__() self.linear = nn.Linear(num_features, num_features) def forward(self, x): """ Forward-pass through the layer. Implements the following computation: f(x) = f_theta(x) + x f_theta(x) = W^T x + b :param x: (torch.tensor) Input representation to apply layer to ( dim(x) = (batch, num_features) ). :return: (torch.tensor) Return f(x) ( dim(f(x) = (batch, num_features) ). """ identity = x out = self.linear(x) out += identity return out class DenseResidualBlock(nn.Module): """ Wrapping a number of residual layers for residual block. Will be used as building block in FiLM hyper-networks. :param in_size: (int) Number of features for input representation. :param out_size: (int) Number of features for output representation. """ def __init__(self, in_size, out_size): super(DenseResidualBlock, self).__init__() self.linear1 = nn.Linear(in_size, out_size) self.linear2 = nn.Linear(out_size, out_size) self.linear3 = nn.Linear(out_size, out_size) self.elu = nn.ELU() def forward(self, x): """ Forward pass through residual block. Implements following computation: h = f3( f2( f1(x) ) ) + x or h = f3( f2( f1(x) ) ) where fi(x) = Elu( Wi^T x + bi ) :param x: (torch.tensor) Input representation to apply layer to ( dim(x) = (batch, in_size) ). :return: (torch.tensor) Return f(x) ( dim(f(x) = (batch, out_size) ). """ identity = x out = self.linear1(x) out = self.elu(out) out = self.linear2(out) out = self.elu(out) out = self.linear3(out) if x.shape[-1] == out.shape[-1]: out += identity return out class FilmAdaptationNetwork(nn.Module): """ FiLM adaptation network (outputs FiLM adaptation parameters for all layers in a base feature extractor). :param layer: (FilmLayerNetwork) Layer object to be used for adaptation. :param num_maps_per_layer: (list::int) Number of feature maps for each layer in the network. :param num_blocks_per_layer: (list::int) Number of residual blocks in each layer in the network (see ResNet file for details about ResNet architectures). :param z_g_dim: (int) Dimensionality of network input. For this network, z is shared across all layers. """ def __init__(self, layer, num_maps_per_layer, num_blocks_per_layer, z_g_dim): super().__init__() self.z_g_dim = z_g_dim self.num_maps = num_maps_per_layer self.num_blocks = num_blocks_per_layer self.num_target_layers = len(self.num_maps) self.layer = layer self.layers = self.get_layers() def get_layers(self): """ Loop over layers of base network and initialize adaptation network. :return: (nn.ModuleList) ModuleList containing the adaptation network for each layer in base network. """ layers = nn.ModuleList() for num_maps, num_blocks in zip(self.num_maps, self.num_blocks): layers.append( self.layer( num_maps=num_maps, num_blocks=num_blocks, z_g_dim=self.z_g_dim ) ) return layers def forward(self, x): """ Forward pass through adaptation network to create list of adaptation parameters. :param x: (torch.tensor) (z -- task level representation for generating adaptation). :return: (list::adaptation_params) Returns a list of adaptation dictionaries, one for each layer in base net. """ return [self.layers[layer](x) for layer in range(self.num_target_layers)] def regularization_term(self): """ Simple function to aggregate the regularization terms from each of the layers in the adaptation network. :return: (torch.scalar) A order-0 torch tensor with the regularization term for the adaptation net params. """ l2_term = 0 for layer in self.layers: l2_term += layer.regularization_term() return l2_term class FilmLayerNetwork(nn.Module): """ Single adaptation network for generating the parameters of each layer in the base network. Will be wrapped around by FilmAdaptationNetwork. :param num_maps: (int) Number of output maps to be adapted in base network layer. :param num_blocks: (int) Number of blocks being adapted in the base network layer. :param z_g_dim: (int) Dimensionality of input to network (task level representation). """ def __init__(self, num_maps, num_blocks, z_g_dim): super().__init__() self.z_g_dim = z_g_dim self.num_maps = num_maps self.num_blocks = num_blocks # Initialize a simple shared layer for all parameter adapters (gammas and betas) self.shared_layer = nn.Sequential( nn.Linear(self.z_g_dim, self.num_maps), nn.ReLU() ) # Initialize the processors (adaptation networks) and regularization lists for each of the output params self.gamma1_processors, self.gamma1_regularizers = torch.nn.ModuleList(), torch.nn.ParameterList() self.gamma2_processors, self.gamma2_regularizers = torch.nn.ModuleList(), torch.nn.ParameterList() self.beta1_processors, self.beta1_regularizers = torch.nn.ModuleList(), torch.nn.ParameterList() self.beta2_processors, self.beta2_regularizers = torch.nn.ModuleList(), torch.nn.ParameterList() # Generate the required layers / regularization parameters, and collect them in ModuleLists and ParameterLists for _ in range(self.num_blocks): self.gamma1_processors.append(self._make_layer(num_maps)) self.gamma1_regularizers.append(torch.nn.Parameter(torch.nn.init.normal_(torch.empty(num_maps), 0, 0.001), requires_grad=True)) self.beta1_processors.append(self._make_layer(num_maps)) self.beta1_regularizers.append(torch.nn.Parameter(torch.nn.init.normal_(torch.empty(num_maps), 0, 0.001), requires_grad=True)) self.gamma2_processors.append(self._make_layer(num_maps)) self.gamma2_regularizers.append(torch.nn.Parameter(torch.nn.init.normal_(torch.empty(num_maps), 0, 0.001), requires_grad=True)) self.beta2_processors.append(self._make_layer(num_maps)) self.beta2_regularizers.append(torch.nn.Parameter(torch.nn.init.normal_(torch.empty(num_maps), 0, 0.001), requires_grad=True)) @staticmethod def _make_layer(size): """ Simple layer generation method for adaptation network of one of the parameter sets (all have same structure). :param size: (int) Number of parameters in layer. :return: (nn.Sequential) Three layer dense residual network to generate adaptation parameters. """ return nn.Sequential( DenseResidualLayer(size), nn.ReLU(), DenseResidualLayer(size), nn.ReLU(), DenseResidualLayer(size) ) def forward(self, x): """ Forward pass through adaptation network. :param x: (torch.tensor) Input representation to network (task level representation z). :return: (list::dictionaries) Dictionary for every block in layer. Dictionary contains all the parameters necessary to adapt layer in base network. Base network is aware of dict structure and can pull params out during forward pass. """ x = self.shared_layer(x) block_params = [] for block in range(self.num_blocks): block_param_dict = { 'gamma1': self.gamma1_processors[block](x).squeeze() * self.gamma1_regularizers[block] + torch.ones_like(self.gamma1_regularizers[block]), 'beta1': self.beta1_processors[block](x).squeeze() * self.beta1_regularizers[block], 'gamma2': self.gamma2_processors[block](x).squeeze() * self.gamma2_regularizers[block] + torch.ones_like(self.gamma2_regularizers[block]), 'beta2': self.beta2_processors[block](x).squeeze() * self.beta2_regularizers[block] } block_params.append(block_param_dict) return block_params def regularization_term(self): """ Compute the regularization term for the parameters. Recall, FiLM applies gamma * x + beta. As such, params gamma and beta are regularized to unity, i.e. ||gamma - 1||_2 and ||beta||_2. :return: (torch.tensor) Scalar for l2 norm for all parameters according to regularization scheme. """ l2_term = 0 for gamma_regularizer, beta_regularizer in zip(self.gamma1_regularizers, self.beta1_regularizers): l2_term += (gamma_regularizer ** 2).sum() l2_term += (beta_regularizer ** 2).sum() for gamma_regularizer, beta_regularizer in zip(self.gamma2_regularizers, self.beta2_regularizers): l2_term += (gamma_regularizer ** 2).sum() l2_term += (beta_regularizer ** 2).sum() return l2_term class NullFeatureAdaptationNetwork(nn.Module): """ Dummy adaptation network for the case of "no_adaptation". """ def __init__(self): super().__init__() def forward(self, x): return {} @staticmethod def regularization_term(): return 0 class LinearClassifierAdaptationNetwork(nn.Module): """ Versa-style adaptation network for linear classifier (see https://arxiv.org/abs/1805.09921 for full details). :param d_theta: (int) Input / output feature dimensionality for layer. """ def __init__(self, d_theta): super(LinearClassifierAdaptationNetwork, self).__init__() self.weight_means_processor = self._make_mean_dense_block(d_theta, d_theta) self.bias_means_processor = self._make_mean_dense_block(d_theta, 1) @staticmethod def _make_mean_dense_block(in_size, out_size): """ Simple method for generating different types of blocks. Final code only uses dense residual blocks. :param in_size: (int) Input representation dimensionality. :param out_size: (int) Output representation dimensionality. :return: (nn.Module) Adaptation network parameters for outputting classification parameters. """ return DenseResidualBlock(in_size, out_size) def forward(self, representation_dict): """ Forward pass through adaptation network. Returns classification parameters for task. :param representation_dict: (dict::torch.tensors) Dictionary containing class-level representations for each class in the task. :return: (dict::torch.tensors) Dictionary containing the weights and biases for the classification of each class in the task. Model can extract parameters and build the classifier accordingly. Supports sampling if ML-PIP objective is desired. """ classifier_param_dict = {} class_weight_means = [] class_bias_means = [] # Extract and sort the label set for the task label_set = list(representation_dict.keys()) label_set.sort() num_classes = len(label_set) # For each class, extract the representation and pass it through adaptation network to generate classification # params for that class. Store parameters in a list, for class_num in label_set: nu = representation_dict[class_num] class_weight_means.append(self.weight_means_processor(nu)) class_bias_means.append(self.bias_means_processor(nu)) # Save the parameters as torch tensors (matrix and vector) and add to dictionary classifier_param_dict['weight_mean'] = torch.cat(class_weight_means, dim=0) classifier_param_dict['bias_mean'] = torch.reshape(torch.cat(class_bias_means, dim=1), [num_classes, ]) return classifier_param_dict class FilmArAdaptationNetwork(nn.Module): """ Auto-Regressive FiLM adaptation network (outputs FiLM adaptation parameters for all layers in a base feature extractor). Similar to FilmAdaptation network, but forward pass leverages Auto-regressive information. :param feature_extractor: (nn.Module) Base network for adaptation (used in AR pass). :param num_maps_per_layer: (list::int) Number of feature maps for each layer in the network. :param num_blocks_per_layer: (list::int) Number of residual blocks in each layer in the network (see ResNet file for details about ResNet architecures). :param num_initial_conv_maps: (int) Number of maps from initial conv layer in base network. :param z_g_dim: (int) Dimensionality of network input. For this network, z is shared across all layers. """ def __init__(self, feature_extractor, num_maps_per_layer, num_blocks_per_layer, num_initial_conv_maps, z_g_dim): super().__init__() self.feature_extractor = feature_extractor self.z_g_dim = z_g_dim self.num_maps = num_maps_per_layer self.num_blocks = num_blocks_per_layer self.num_target_layers = len(self.num_maps) self.affine_layer = FilmArLayerNetwork self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.layers = nn.ModuleList() previous_maps = [num_initial_conv_maps] + self.num_maps for input_dim, num_maps, num_blocks in zip(previous_maps[:-1], self.num_maps, self.num_blocks): self.layers.append( self.affine_layer( input_dim=input_dim, z_g_dim=self.z_g_dim, num_maps=num_maps, num_blocks=num_blocks ) ) def forward(self, x, task_representation): """ Forward pass through the adaptation network. Implements auto-regressive computation detailed in paper (see Section 2.2 in https://arxiv.org/pdf/1906.07697 for further details). :param x: (torch.tensor) Example images context set of task. :param task_representation: (torch.tensor) Global task representation z_G from set encoder. :return: (list::dict::torch.tensor) List of dictionaries of adaptation parameters to be used by model. """ def flatten(t): t = self.avgpool(t) return t.view(t.size(0), -1) param_dicts = [] # Start with initial convolution layer from ResNet to embedd context set and generate first local representation z = self.feature_extractor.get_layer_output(x, None, 0) z_hn = flatten(z) # For every following layer: pass global and local representations through hypernet layer. This returns the # next layer adaptation parameters. Use these to make a pass through the next layer with context set, and # save adaptation parameters. for layer, hn_layer in enumerate(self.layers): param_dicts.append(hn_layer(z_hn, task_representation)) z = self.feature_extractor.get_layer_output(z, param_dicts, layer + 1) z_hn = flatten(z) return param_dicts def regularization_term(self): """ Simple function to aggregate the regularization terms from each of the layers in the adaptation network. :return: (torch.scalar) A order-0 torch tensor with the regularization term for the adaptation net params. """ l2_term = 0 for layer in self.layers: l2_term += layer.regularization_term() return l2_term class FilmArLayerNetwork(nn.Module): """ Single adaptation network for generating the parameters of each layer in the base network. Will be wrapped around by FilmARAdaptationNetwork. Generates adaptation parameters for next layer give z_G and z_AR. :param input_dim: (int) Dimensionality of local representation. :param num_maps: (int) Number of output maps to be adapted in base network layer. :param num_blocks: (int) Number of blocks being adapted in the base network layer. :param z_g_dim: (int) Dimensionality of input to network (task level representation). """ def __init__(self, input_dim, z_g_dim, num_maps, num_blocks): super().__init__() self.input_dim = input_dim self.z_g_dim = z_g_dim self.num_maps = num_maps self.num_blocks = num_blocks self.shared_layer, self.shared_layer_post = self.get_shared_layers() # Initialize ModuleLists and ParameterLists for layer processeors (hyper-nets) and regluarizers self.gamma1_processors, self.gamma1_regularizers = torch.nn.ModuleList(), torch.nn.ParameterList() self.gamma2_processors, self.gamma2_regularizers = torch.nn.ModuleList(), torch.nn.ParameterList() self.beta1_processors, self.beta1_regularizers = torch.nn.ModuleList(), torch.nn.ParameterList() self.beta2_processors, self.beta2_regularizers = torch.nn.ModuleList(), torch.nn.ParameterList() # Loop over blocks. For each block, collect necessary parameters and regularizers for _ in range(self.num_blocks): self.gamma1_processors.append(self._make_layer(self.num_maps + self.z_g_dim, num_maps)) self.gamma1_regularizers.append(torch.nn.Parameter(torch.nn.init.normal_(torch.empty(num_maps), 0, 0.001), requires_grad=True)) self.beta1_processors.append(self._make_layer(self.num_maps + self.z_g_dim, num_maps)) self.beta1_regularizers.append(torch.nn.Parameter(torch.nn.init.normal_(torch.empty(num_maps), 0, 0.001), requires_grad=True)) self.gamma2_processors.append(self._make_layer(self.num_maps + self.z_g_dim, num_maps)) self.gamma2_regularizers.append(torch.nn.Parameter(torch.nn.init.normal_(torch.empty(num_maps), 0, 0.001), requires_grad=True)) self.beta2_processors.append(self._make_layer(self.num_maps + self.z_g_dim, num_maps)) self.beta2_regularizers.append(torch.nn.Parameter(torch.nn.init.normal_(torch.empty(num_maps), 0, 0.001), requires_grad=True)) def get_shared_layers(self): """ Simple layer generation method for shared layer to be used in layer adaptation network. :param size: (int) Number of parameters in layer. :return: (nn.Sequential) Three layer dense residual network to generate adaptation parameters. """ shared_layer_pre = nn.Sequential( nn.Linear(self.input_dim, self.num_maps), nn.ReLU(), DenseResidualLayer(self.num_maps), nn.ReLU(), DenseResidualLayer(self.num_maps), nn.ReLU(), DenseResidualLayer(self.num_maps) ) shared_layer_post = nn.Sequential( nn.Linear(self.num_maps, self.num_maps), nn.ReLU() ) return shared_layer_pre, shared_layer_post @staticmethod def _make_layer(in_size, out_size): """ Simple layer generation method for processor for each of the parameters associated with the base net layer. :param size: (int) Number of parameters in layer. :return: (nn.Sequential) Three layer dense residual network to generate adaptation parameters. """ return nn.Sequential( nn.Linear(in_size, out_size), nn.ReLU(), DenseResidualLayer(out_size), nn.ReLU(), DenseResidualLayer(out_size), nn.ReLU(), DenseResidualLayer(out_size) ) def forward(self, x, task_representation): """ Forward pass through adaptation network. :param x: (torch.tensor) Input representation to network (task level representation z). :return: (list::dictionaries) Dictionary for every block in layer. Dictionary contains all the parameters necessary to adapt layer in base network. Base network is aware of dict structure and can pull params out during forward pass. """ x = self.shared_layer(x) x = torch.mean(x, dim=0, keepdim=True) x = self.shared_layer_post(x) x = torch.cat([x, task_representation], dim=-1) block_params = [] for block in range(self.num_blocks): block_param_dict = { 'gamma1': self.gamma1_processors[block](x).squeeze() * self.gamma1_regularizers[block] + torch.ones_like(self.gamma1_regularizers[block]), 'beta1': self.beta1_processors[block](x).squeeze() * self.beta1_regularizers[block], 'gamma2': self.gamma2_processors[block](x).squeeze() * self.gamma2_regularizers[block] + torch.ones_like(self.gamma2_regularizers[block]), 'beta2': self.beta2_processors[block](x).squeeze() * self.beta2_regularizers[block] } block_params.append(block_param_dict) return block_params def regularization_term(self): """ Compute the regularization term for the parameters. Recall, FiLM applies gamma * x + beta. As such, params gamma and beta are regularized to unity, i.e. ||gamma - 1||_2 and ||beta||_2. :return: (torch.tensor) Scalar for l2 norm for all parameters according to regularization scheme. """ l2_term = 0 for gamma_regularizer, beta_regularizer in zip(self.gamma1_regularizers, self.beta1_regularizers): l2_term += (gamma_regularizer ** 2).sum() l2_term += (beta_regularizer ** 2).sum() for gamma_regularizer, beta_regularizer in zip(self.gamma2_regularizers, self.beta2_regularizers): l2_term += (gamma_regularizer ** 2).sum() l2_term += (beta_regularizer ** 2).sum() return l2_term