#!/usr/bin/env python # -*- coding: utf-8 -*- ''' Training script for semantic segmentation ''' import math import datetime import time from pathlib import Path from fastprogress import master_bar, progress_bar import matplotlib.pyplot as plt import torch import torch.utils.data from torchvision import transforms from torchvision.datasets import VOCSegmentation from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.transforms import functional as F import holocron from transforms import (Compose, Resize, ImageTransform, CenterCrop, RandomResizedCrop, RandomHorizontalFlip, convert_to_relative) VOC_CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] def train_one_batch(model, x, target, optimizer, criterion, device=None): x, target = x.to(device), target.to(device) out = model(x) batch_loss = criterion(out, target) optimizer.zero_grad() batch_loss.backward() optimizer.step() return batch_loss.item() def train_one_epoch(model, optimizer, criterion, scheduler, data_loader, device, master_bar): model.train() for x, target in progress_bar(data_loader, parent=master_bar): x, target = x.to(device), target.to(device) out = model(x) batch_loss = criterion(out, target) optimizer.zero_grad() batch_loss.backward() optimizer.step() if isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR): scheduler.step() master_bar.child.comment = f"Training loss: {batch_loss.item():.4}" def evaluate(model, data_loader, criterion, device, ignore_index=255): model.eval() val_loss, mean_iou = 0, 0 with torch.no_grad(): for x, target in data_loader: x, target = x.to(device), target.to(device) out = model(x) val_loss += criterion(out, target).item() pred = out.argmax(dim=1) tmp_iou, num_seg = 0, 0 for class_idx in torch.unique(target): if class_idx != ignore_index: inter = (pred[target == class_idx] == class_idx).sum().item() tmp_iou += inter / ((pred == class_idx) | (target == class_idx)).sum().item() num_seg += 1 mean_iou += tmp_iou / num_seg val_loss /= len(data_loader) mean_iou /= len(data_loader) return val_loss, mean_iou def load_data(datadir): # Data loading code print("Loading data") normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) base_size = 320 crop_size = 256 min_size = int(0.5 * base_size) max_size = int(2.0 * base_size) print("Loading training data") st = time.time() dataset = VOCSegmentation(datadir, image_set='train', download=True, transforms=Compose([RandomResize(min_size, max_size), RandomHorizontalFlip(0.5), RandomCrop(crop_size), SampleTransform(transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.1, hue=0.02)), ToTensor(), SampleTransform(normalize)])) print("Took", time.time() - st) print("Loading validation data") st = time.time() dataset_test = VOCSegmentation(datadir, image_set='val', download=True, transforms=Compose([RandomResize(base_size, base_size), ToTensor(), SampleTransform(normalize)])) print("Took", time.time() - st) print("Creating data loaders") train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) return dataset, dataset_test, train_sampler, test_sampler def plot_lr_finder(train_batch, model, data_loader, optimizer, criterion, device, start_lr=1e-7, end_lr=1, loss_margin=1e-2): lrs, losses = holocron.utils.lr_finder(train_batch, model, data_loader, optimizer, criterion, device, start_lr=start_lr, end_lr=end_lr, stop_threshold=10, beta=0.95) # Plot Loss vs LR plt.plot(lrs[10:-5], losses[10:-5]) plt.xscale('log') plt.xlabel('Learning Rate') plt.ylabel('Training loss') plt.grid(True, linestyle='--', axis='x') plt.show() def plot_samples(images, targets, ignore_index=None): # Unnormalize image nb_samples = 4 _, axes = plt.subplots(2, nb_samples, figsize=(20, 5)) for idx in range(nb_samples): img = images[idx] img *= torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) img += torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) img = F.to_pil_image(img) target = targets[idx] if isinstance(ignore_index, int): target[target == ignore_index] = 0 axes[0][idx].imshow(img) axes[0][idx].axis('off') axes[1][idx].imshow(target) axes[1][idx].axis('off') plt.show() def main(args): print(args) device = torch.device(args.device) torch.backends.cudnn.benchmark = True dataset, dataset_test, train_sampler, test_sampler = load_data(args.data_path) train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) if args.show_samples: x, target = next(iter(train_loader)) plot_samples(x, target, ignore_index=255) return val_loader = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model") kwargs = {} if args.freeze_backbone: kwargs['norm_layer'] = FrozenBatchNorm2d model = holocron.models.__dict__[args.model](args.pretrained, num_classes=len(classes), in_channels=3, **kwargs) # Backbone freezing if args.freeze_backbone: for p in model.backbone.parameters(): p.requires_grad_(False) model.to(device) if args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) elif args.opt == 'radam': optimizer = holocron.optim.RAdam(model.parameters(), args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) elif args.opt == 'ranger': optimizer = Lookahead(holocron.optim.RAdam(model.parameters(), args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay)) loss_weight = torch.ones(len(classes)) loss_weight[0] = 0.1 criterion = torch.nn.CrossEntropyLoss(weight=loss_weight, ignore_index=255).to(device) if args.lr_finder: plot_lr_finder(train_one_batch, model, train_loader, optimizer, criterion, device, start_lr=1e-7, end_lr=1) return if args.sched == 'plateau': lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, threshold=5e-3) elif args.sched == 'onecycle': lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, total_steps=args.epochs * len(train_loader), cycle_momentum=False, div_factor=25, final_div_factor=25e4) best_loss = math.inf if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 best_loss = checkpoint['val_loss'] if args.test_only: val_loss, mean_iou = evaluate(model, val_loader, criterion, device=device) print(f"Validation loss: {val_loss:.4} | Mean IoU: {mean_iou:.2%}") return print("Start training") start_time = time.time() mb = master_bar(range(args.start_epoch, args.epochs)) for epoch in mb: train_one_epoch(model, optimizer, criterion, lr_scheduler, train_loader, device, mb) val_loss, mean_iou = evaluate(model, val_loader, criterion, device=device) mb.main_bar.comment = f"Epoch {args.start_epoch+epoch+1}/{args.start_epoch+args.epochs}" mb.write(f"Epoch {args.start_epoch+epoch+1}/{args.start_epoch+args.epochs} - " f"Validation loss: {val_loss:.4} | Mean IoU: {mean_iou:.2%}") if args.sched == 'plateau': lr_scheduler.step(val_loss) if val_loss < best_loss: if args.output_dir: print(f"Validation loss decreased {best_loss:.4} --> {val_loss:.4}: saving state...") torch.save(dict(model=model.state_dict(), optimizer=optimizer.state_dict(), lr_scheduler=lr_scheduler.state_dict(), epoch=epoch, val_loss=val_loss), Path(args.output_dir, f"{args.checkpoint}_best_state.pth")) best_loss = val_loss total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) def parse_args(): import argparse parser = argparse.ArgumentParser(description='PyTorch Classification Training', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('data_path', type=str, help='path to dataset folder') parser.add_argument('--model', default='unet3p', help='model') parser.add_argument("--freeze-backbone", dest='freeze_backbone', action='store_true', help="Should the backbone be frozen") parser.add_argument('--device', default='cuda', help='device') parser.add_argument('-b', '--batch-size', default=32, type=int, help='batch size') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers') parser.add_argument('--loss', default='crossentropy', type=str, help='loss') parser.add_argument('--opt', default='adam', type=str, help='optimizer') parser.add_argument('--sched', default='plateau', type=str, help='scheduler') parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') parser.add_argument('--wd', '--weight-decay', default=0, type=float, metavar='W', help='weight decay', dest='weight_decay') parser.add_argument("--lr-finder", dest='lr_finder', action='store_true', help="Should you run LR Finder") parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--checkpoint', default='model', help='checkpoint name') parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument( "--test-only", dest="test_only", help="Only test the model", action="store_true", ) parser.add_argument( "--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", action="store_true", ) args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() main(args)