""" A collection of neural network code. The first part of the script includes blocks, which are the building blocks of our models. The second part includes the actual Pytorch models. """ import torch import torchvision.transforms as transforms class ConvBlock(torch.nn.Module): """ A ConvBlock represents a convolution. It's not just a convolution however, as some common operations (dropout, activation, batchnorm, 2x2 pooling) can be set and run in the order mentioned. """ def __init__( self, dim, n_out, kernel_size=3, stride=1, padding=1, batchnorm=False, dropout=0, activation=True, ): """ A convolution operation """ super(ConvBlock, self).__init__() n_in = int(dim[0]) self.conv2d = torch.nn.Conv2d( n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding ) self.batchnorm = torch.nn.BatchNorm2d(n_out) if batchnorm else None self.activation = torch.nn.ReLU(inplace=True) if activation else None self.dropout = torch.nn.Dropout2d(dropout) if dropout else None dim[0] = n_out dim[1:] = 1 + (dim[1:] + padding * 2 - kernel_size) // stride self.n_params = n_out * (n_in * kernel_size * kernel_size + (3 if batchnorm else 1)) print( "Conv2d in %4i out %4i h %4i w %4i k %i s %i params %9i" % (n_in, *dim, kernel_size, stride, self.n_params) ) def forward(self, batch): """ Forward the 4D batch """ out = self.conv2d(batch) if self.activation: out = self.activation(out) if self.batchnorm: out = self.batchnorm(out) if self.dropout: out = self.dropout(out) return out class LinearBlock(torch.nn.Module): """ A LinearBlock represents a fully connected layer. It's not just this, as some common operations (dropout, activation, batchnorm) can be set and run in the order mentioned. """ def __init__(self, dim, n_out, batchnorm=False, dropout=0.0, activation=True): """ A fully connected operation """ super(LinearBlock, self).__init__() n_in = int(dim[0]) self.linear = torch.nn.Linear(n_in, n_out) dim[0] = n_out if type(n_out) in (int, float) else n_out[0] self.batchnorm = torch.nn.BatchNorm1d(dim[0]) if batchnorm else None self.dropout = torch.nn.Dropout(dropout) if dropout > 0.0 else None self.activation = torch.nn.ReLU(inplace=True) if activation else None self.n_params = n_out * (n_in + (3 if batchnorm else 1)) print( "Linear in %4i out %4i params %9i" % (n_in, n_out, self.n_params) ) def forward(self, batch): """ Forward the 2D batch """ out = self.linear(batch) if self.activation: out = self.activation(out) if self.batchnorm: out = self.batchnorm(out) if self.dropout: out = self.dropout(out) return out class PoolBlock(torch.nn.Module): """ A PoolBlock is a pooling operation that happens on a matrix, often between convolutional layers, on each channel individually. By default only two are supported: max and avg. """ def __init__(self, dim, pool="max", size=None, stride=None): """ A pooling operation """ super(PoolBlock, self).__init__() stride = size if stride is None else stride if size: dim[1:] //= stride else: size = [int(x) for x in dim[1:]] dim[1:] = 1 if pool == "max": self.pool = torch.nn.MaxPool2d(size, stride=stride, padding=0) elif pool == "avg": self.pool = torch.nn.AvgPool2d(size, stride=stride, padding=0) self.n_params = 0 def forward(self, batch): """ Forward the 4D batch """ out = self.pool(batch) return out class ViewBlock(torch.nn.Module): """ A ViewBlock restructures the shape of our activation maps so they're represented as 1D instead of 3D. """ def __init__(self, dim, shape=-1): """ A reshape operation """ super(ViewBlock, self).__init__() self.shape = shape if self.shape == -1: dim[0] = dim[0] * dim[1] * dim[2] dim[-2] = 0 dim[-1] = 0 else: dim[:] = shape self.n_params = 0 print("View d %4i h %4i w %4i" % (*dim,)) def forward(self, batch): """ Forward the 4D batch into a 2D batch """ return batch.view(batch.size(0), self.shape) class Tiny(torch.nn.Module): """ A small and quick model """ def __init__(self, in_dim, n_status, n_out): """ Args: in_dim (list): The input size of each example n_status (int): Number of status inputs to add n_out (int): Number of values to predict """ super(Tiny, self).__init__() self.n_status = n_status dim = in_dim.copy() self.feat = torch.nn.Sequential( ConvBlock(dim, 16), PoolBlock(dim, "max", 2), ConvBlock(dim, 32), PoolBlock(dim, "max", 2), ConvBlock(dim, 48), PoolBlock(dim, "max", 2), ConvBlock(dim, 64), PoolBlock(dim, "max", 2), ) self.view = ViewBlock(dim) dim[0] += n_status self.head = torch.nn.Sequential(LinearBlock(dim, n_out, activation=False)) self.n_params = sum([x.n_params for x in self.feat]) + sum([x.n_params for x in self.head]) print("Tiny params %9i" % self.n_params) def forward(self, batch, status): """ Args: batch (4D tensor): A batch of camera input. status (1D tensor): Status inputs indicating things like speed. """ out = self.feat(batch) out = self.view(out) if self.n_status: out = torch.cat((out, status), 1) out = self.head(out) return out class StarTree(torch.nn.Module): """ A medium-sized model that uses layers with few activation maps to efficiently increase the number of layers, and therefore nonlinearities. """ def __init__(self, in_dim, n_status, n_out): """ Args: in_dim (list): The input size of each example n_status (int): Number of status inputs to add n_out (int): Number of values to predict """ super(StarTree, self).__init__() self.n_status = n_status dim = in_dim.copy() self.feat = torch.nn.Sequential( ConvBlock(dim, 64, dropout=0.25), ConvBlock(dim, 16), ConvBlock(dim, 32), PoolBlock(dim, "max", 2), ConvBlock(dim, 24), ConvBlock(dim, 48), PoolBlock(dim, "max", 2), ConvBlock(dim, 32), ConvBlock(dim, 64), PoolBlock(dim, "max", 2), ConvBlock(dim, 40), ConvBlock(dim, 80, dropout=0.25), PoolBlock(dim, "max", 2), ) self.view = ViewBlock(dim) dim[0] += n_status self.head = torch.nn.Sequential( LinearBlock(dim, 50), LinearBlock(dim, n_out, activation=False), ) self.n_params = sum([x.n_params for x in self.feat]) + sum([x.n_params for x in self.head]) print("StarTree params %9i" % self.n_params) def forward(self, batch, status): """ Args: batch (4D tensor): A batch of camera input. status (1D tensor): Status inputs indicating things like speed. """ out = self.feat(batch) out = self.view(out) if self.n_status: out = torch.cat((out, status), 1) out = self.head(out) return out def train_epoch(device, model, optimizer, criterion, loader): """ Run the optimzer over all batches in an epoch """ model.train() epoch_loss = 0 batch_index = 0 for batch_index, (examples, statuses, labels) in enumerate(loader): optimizer.zero_grad() guesses = model(examples.to(device), statuses.to(device)) loss = criterion(guesses, labels.to(device)) loss.backward() optimizer.step() epoch_loss += loss.item() return epoch_loss / (batch_index + 1) def test_epoch(device, model, criterion, loader): """ Run the evaluator over all batches in an epoch """ model.eval() epoch_loss = 0 batch_index = 0 with torch.no_grad(): for batch_index, (examples, statuses, labels) in enumerate(loader): guesses = model(examples.to(device), statuses.to(device)) loss = criterion(guesses, labels.to(device)) epoch_loss += loss.item() return epoch_loss / (batch_index + 1) def compose_transforms(transform_config): """ Apply all image transforms """ transform_list = [] for perturb_config in transform_config: if perturb_config["name"] == "colorjitter": transform = transforms.ColorJitter( brightness=perturb_config["brightness"], contrast=perturb_config["contrast"], saturation=perturb_config["saturation"], hue=perturb_config["hue"], ) transform_list.append(transform) transform_list.append(transforms.ToTensor()) return transforms.Compose(transform_list)