import numpy as np from tqdm import tqdm import shutil import torch from torch.backends import cudnn from torch.autograd import Variable from graphs.models.erfnet import ERF from graphs.models.erfnet_imagenet import ERFNet from datasets.voc2012 import VOCDataLoader from graphs.losses.cross_entropy import CrossEntropyLoss from torch.optim import lr_scheduler from tensorboardX import SummaryWriter from utils.metrics import AverageMeter, IOUMetric from utils.misc import print_cuda_statistics from agents.base import BaseAgent cudnn.benchmark = True class ERFNetAgent(BaseAgent): """ This class will be responsible for handling the whole process of our architecture. """ def __init__(self, config): super().__init__(config) # Create an instance from the Model self.logger.info("Loading encoder pretrained in imagenet...") if self.config.pretrained_encoder: pretrained_enc = torch.nn.DataParallel(ERFNet(self.config.imagenet_nclasses)).cuda() pretrained_enc.load_state_dict(torch.load(self.config.pretrained_model_path)['state_dict']) pretrained_enc = next(pretrained_enc.children()).features.encoder else: pretrained_enc = None # define erfNet model self.model = ERF(self.config, pretrained_enc) # Create an instance from the data loader self.data_loader = VOCDataLoader(self.config) # Create instance from the loss self.loss = CrossEntropyLoss(self.config) # Create instance from the optimizer self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.learning_rate, betas=(self.config.betas[0], self.config.betas[1]), eps=self.config.eps, weight_decay=self.config.weight_decay) # Define Scheduler lambda1 = lambda epoch: pow((1 - ((epoch - 1) / self.config.max_epoch)), 0.9) self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda1) # initialize my counters self.current_epoch = 0 self.current_iteration = 0 self.best_valid_mean_iou = 0 # Check is cuda is available or not self.is_cuda = torch.cuda.is_available() # Construct the flag and make sure that cuda is available self.cuda = self.is_cuda & self.config.cuda if self.cuda: torch.cuda.manual_seed_all(self.config.seed) self.device = torch.device("cuda") torch.cuda.set_device(self.config.gpu_device) self.logger.info("Operation will be on *****GPU-CUDA***** ") print_cuda_statistics() else: self.device = torch.device("cpu") torch.manual_seed(self.config.seed) self.logger.info("Operation will be on *****CPU***** ") self.model = self.model.to(self.device) self.loss = self.loss.to(self.device) # Model Loading from the latest checkpoint if not found start from scratch. self.load_checkpoint(self.config.checkpoint_file) # Tensorboard Writer self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir, comment='FCN8s') # # scheduler for the optimizer # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, # 'min', patience=self.config.learning_rate_patience, # min_lr=1e-10, verbose=True) def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0): """ Saving the latest checkpoint of the training :param filename: filename which will contain the state :param is_best: flag is it is the best model :return: """ state = { 'epoch': self.current_epoch + 1, 'iteration': self.current_iteration, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), } # Save the state torch.save(state, self.config.checkpoint_dir + filename) # If it is the best copy it to another file 'model_best.pth.tar' if is_best: shutil.copyfile(self.config.checkpoint_dir + filename, self.config.checkpoint_dir + 'model_best.pth.tar') def load_checkpoint(self, filename): filename = self.config.checkpoint_dir + filename try: self.logger.info("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.current_epoch = checkpoint['epoch'] self.current_iteration = checkpoint['iteration'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n" .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration'])) except OSError as e: self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir)) self.logger.info("**First time to train**") def run(self): """ This function will the operator :return: """ assert self.config.mode in ['train', 'test', 'random'] try: if self.config.mode == 'test': self.test() else: self.train() except KeyboardInterrupt: self.logger.info("You have entered CTRL+C.. Wait to finalize") def train(self): """ Main training function, with per-epoch model saving """ for epoch in range(self.current_epoch, self.config.max_epoch): self.current_epoch = epoch self.scheduler.step(epoch) self.train_one_epoch() valid_mean_iou, valid_loss = self.validate() self.scheduler.step(valid_loss) is_best = valid_mean_iou > self.best_valid_mean_iou if is_best: self.best_valid_mean_iou = valid_mean_iou self.save_checkpoint(is_best=is_best) def train_one_epoch(self): """ One epoch training function """ # Initialize tqdm tqdm_batch = tqdm(self.data_loader.train_loader, total=self.data_loader.train_iterations, desc="Epoch-{}-".format(self.current_epoch)) # Set the model to be in training mode (for batchnorm) self.model.train() # Initialize your average meters epoch_loss = AverageMeter() metrics = IOUMetric(self.config.num_classes) for x, y in tqdm_batch: if self.cuda: x, y = x.pin_memory().cuda(async=self.config.async_loading), y.cuda(async=self.config.async_loading) x, y = Variable(x), Variable(y) # model pred = self.model(x) # loss cur_loss = self.loss(pred, y) if np.isnan(float(cur_loss.item())): raise ValueError('Loss is nan during training...') # optimizer self.optimizer.zero_grad() cur_loss.backward() self.optimizer.step() epoch_loss.update(cur_loss.item()) _, pred_max = torch.max(pred, 1) metrics.add_batch(pred_max.data.cpu().numpy(), y.data.cpu().numpy()) self.current_iteration += 1 # exit(0) epoch_acc, _, epoch_iou_class, epoch_mean_iou, _ = metrics.evaluate() self.summary_writer.add_scalar("epoch-training/loss", epoch_loss.val, self.current_iteration) self.summary_writer.add_scalar("epoch_training/mean_iou", epoch_mean_iou, self.current_iteration) tqdm_batch.close() print("Training Results at epoch-" + str(self.current_epoch) + " | " + "loss: " + str( epoch_loss.val) + " - acc-: " + str( epoch_acc) + "- mean_iou: " + str(epoch_mean_iou) + "\n iou per class: \n" + str( epoch_iou_class)) def validate(self): """ One epoch validation :return: """ tqdm_batch = tqdm(self.data_loader.valid_loader, total=self.data_loader.valid_iterations, desc="Valiation at -{}-".format(self.current_epoch)) # set the model in training mode self.model.eval() epoch_loss = AverageMeter() metrics = IOUMetric(self.config.num_classes) for x, y in tqdm_batch: if self.cuda: x, y = x.pin_memory().cuda(async=self.config.async_loading), y.cuda(async=self.config.async_loading) x, y = Variable(x), Variable(y) # model pred = self.model(x) # loss cur_loss = self.loss(pred, y) if np.isnan(float(cur_loss.item())): raise ValueError('Loss is nan during Validation.') _, pred_max = torch.max(pred, 1) metrics.add_batch(pred_max.data.cpu().numpy(), y.data.cpu().numpy()) epoch_loss.update(cur_loss.item()) epoch_acc, _, epoch_iou_class, epoch_mean_iou, _ = metrics.evaluate() self.summary_writer.add_scalar("epoch_validation/loss", epoch_loss.val, self.current_iteration) self.summary_writer.add_scalar("epoch_validation/mean_iou", epoch_mean_iou, self.current_iteration) print("Validation Results at epoch-" + str(self.current_epoch) + " | " + "loss: " + str( epoch_loss.val) + " - acc-: " + str( epoch_acc) + "- mean_iou: " + str(epoch_mean_iou) + "\n iou per class: \n" + str( epoch_iou_class)) tqdm_batch.close() return epoch_mean_iou, epoch_loss.val def test(self): # TODO pass def finalize(self): """ Finalize all the operations of the 2 Main classes of the process the operator and the data loader :return: """ print("Please wait while finalizing the operation.. Thank you") self.save_checkpoint() self.summary_writer.export_scalars_to_json("{}all_scalars.json".format(self.config.summary_dir)) self.summary_writer.close() self.data_loader.finalize()