import torch import torch.nn as nn import torch.nn.functional as F from metal.classifier import Classifier from metal.end_model.em_defaults import em_default_config from metal.end_model.identity_module import IdentityModule from metal.end_model.loss import SoftCrossEntropyLoss from metal.utils import MetalDataset, pred_to_prob, recursive_merge_dicts class EndModel(Classifier): """A dynamically constructed discriminative classifier layer_out_dims: a list of integers corresponding to the output sizes of the layers of your network. The first element is the dimensionality of the input layer, the last element is the dimensionality of the head layer (equal to the cardinality of the task), and all other elements dictate the sizes of middle layers. The number of middle layers will be inferred from this list. input_module: (nn.Module) a module that converts the user-provided model inputs to torch.Tensors. Defaults to IdentityModule. middle_modules: (nn.Module) a list of modules to execute between the input_module and task head. Defaults to nn.Linear. head_module: (nn.Module) a module to execute right before the final softmax that outputs a prediction for the task. """ def __init__( self, layer_out_dims, input_module=None, middle_modules=None, head_module=None, **kwargs, ): if len(layer_out_dims) < 2 and not kwargs["skip_head"]: raise ValueError( "Arg layer_out_dims must have at least two " "elements corresponding to the output dim of the input module " "and the cardinality of the task. If the input module is the " "IdentityModule, then the output dim of the input module will " "be equal to the dimensionality of your input data points" ) # Add layer_out_dims to kwargs so it will be merged into the config dict kwargs["layer_out_dims"] = layer_out_dims config = recursive_merge_dicts(em_default_config, kwargs, misses="insert") super().__init__(k=layer_out_dims[-1], config=config) self._build(input_module, middle_modules, head_module) # Show network if self.config["verbose"]: print("\nNetwork architecture:") self._print() print() def _build(self, input_module, middle_modules, head_module): """ TBD """ input_layer = self._build_input_layer(input_module) middle_layers = self._build_middle_layers(middle_modules) # Construct list of layers layers = [input_layer] if middle_layers is not None: layers += middle_layers if not self.config["skip_head"]: head = self._build_task_head(head_module) layers.append(head) # Construct network if len(layers) > 1: self.network = nn.Sequential(*layers) else: self.network = layers[0] # Construct loss module loss_weights = self.config["train_config"]["loss_weights"] if loss_weights is not None and self.config["verbose"]: print(f"Using class weight vector {loss_weights}...") reduction = self.config["train_config"]["loss_fn_reduction"] self.criteria = SoftCrossEntropyLoss( weight=self._to_torch(loss_weights, dtype=torch.FloatTensor), reduction=reduction, ) def _build_input_layer(self, input_module): if input_module is None: input_module = IdentityModule() output_dim = self.config["layer_out_dims"][0] input_layer = self._make_layer( input_module, "input", self.config["input_layer_config"], output_dim=output_dim, ) return input_layer def _build_middle_layers(self, middle_modules): layer_out_dims = self.config["layer_out_dims"] num_mid_layers = len(layer_out_dims) - 2 if num_mid_layers == 0: return None middle_layers = nn.ModuleList() for i in range(num_mid_layers): if middle_modules is None: module = nn.Linear(*layer_out_dims[i : i + 2]) output_dim = layer_out_dims[i + 1] else: module = middle_modules[i] output_dim = None layer = self._make_layer( module, "middle", self.config["middle_layer_config"], output_dim=output_dim, ) middle_layers.add_module(f"layer{i+1}", layer) return middle_layers def _build_task_head(self, head_module): if head_module is None: head = nn.Linear(self.config["layer_out_dims"][-2], self.k) else: # Note that if head module is provided, it must have input dim of # the last middle module and output dim of self.k, the cardinality head = head_module return head def _make_layer(self, module, prefix, layer_config, output_dim=None): if isinstance(module, IdentityModule): return module layer = [module] if layer_config[f"{prefix}_relu"]: layer.append(nn.ReLU()) if layer_config[f"{prefix}_batchnorm"] and output_dim: layer.append(nn.BatchNorm1d(output_dim)) if layer_config[f"{prefix}_dropout"]: layer.append(nn.Dropout(layer_config[f"{prefix}_dropout"])) if len(layer) > 1: return nn.Sequential(*layer) else: return layer[0] def _print(self): print(self.network) def forward(self, x): """Returns a list of outputs for tasks 0,...t-1 Args: x: a [batch_size, ...] batch from X """ return self.network(x) @staticmethod def _reset_module(m): """A method for resetting the parameters of any module in the network First, handle special cases (unique initialization or none required) Next, use built in method if available Last, report that no initialization occured to avoid silent failure. This will be called on all children of m as well, so do not recurse manually. """ if callable(getattr(m, "reset_parameters", None)): m.reset_parameters() def update_config(self, update_dict): """Updates self.config with the values in a given update dictionary""" self.config = recursive_merge_dicts(self.config, update_dict) def _preprocess_Y(self, Y, k): """Convert Y to prob labels if necessary""" Y = Y.clone() # If preds, convert to probs if Y.dim() == 1 or Y.shape[1] == 1: Y = pred_to_prob(Y.long(), k=k) return Y def _create_dataset(self, *data): return MetalDataset(*data) def _get_loss_fn(self): criteria = self.criteria.to(self.config["device"]) # This self.preprocess_Y allows us to not handle preprocessing # in a custom dataloader, but decreases speed a bit loss_fn = lambda X, Y: criteria(self.forward(X), self._preprocess_Y(Y, self.k)) return loss_fn def train_model(self, train_data, valid_data=None, log_writer=None, **kwargs): self.config = recursive_merge_dicts(self.config, kwargs) # If train_data is provided as a tuple (X, Y), we can make sure Y is in # the correct format # NOTE: Better handling for if train_data is Dataset or DataLoader...? if isinstance(train_data, (tuple, list)): X, Y = train_data Y = self._preprocess_Y(self._to_torch(Y, dtype=torch.FloatTensor), self.k) train_data = (X, Y) # Convert input data to data loaders train_loader = self._create_data_loader(train_data, shuffle=True) # Create loss function loss_fn = self._get_loss_fn() # Execute training procedure self._train_model( train_loader, loss_fn, valid_data=valid_data, log_writer=log_writer ) def predict_proba(self, X): """Returns a [n, k] tensor of probs (probabilistic labels).""" return F.softmax(self.forward(X), dim=1).data.cpu().numpy()