import os import torch import torch.nn as nn import numpy as np def accuracy(logits, y_true): """ :param logits: np.ndarray, output of the model :param y_true: np.ndarray """ predictions = np.argmax(logits, axis=1) correct_samples = np.sum(predictions == y_true) total_samples = y_true.shape[0] return float(correct_samples) / total_samples def compute_accuracy(model, loader): """ :param model: a model which returns classifier_output and segmentator_output :param loader: data loader """ model.eval() # enter evaluation mode score_accum = 0 count = 0 for x, y, _, _ in loader: classifier_output, _ = model(x) score_accum += accuracy(classifier_output.data.cpu().numpy(), y.data.cpu().numpy()) * y.shape[0] count += y.shape[0] return float(score_accum / count) def iou(logits, y_true, smooth=1e-2): """ :param logits: np.ndarray, output of the model :param y_true: np.ndarray :param smooth: float """ batch_size, channels, samples = logits.shape values = np.zeros(channels) for i in range(channels): pred = logits[:, i, :] > 0.5 gt = y_true[:, i, :].astype(np.bool) intersection = (pred & gt).sum(axis=1) union = (pred | gt).sum(axis=1) values[i] = np.mean((intersection + smooth) / (union + smooth)) return np.mean(values) def compute_iou(model, loader): """ Computes intersection over union on the dataset wrapped in a loader :param model: a model which returns classifier_output and segmentator_output :param loader: data loader returns: IoU (jaccard index) for integration and intersection masks """ model.eval() # Evaluation mode integration_score = [] intersection_score = [] for x, _, integration_mask, intersection_mask in loader: _, segmentator_output = model(x) predicted_integration_mask = segmentator_output[:, 0, :].data.cpu().numpy() predicted_intersection_mask = segmentator_output[:, 1, :].data.cpu().numpy() integration_mask = integration_mask.data.cpu().numpy() intersection_mask = intersection_mask.data.cpu().numpy() integration_score.append(iou(predicted_integration_mask, integration_mask)) intersection_score.append(iou(predicted_intersection_mask, intersection_mask)) return np.mean(integration_score), np.mean(intersection_score) class WeightedBCE: def __init__(self, weights=None): self.weights = weights self.logsigmoid = nn.LogSigmoid() def __call__(self, output, target): if self.weights is not None: assert len(self.weights) == 2 loss = self.weights[1] * (target * self.logsigmoid(output)) + \ self.weights[0] * ((1 - target) * self.logsigmoid(-output)) else: loss = target * self.logsigmoid(output) + (1 - target) * self.logsigmoid(-output) return torch.neg(torch.mean(loss)) class DiceLoss: def __init__(self, smooth=1e-2): self.smooth = smooth def __call__(self, output, target): output = output.sigmoid() numerator = torch.sum(output * target, dim=1) denominator = torch.sum(torch.sqrt(output) + target, dim=1) return 1 - torch.mean((2 * numerator + self.smooth) / (denominator + self.smooth)) class CombinedLoss: def __init__(self, weights=None): self.dice = DiceLoss() self.bce = WeightedBCE(weights) def __call__(self, output, target): return self.dice(output, target) + self.bce(output, target) def train_model(model, loader, val_loader, optimizer, num_epoch, print_epoch=10, classification_metric=None, segmentation_metric=None, scheduler=None, label_criterion=None, integration_criterion=None, intersection_criterion=None, accumulation=1, loss_ax=None, classification_score_ax=None, segmentation_score_ax=None, figure=None, canvas=None): loss_history = [] train_classification_score_history = [] train_segmentation_score_history = [] val_classification_score_history = [] val_segmentation_score_history = [] best_score = 0 for epoch in range(num_epoch): model.train() # enter train mode loss_accum = 0 classification_score_accum = 0 segemntation_score_accum = 0 count = 0 step = 0 for x, y, integration_mask, intersection_mask in loader: classifier_output, integrator_output = model(x) # classifier_output = classifier_output.view(1, -1) # calculate loss and gradients loss = torch.tensor(0, dtype=torch.float32, device=x.device) if label_criterion is not None: loss = loss + label_criterion(classifier_output, y) if integration_criterion is not None: loss = loss + integration_criterion(integrator_output[:, 0, :], integration_mask) if intersection_criterion is not None: loss = loss + intersection_criterion(integrator_output[:, 1, :], intersection_mask) loss = loss / accumulation loss.backward() step += 1 if step == accumulation: # accumulate loss over few batches optimizer.step() optimizer.zero_grad() step = 0 if classification_metric is not None: classification_score_accum += classification_metric(classifier_output.detach().cpu().numpy(), y.detach().cpu().numpy()) * len(y) if segmentation_metric is not None: gt = np.stack((integration_mask.data.cpu().numpy(), intersection_mask.data.cpu().numpy())).transpose(1, 0, 2) segemntation_score_accum += segmentation_metric(integrator_output.detach().cpu().sigmoid().numpy(), gt) * len(y) loss_accum += loss count += len(y) loss_history.append(float(loss_accum / count)) # average loss over epoch train_classification_score_history.append(float(classification_score_accum / count)) train_segmentation_score_history.append(float(segemntation_score_accum / count)) model.eval() # enter evaluation mode classification_score_accum = 0 segemntation_score_accum = 0 count = 0 for x, y, integration_mask, intersection_mask in val_loader: classifier_output, integrator_output = model(x) if classification_metric is not None: classification_score_accum += classification_metric(classifier_output.detach().cpu().numpy(), y.detach().cpu().numpy()) * len(y) if segmentation_metric is not None: gt = np.stack((integration_mask.data.cpu().numpy(), intersection_mask.data.cpu().numpy())).transpose(1, 0, 2) segemntation_score_accum += segmentation_metric(integrator_output.detach().cpu().sigmoid().numpy(), gt) * len(y) count += len(y) val_classification_score_history.append(float(classification_score_accum / count)) val_segmentation_score_history.append(float(segemntation_score_accum / count)) # save best model based on classification score (if it is not None) if classification_metric is not None and segmentation_metric is not None: if best_score < val_classification_score_history[-1] * val_segmentation_score_history[-1]: best_score = val_classification_score_history[-1] * val_segmentation_score_history[-1] torch.save(model.state_dict(), os.path.join('data/tmp_weights', model.__class__.__name__)) # save best model elif classification_metric is not None: if best_score < val_classification_score_history[-1]: best_score = val_classification_score_history[-1] torch.save(model.state_dict(), os.path.join('data/tmp_weights', model.__class__.__name__)) # save best model elif segmentation_metric is not None: if best_score < val_segmentation_score_history[-1]: best_score = val_segmentation_score_history[-1] torch.save(model.state_dict(), os.path.join('data/tmp_weights', model.__class__.__name__)) # save best model if scheduler: scheduler.step() if not epoch % print_epoch or epoch == num_epoch - 1: print('Epoch #{}, train loss: {:.4f}'.format( epoch, loss_history[-1])) if classification_metric is not None: print('Train classification score: {:.4f}, val classificiation score: {:.4f}'.format( train_classification_score_history[-1], val_classification_score_history[-1] )) if segmentation_metric is not None: print('Train segmentation score: {:.4f}, val segmentation score: {:.4f}'.format( train_segmentation_score_history[-1], val_segmentation_score_history[-1] )) # visualization if loss_ax is not None: loss_ax.clear() loss_ax.plot(loss_history) loss_ax.set_title('Loss function') if classification_score_ax is not None: classification_score_ax.clear() classification_score_ax.plot(train_classification_score_history, label='train') classification_score_ax.plot(val_classification_score_history, label='validation') classification_score_ax.legend(loc='best') classification_score_ax.set_title('Classification score') if segmentation_score_ax is not None: segmentation_score_ax.clear() segmentation_score_ax.plot(train_segmentation_score_history, label='train') segmentation_score_ax.plot(val_segmentation_score_history, label='validation') segmentation_score_ax.legend(loc='best') segmentation_score_ax.set_title('Segmentation score') if figure is not None: figure.tight_layout() if canvas is not None: canvas.draw() return (loss_history, train_classification_score_history, train_segmentation_score_history, val_classification_score_history, val_segmentation_score_history)