# pylint: disable=E1101,R,C import numpy as np import torch.nn as nn from s2cnn import SO3Convolution from s2cnn import S2Convolution from s2cnn import so3_integrate from s2cnn import so3_near_identity_grid from s2cnn import s2_near_identity_grid import torch.nn.functional as F import torch import torch.utils.data as data_utils import gzip import pickle import numpy as np from torch.autograd import Variable import argparse MNIST_PATH = "s2_mnist.gz" DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") NUM_EPOCHS = 20 BATCH_SIZE = 32 LEARNING_RATE = 5e-3 def load_data(path, batch_size): with gzip.open(path, 'rb') as f: dataset = pickle.load(f) train_data = torch.from_numpy( dataset["train"]["images"][:, None, :, :].astype(np.float32)) train_labels = torch.from_numpy( dataset["train"]["labels"].astype(np.int64)) # TODO normalize dataset # mean = train_data.mean() # stdv = train_data.std() train_dataset = data_utils.TensorDataset(train_data, train_labels) train_loader = data_utils.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_data = torch.from_numpy( dataset["test"]["images"][:, None, :, :].astype(np.float32)) test_labels = torch.from_numpy( dataset["test"]["labels"].astype(np.int64)) test_dataset = data_utils.TensorDataset(test_data, test_labels) test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) return train_loader, test_loader, train_dataset, test_dataset class S2ConvNet_original(nn.Module): def __init__(self): super(S2ConvNet_original, self).__init__() f1 = 20 f2 = 40 f_output = 10 b_in = 30 b_l1 = 10 b_l2 = 6 grid_s2 = s2_near_identity_grid() grid_so3 = so3_near_identity_grid() self.conv1 = S2Convolution( nfeature_in=1, nfeature_out=f1, b_in=b_in, b_out=b_l1, grid=grid_s2) self.conv2 = SO3Convolution( nfeature_in=f1, nfeature_out=f2, b_in=b_l1, b_out=b_l2, grid=grid_so3) self.out_layer = nn.Linear(f2, f_output) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = so3_integrate(x) x = self.out_layer(x) return x class S2ConvNet_deep(nn.Module): def __init__(self, bandwidth=30): super(S2ConvNet_deep, self).__init__() grid_s2 = s2_near_identity_grid(n_alpha=6, max_beta=np.pi/16, n_beta=1) grid_so3_1 = so3_near_identity_grid(n_alpha=6, max_beta=np.pi/16, n_beta=1, max_gamma=2*np.pi, n_gamma=6) grid_so3_2 = so3_near_identity_grid(n_alpha=6, max_beta=np.pi/ 8, n_beta=1, max_gamma=2*np.pi, n_gamma=6) grid_so3_3 = so3_near_identity_grid(n_alpha=6, max_beta=np.pi/ 4, n_beta=1, max_gamma=2*np.pi, n_gamma=6) grid_so3_4 = so3_near_identity_grid(n_alpha=6, max_beta=np.pi/ 2, n_beta=1, max_gamma=2*np.pi, n_gamma=6) self.convolutional = nn.Sequential( S2Convolution( nfeature_in = 1, nfeature_out = 8, b_in = bandwidth, b_out = bandwidth, grid=grid_s2), nn.ReLU(inplace=False), SO3Convolution( nfeature_in = 8, nfeature_out = 16, b_in = bandwidth, b_out = bandwidth//2, grid=grid_so3_1), nn.ReLU(inplace=False), SO3Convolution( nfeature_in = 16, nfeature_out = 16, b_in = bandwidth//2, b_out = bandwidth//2, grid=grid_so3_2), nn.ReLU(inplace=False), SO3Convolution( nfeature_in = 16, nfeature_out = 24, b_in = bandwidth//2, b_out = bandwidth//4, grid=grid_so3_2), nn.ReLU(inplace=False), SO3Convolution( nfeature_in = 24, nfeature_out = 24, b_in = bandwidth//4, b_out = bandwidth//4, grid=grid_so3_3), nn.ReLU(inplace=False), SO3Convolution( nfeature_in = 24, nfeature_out = 32, b_in = bandwidth//4, b_out = bandwidth//8, grid=grid_so3_3), nn.ReLU(inplace=False), SO3Convolution( nfeature_in = 32, nfeature_out = 64, b_in = bandwidth//8, b_out = bandwidth//8, grid=grid_so3_4), nn.ReLU(inplace=False) ) self.linear = nn.Sequential( # linear 1 nn.BatchNorm1d(64), nn.Linear(in_features=64,out_features=64), nn.ReLU(inplace=False), # linear 2 nn.BatchNorm1d(64), nn.Linear(in_features=64, out_features=32), nn.ReLU(inplace=False), # linear 3 nn.BatchNorm1d(32), nn.Linear(in_features=32, out_features=10) ) def forward(self, x): x = self.convolutional(x) x = so3_integrate(x) x = self.linear(x) return x def main(network): train_loader, test_loader, train_dataset, _ = load_data( MNIST_PATH, BATCH_SIZE) if network == 'original': classifier = S2ConvNet_original() elif network == 'deep': classifier = S2ConvNet_deep() else: raise ValueError('Unknown network architecture') classifier.to(DEVICE) print("#params", sum(x.numel() for x in classifier.parameters())) criterion = nn.CrossEntropyLoss() criterion = criterion.to(DEVICE) optimizer = torch.optim.Adam( classifier.parameters(), lr=LEARNING_RATE) for epoch in range(NUM_EPOCHS): for i, (images, labels) in enumerate(train_loader): classifier.train() images = images.to(DEVICE) labels = labels.to(DEVICE) optimizer.zero_grad() outputs = classifier(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() print('\rEpoch [{0}/{1}], Iter [{2}/{3}] Loss: {4:.4f}'.format( epoch+1, NUM_EPOCHS, i+1, len(train_dataset)//BATCH_SIZE, loss.item()), end="") print("") correct = 0 total = 0 for images, labels in test_loader: classifier.eval() with torch.no_grad(): images = images.to(DEVICE) labels = labels.to(DEVICE) outputs = classifier(images) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).long().sum().item() print('Test Accuracy: {0}'.format(100 * correct / total)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--network", help="network architecture to use", default='original', choices=['original', 'deep']) args = parser.parse_args() main(args.network)