import os import shutil import matplotlib import numpy as np import torch import torch.nn.parallel import torch.optim import torch.utils.data from tqdm import tqdm, trange matplotlib.use('Agg') from matplotlib import pyplot as plt from torch.optim.lr_scheduler import CyclicLR def train(model, loader, mixup, epoch, optim, criterion, device, dtype, batch_size, log_interval, child): model.train() correct1, correct5 = 0, 0 enum_load = enumerate(loader) if child else enumerate(tqdm(loader)) for batch_idx, (data, t) in enum_load: data, t = data.to(device=device, dtype=dtype), t.to(device=device) data, target = mixup(data, t) optim.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optim.batch_step() corr = correct(output, t, topk=(1, 5)) correct1 += corr[0] correct5 += corr[1] if batch_idx % log_interval == 0 and not child: tqdm.write( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}. ' 'Top-1 accuracy: {:.2f}%({:.2f}%). ' 'Top-5 accuracy: {:.2f}%({:.2f}%).'.format(epoch, batch_idx, len(loader), 100. * batch_idx / len(loader), loss.item(), 100. * corr[0] / batch_size, 100. * correct1 / (batch_size * (batch_idx + 1)), 100. * corr[1] / batch_size, 100. * correct5 / (batch_size * (batch_idx + 1)))) return loss.item(), correct1 / len(loader.sampler), correct5 / len(loader.sampler) def test(model, loader, criterion, device, dtype, child): model.eval() test_loss = 0 correct1, correct5 = 0, 0 enum_load = enumerate(loader) if child else enumerate(tqdm(loader)) with torch.no_grad(): for batch_idx, (data, target) in enum_load: data, target = data.to(device=device, dtype=dtype), target.to(device=device) output = model(data) test_loss += criterion(output, target).item() # sum up batch loss corr = correct(output, target, topk=(1, 5)) correct1 += corr[0] correct5 += corr[1] test_loss /= len(loader) if not child: tqdm.write( '\nTest set: Average loss: {:.4f}, Top1: {}/{} ({:.2f}%), ' 'Top5: {}/{} ({:.2f}%)'.format(test_loss, int(correct1), len(loader.sampler), 100. * correct1 / len(loader.sampler), int(correct5), len(loader.sampler), 100. * correct5 / len(loader.sampler))) return test_loss, correct1 / len(loader.sampler), correct5 / len(loader.sampler) def correct(output, target, topk=(1,)): """Computes the correct@k for the specified values of k""" maxk = max(topk) _, pred = output.topk(maxk, 1, True, True) pred = pred.t().type_as(target) correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0).item() res.append(correct_k) return res def save_checkpoint(state, is_best, filepath='./', filename='checkpoint{}.pth.tar', local_rank=0, test=False): save_path = os.path.join(filepath, filename.format(local_rank)) best_path = os.path.join(filepath, 'model_best{}.pth.tar'.format(local_rank)) torch.save(state, save_path) if test: torch.load(save_path) if is_best: shutil.copyfile(save_path, best_path) def swa_clr(folder, device): checkpoints = [] for file in os.listdir(folder): if not 'SW' in file: continue fname = os.path.join(folder, file) checkpoints.append(torch.load(fname, map_location=device)) sd = checkpoints[0]['state_dict'].copy() for i, cp in enumerate(checkpoints): for w in cp['state_dict']: sd[w] = i / (i + 1.) * sd[w] + 1. / (i + 1.) * cp['state_dict'][w] return sd def find_bounds_clr(model, loader, optimizer, criterion, device, dtype, min_lr=8e-6, max_lr=8e-5, step_size=2000, mode='triangular', save_path='.'): model.train() correct1, correct5 = 0, 0 scheduler = CyclicLR(optimizer, base_lr=min_lr, max_lr=max_lr, step_size_up=step_size, mode=mode) epoch_count = step_size // len(loader) # Assuming step_size is multiple of batch per epoch accuracy = [] for _ in trange(epoch_count): for batch_idx, (data, target) in enumerate(tqdm(loader)): if scheduler is not None: scheduler.step() data, target = data.to(device=device, dtype=dtype), target.to(device=device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() corr = correct(output, target) accuracy.append(corr[0] / data.shape[0]) lrs = np.linspace(min_lr, max_lr, step_size) plt.plot(lrs, accuracy) plt.show() plt.savefig(os.path.join(save_path, 'find_bounds_clr.pdf')) np.save(os.path.join(save_path, 'acc.npy'), accuracy) return