import torch from torch.nn.utils import clip_grad_norm_ torch.multiprocessing.set_sharing_strategy('file_system') import pandas as pd import numpy as np from tqdm import tqdm from pathlib import Path import heapq from collections import defaultdict class Learning(): def __init__(self, optimizer, binarizer_fn, loss_fn, eval_fn, device, n_epoches, scheduler, freeze_model, grad_clip, grad_accum, early_stopping, validation_frequency, calculation_name, best_checkpoint_folder, checkpoints_history_folder, checkpoints_topk, logger ): self.logger = logger self.optimizer = optimizer self.binarizer_fn = binarizer_fn self.loss_fn = loss_fn self.eval_fn = eval_fn self.device = device self.n_epoches = n_epoches self.scheduler = scheduler self.freeze_model = freeze_model self.grad_clip = grad_clip self.grad_accum = grad_accum self.early_stopping = early_stopping self.validation_frequency = validation_frequency self.calculation_name = calculation_name self.best_checkpoint_path = Path( best_checkpoint_folder, '{}.pth'.format(self.calculation_name) ) self.checkpoints_history_folder = Path(checkpoints_history_folder) self.checkpoints_topk = checkpoints_topk self.score_heap = [] self.summary_file = Path(self.checkpoints_history_folder, 'summary.csv') if self.summary_file.is_file(): self.best_score = pd.read_csv(self.summary_file).best_metric.max() logger.info('Pretrained best score is {:.5}'.format(self.best_score)) else: self.best_score = 0 self.best_epoch = -1 def train_epoch(self, model, loader): tqdm_loader = tqdm(loader) current_loss_mean = 0 for batch_idx, (imgs, labels) in enumerate(tqdm_loader): loss, predicted = self.batch_train(model, imgs, labels, batch_idx) # just slide average current_loss_mean = (current_loss_mean * batch_idx + loss) / (batch_idx + 1) tqdm_loader.set_description('loss: {:.4} lr:{:.6}'.format( current_loss_mean, self.optimizer.param_groups[0]['lr'])) return current_loss_mean def batch_train(self, model, batch_imgs, batch_labels, batch_idx): batch_imgs, batch_labels = batch_imgs.to(self.device), batch_labels.to(self.device) predicted = model(batch_imgs) loss = self.loss_fn(predicted, batch_labels) loss.backward() if batch_idx % self.grad_accum == self.grad_accum - 1: clip_grad_norm_(model.parameters(), self.grad_clip) self.optimizer.step() self.optimizer.zero_grad() return loss.item(), predicted def valid_epoch(self, model, loader): tqdm_loader = tqdm(loader) current_score_mean = 0 used_thresholds = self.binarizer_fn.thresholds metrics = defaultdict(float) for batch_idx, (imgs, labels) in enumerate(tqdm_loader): with torch.no_grad(): predicted_probas = self.batch_valid(model, imgs) labels = labels.to(self.device) mask_generator = self.binarizer_fn.transform(predicted_probas) for current_thr, current_mask in zip(used_thresholds, mask_generator): current_metric = self.eval_fn(current_mask, labels).item() current_thr = tuple(current_thr) metrics[current_thr] = (metrics[current_thr] * batch_idx + current_metric) / (batch_idx + 1) best_threshold = max(metrics, key=metrics.get) best_metric = metrics[best_threshold] tqdm_loader.set_description('score: {:.5} on {}'.format(best_metric, best_threshold)) return metrics, best_metric def batch_valid(self, model, batch_imgs): batch_imgs = batch_imgs.to(self.device) predicted = model(batch_imgs) predicted = torch.sigmoid(predicted) return predicted def process_summary(self, metrics, epoch): best_threshold = max(metrics, key=metrics.get) best_metric = metrics[best_threshold] epoch_summary = pd.DataFrame.from_dict([metrics]) epoch_summary['epoch'] = epoch epoch_summary['best_metric'] = best_metric epoch_summary = epoch_summary[['epoch', 'best_metric'] + list(metrics.keys())] epoch_summary.columns = [str(col) for col in epoch_summary.columns] self.logger.info('{} epoch: \t Score: {:.5}\t Params: {}'.format(epoch, best_metric, best_threshold)) if not self.summary_file.is_file(): epoch_summary.to_csv(self.summary_file, index=False) else: summary = pd.read_csv(self.summary_file) summary = summary.append(epoch_summary).reset_index(drop=True) summary.to_csv(self.summary_file, index=False) @staticmethod def get_state_dict(model): if type(model) == torch.nn.DataParallel: state_dict = model.module.state_dict() else: state_dict = model.state_dict() return state_dict def post_processing(self, score, epoch, model): if self.freeze_model: return checkpoints_history_path = Path( self.checkpoints_history_folder, '{}_epoch{}.pth'.format(self.calculation_name, epoch) ) torch.save(self.get_state_dict(model), checkpoints_history_path) heapq.heappush(self.score_heap, (score, checkpoints_history_path)) if len(self.score_heap) > self.checkpoints_topk: _, removing_checkpoint_path = heapq.heappop(self.score_heap) removing_checkpoint_path.unlink() self.logger.info('Removed checkpoint is {}'.format(removing_checkpoint_path)) if score > self.best_score: self.best_score = score self.best_epoch = epoch torch.save(self.get_state_dict(model), self.best_checkpoint_path) self.logger.info('best model: {} epoch - {:.5}'.format(epoch, score)) if self.scheduler.__class__.__name__ == 'ReduceLROnPlateau': self.scheduler.step(score) else: self.scheduler.step() def run_train(self, model, train_dataloader, valid_dataloader): model.to(self.device) for epoch in range(self.n_epoches): if not self.freeze_model: self.logger.info('{} epoch: \t start training....'.format(epoch)) model.train() train_loss_mean = self.train_epoch(model, train_dataloader) self.logger.info('{} epoch: \t Calculated train loss: {:.5}'.format(epoch, train_loss_mean)) if epoch % self.validation_frequency != (self.validation_frequency - 1): self.logger.info('skip validation....') continue self.logger.info('{} epoch: \t start validation....'.format(epoch)) model.eval() metrics, score = self.valid_epoch(model, valid_dataloader) self.process_summary(metrics, epoch) self.post_processing(score, epoch, model) if epoch - self.best_epoch > self.early_stopping: self.logger.info('EARLY STOPPING') break return self.best_score, self.best_epoch