import time import os import pickle import numpy as np import torch import torch.nn as nn import torch.nn.init as init import torch.utils.data as data import torch.nn.functional as F def normalize_to_unit_interval(x, normalize_by = None): if normalize_by == None: max_val = np.max(x) return x / max_val, max_val else: return x / normalize_by, normalize_by def load_CIFAR_batch(filename): """ load single batch of cifar """ with open(filename, 'rb') as f: datadict = pickle.load(f, encoding='latin1') X = datadict['data'] Y = datadict['labels'] X = X.reshape(10000, 3, 32, 32).astype('float32') Y = np.array(Y) return X, Y def load_dataset(dataset, num_training_samples=-1): if dataset == 'cifar-10': datapath = 'data/cifar-10-batches-py' xs = [] ys = [] for b in range(1,6): f = os.path.join(datapath, 'data_batch_%d' % (b, )) X, Y = load_CIFAR_batch(f) xs.append(X) ys.append(Y) x_all = np.concatenate(xs) y_all = np.concatenate(ys) x_train = x_all[:-5000] x_val = x_all[-5000:] y_train = y_all[:-5000] y_val = y_all[-5000:] del X, Y x_test, y_test = load_CIFAR_batch(os.path.join(datapath, 'test_batch')) x_train, normalize_by = normalize_to_unit_interval(x_train) x_val, _ = normalize_to_unit_interval(x_val, normalize_by) x_test, _ = normalize_to_unit_interval(x_test, normalize_by) else: raise ValueError('Import for the dataset you have provided is not yet implemented.') if num_training_samples > 0: x_train = x_train[:num_training_samples] y_train = y_train[:num_training_samples] return x_train, y_train, x_val, y_val, x_test, y_test def compute_loss_and_accuracy(model, loss_fn, x, y, batch_size=64, device='cuda'): num_samples = x.shape[0] data = torch.utils.data.TensorDataset(torch.from_numpy(x), torch.from_numpy(y)) loader = torch.utils.data.DataLoader(dataset=data, batch_size=batch_size, shuffle=False) correct_samples = 0 loss = 0. num_batches = 0 model.train(False) with torch.no_grad(): for sample_x, sample_y in loader: num_batches += 1 sample_x = sample_x.to(device) sample_y = sample_y.to(device) sample_out = model(sample_x) loss += loss_fn(sample_out, sample_y).item() _, y_pred = sample_out.max(dim=1) correct_samples += sample_y.numel() - torch.nonzero(y_pred - sample_y).numel() acc = float(correct_samples) / float(num_samples) loss = float(loss) / float(num_batches) model.train(True) return loss, acc def train(model, loss_fn, optimizer, data, num_epochs, batch_size, scheduler=None, device='cuda'): x_train, y_train, x_val, y_val, x_test, y_test = data train_data = torch.utils.data.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True) start_time = time.time() training_metrics = {} training_metrics['minibatch_avg_loss'] = [] training_metrics['full_batch_loss'] = [] training_metrics['full_batch_acc'] = [] training_metrics['val_loss'] = [] training_metrics['val_acc'] = [] training_metrics['epoch_time'] = [] #initial loss and accuracies epoch_full_batch_loss, epoch_full_batch_training_acc = compute_loss_and_accuracy(model, loss_fn, x_train, y_train, batch_size=batch_size, device=device) val_loss, val_acc = compute_loss_and_accuracy(model, loss_fn, x_val, y_val, batch_size=batch_size, device=device) training_metrics['minibatch_avg_loss'].append(epoch_full_batch_loss) training_metrics['full_batch_loss'].append(epoch_full_batch_loss) training_metrics['full_batch_acc'].append(epoch_full_batch_training_acc) training_metrics['val_loss'].append(val_loss) training_metrics['val_acc'].append(val_acc) training_metrics['epoch_time'].append(0) for epoch in range(num_epochs): epoch_start_time = time.time() epoch_summed_loss = 0 epoch_batch_counter = 0 for x_batch, y_batch in train_loader: x = x_batch.to(device) y = y_batch.to(device) x_out= model(x) loss = loss_fn(x_out, y) optimizer.zero_grad() loss.backward() if scheduler is not None: scheduler.step(epoch=epoch) optimizer.step() epoch_summed_loss += loss.item() epoch_batch_counter += 1 epoch_end_time = time.time() epoch_time = epoch_end_time - epoch_start_time training_metrics['epoch_time'].append(epoch_time) epoch_full_batch_loss, epoch_full_batch_training_acc = compute_loss_and_accuracy(model, loss_fn, x_train, y_train, batch_size=batch_size, device=device) val_loss, val_acc = compute_loss_and_accuracy(model, loss_fn, x_val, y_val, device=device) epoch_avg_loss = epoch_summed_loss / epoch_batch_counter training_metrics['minibatch_avg_loss'].append(epoch_avg_loss) training_metrics['full_batch_loss'].append(epoch_full_batch_loss) training_metrics['full_batch_acc'].append(epoch_full_batch_training_acc) training_metrics['val_loss'].append(val_loss) training_metrics['val_acc'].append(val_acc) print(str(epoch) + ': mini batch avg loss: ' + str(epoch_avg_loss) + ', full batch loss: ' + str(epoch_full_batch_loss) + ', epoch time: ' + str(epoch_time) + 's') print('Trained in {0} seconds.'.format(int(time.time() - start_time))) test_loss, test_acc = compute_loss_and_accuracy(model, loss_fn, x_test, y_test, batch_size=256, device=device) training_metrics['test_loss_acc'] = (test_loss, test_acc) print('Avg. test loss: {0}, avg. test accuracy: {1}'.format(test_loss, test_acc)) return training_metrics