import os import time import itertools import numpy as np from PIL import Image from tqdm import tqdm import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.backends.cudnn as cudnn # custom modules from schedulers import get_scheduler from optimizers import get_optimizer from networks import get_aux_net from utils.metrics import AverageMeter from utils.utils import to_device # summary from tensorboardX import SummaryWriter class AuxModel: def __init__(self, args, logger): self.args = args self.logger = logger self.writer = SummaryWriter(args.log_dir) cudnn.enabled = True # set up model self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = get_aux_net(args.network.arch)(aux_classes=args.aux_classes + 1, classes=args.n_classes) self.model = self.model.to(self.device) if args.mode == 'train': # set up optimizer, lr scheduler and loss functions optimizer = get_optimizer(self.args.training.optimizer) optimizer_params = {k: v for k, v in self.args.training.optimizer.items() if k != "name"} self.optimizer = optimizer(self.model.parameters(), **optimizer_params) self.scheduler = get_scheduler(self.optimizer, self.args.training.lr_scheduler) self.class_loss_func = nn.CrossEntropyLoss() self.start_iter = 0 # resume if args.training.resume: self.load(args.model_dir + '/' + args.training.resume) cudnn.benchmark = True elif args.mode == 'val': self.load(os.path.join(args.model_dir, args.validation.model)) else: self.load(os.path.join(args.model_dir, args.testing.model)) def entropy_loss(self, x): return torch.sum(-F.softmax(x, 1) * F.log_softmax(x, 1), 1).mean() def train(self, src_loader, tar_loader, val_loader, test_loader): num_batches = len(src_loader) print_freq = max(num_batches // self.args.training.num_print_epoch, 1) i_iter = self.start_iter start_epoch = i_iter // num_batches num_epochs = self.args.training.num_epochs best_acc = 0 for epoch in range(start_epoch, num_epochs): self.model.train() batch_time = AverageMeter() losses = AverageMeter() # adjust learning rate self.scheduler.step() for it, (src_batch, tar_batch) in enumerate(zip(src_loader, itertools.cycle(tar_loader))): t = time.time() self.optimizer.zero_grad() if isinstance(src_batch, list): src = src_batch[0] # data, dataset_idx else: src = src_batch src = to_device(src, self.device) src_imgs = src['images'] src_cls_lbls = src['class_labels'] src_aux_lbls = src['aux_labels'] self.optimizer.zero_grad() src_aux_logits, src_class_logits = self.model(src_imgs) src_aux_loss = self.class_loss_func(src_aux_logits, src_aux_lbls) # If true, the network will only try to classify the non scrambled images if self.args.training.only_non_scrambled: src_class_loss = self.class_loss_func( src_class_logits[src_aux_lbls == 0], src_cls_lbls[src_aux_lbls == 0]) else: src_class_loss = self.class_loss_func(src_class_logits, src_cls_lbls) tar = to_device(tar_batch, self.device) tar_imgs = tar['images'] tar_aux_lbls = tar['aux_labels'] tar_aux_logits, tar_class_logits = self.model(tar_imgs) tar_aux_loss = self.class_loss_func(tar_aux_logits, tar_aux_lbls) tar_entropy_loss = self.entropy_loss(tar_class_logits[tar_aux_lbls==0]) loss = src_class_loss + src_aux_loss * self.args.training.src_aux_weight loss += tar_aux_loss * self.args.training.tar_aux_weight loss += tar_entropy_loss * self.args.training.tar_entropy_weight loss.backward() self.optimizer.step() losses.update(loss.item(), src_imgs.size(0)) # measure elapsed time batch_time.update(time.time() - t) i_iter += 1 if i_iter % print_freq == 0: print_string = 'Epoch {:>2} | iter {:>4} | src_class: {:.3f} | src_aux: {:.3f} | tar_entropy: {:.3f} | tar_aux: {:.3f} |{:4.2f} s/it' self.logger.info(print_string.format(epoch, i_iter, src_aux_loss.item(), src_class_loss.item(), tar_entropy_loss.item(), tar_aux_loss.item(), batch_time.avg)) self.writer.add_scalar('losses/src_class_loss', src_class_loss, i_iter) self.writer.add_scalar('losses/src_aux_loss', src_aux_loss, i_iter) self.writer.add_scalar('losses/tar_entropy_loss', tar_entropy_loss, i_iter) self.writer.add_scalar('losses/tar_aux_loss', tar_aux_loss, i_iter) del loss, src_class_loss, src_aux_loss, tar_aux_loss, tar_entropy_loss del src_aux_logits, src_class_logits del tar_aux_logits, tar_class_logits # validation self.save(self.args.model_dir, i_iter) if val_loader is not None: self.logger.info('validating...') aux_acc, class_acc = self.test(val_loader) self.writer.add_scalar('val/aux_acc', aux_acc, i_iter) self.writer.add_scalar('val/class_acc', class_acc, i_iter) if test_loader is not None: self.logger.info('testing...') aux_acc, class_acc = self.test(test_loader) self.writer.add_scalar('test/aux_acc', aux_acc, i_iter) self.writer.add_scalar('test/class_acc', class_acc, i_iter) if class_acc > best_acc: best_acc = class_acc # todo copy current model to best model self.logger.info('Best testing accuracy: {:.2f} %'.format(best_acc)) self.logger.info('Best testing accuracy: {:.2f} %'.format(best_acc)) self.logger.info('Finished Training.') def save(self, path, i_iter): state = {"iter": i_iter + 1, "model_state": self.model.state_dict(), "optimizer_state": self.optimizer.state_dict(), "scheduler_state": self.scheduler.state_dict(), } save_path = os.path.join(path, 'model_{:06d}.pth'.format(i_iter)) self.logger.info('Saving model to %s' % save_path) torch.save(state, save_path) def load(self, path): checkpoint = torch.load(path) self.seg_model.load_state_dict(checkpoint['model_state']) self.logger.info('Loaded model from: ' + path) if self.args.mode == 'train': self.model.load_state_dict(checkpoint['model_state']) self.optimizer.load_state_dict(checkpoint['optimizer_state']) self.scheduler.load_state_dict(checkpoint['scheduler_state']) self.start_iter = checkpoint['iter'] self.logger.info('Start iter: %d ' % self.start_iter) def test(self, val_loader): val_loader_iterator = iter(val_loader) num_val_iters = len(val_loader) tt = tqdm(range(num_val_iters), total=num_val_iters, desc="Validating") aux_correct = 0 class_correct = 0 total = 0 self.model.eval() with torch.no_grad(): for cur_it in tt: data = next(val_loader_iterator) if isinstance(data, list): data = data[0] # Get the inputs data = to_device(data, self.device) imgs = data['images'] cls_lbls = data['class_labels'] aux_lbls = data['aux_labels'] aux_logits, class_logits = self.model(imgs) _, cls_pred = class_logits.max(dim=1) _, aux_pred = aux_logits.max(dim=1) class_correct += torch.sum(cls_pred == cls_lbls.data) aux_correct += torch.sum(aux_pred == aux_lbls.data) total += imgs.size(0) tt.close() aux_acc = 100 * float(aux_correct) / total class_acc = 100 * float(class_correct) / total self.logger.info('aux acc: {:.2f} %, class_acc: {:.2f} %'.format(aux_acc, class_acc)) return aux_acc, class_acc