pytorch-UNet

A tunable implementation of U-Net in PyTorch.

About U-Net

U-Net is a powerful encoder-decoder CNN architecture for semantic segmentation, developed by Olaf Ronneberger, Philipp Fischer and Thomas Brox. It has won several competitions, for example the ISBI Cell Tracking Challenge 2015 or the Kaggle Data Science Bowl 2018.

An example image from the Kaggle Data Science Bowl 2018:

This repository was created to

  1. provide a reference implementation of 2D and 3D U-Net in PyTorch,
  2. allow fast prototyping and hyperparameter tuning by providing an easily parametrizable model.

In essence, the U-Net is built up using encoder and decoder blocks, each of them consisting of convolutional and pooling layers. With this implementation, you can build your U-Net using the First, Encoder, Center, Decoder and Last blocks, controlling the complexity and the number of these blocks. (Because the first, last and the middle of these blocks are somewhat special, they require their own class.)

WARNING! The 3D U-Net implementation is currently untested!

U-Net quickstart

The simplest way to use the implemented U-Net is with the provided train.py and predict.py scripts.

Training

For training, train.py should be used, where the required arguments are

Optional arguments:

Predicting

For prediction, the predict.py script should be used, where the required arguments are

Customizing the network

As you can see on this figure, the U-Net architecture is basically made from convolution blocks. In the original architecture, the flow looks like

1 → 64 → 128 → 256 → 512 → 1024 (channels)
1024 → 512 → 256 → 128 → 64 → 1 (channels).

This is quite arbitrary and it might not be the best architecture for your problem. With the implementation provided in this repository, this can be changed quickly without requiring you to tweak the code, as you'll see in the next section.

The UNet2D object

The 2D U-Net architecture is implemented by the unet.unet.UNet2D class. It accepts the following arguments during initialization:

Utilities for training the model

To save time with writing the usual boilerplate PyTorch code for training, a dataset generator and a simple wrapper is provided.

Wrapper for training and inference

The wrapper is implemented in the unet.model.Model object. Upon initialization, you are required to provide the following arguments:

Optional arguments are:

To train the model, the .fit_dataset() method can be used. For details on how to use it, see its docstring. To do this, you'll need to use the unet.dataset.ImageToImage2D dataset generator, which is described in the next section.

Datasets and augmentation transforms

For training the U-Net, simple classes for augmentations and dataset input is implemented. The joint augmentation transform for image and mask is implemented in unet.dataset.JointTransform2D. This transform is used by the unet.dataset.ImageToImage2D. For more details on their usage, see their corresponding docstrings.

Experiments with U-Net

To get a good grip on U-Net and how it depends on hyperparameters, I have made a simple experiment using the dataset from the Kaggle Data Science Bowl 2018, which aims to find cell nuclei in microscopy images. Although the goal of the competition was instance based segmentation which is not exactly the proper use of U-Net, it actually won the race with some really clever tricks. (For details, see this post by the winner team, explaining what they did in detail.)

For simplicity, the following experiments are focused on a simplified problem: segmenting out nuclei from the background, disregarding the differences between instances of nuclei.

The Kaggle Data Science Bowl 2018 nuclei detection challenge dataset

If you would like to play around with the data, you can download the images from here. Since the ground truth masks are given for each instance, we need some preprocessing. This can be done with the provided script kaggle_dsb18_preprocessing.py, in the kaggle_dsb18 folder. It requires two arguments:

The images in this dataset can be subdivided further: fluorescent images, brightfield images and histopathological images containing tissue. If you also want to make this split, you can find the corresponding image names in the kaggle_dsb18 folder.