Skip to content

Latest commit

 

History

History
99 lines (67 loc) · 13 KB

unet_baseline.md

File metadata and controls

99 lines (67 loc) · 13 KB

← Return to overview

U-Net

We include a lightweight, baseline U-Net to provide a playground environment for participants and kickstart their development cycle. Its goal is to help developers get familiar with the end-to-end pipeline of training a U-Net model for csPCa detection/diagnosis in 3D, encapsulating the trained AI model in a Docker container, and uploading the same to the grand-challenge.org platform as an "algorithm". The U-Net baseline generates quick results with minimal complexity, but does so at the expense of sub-optimal performance and low flexibility in adapting to any other task.

U-Net - Data Preparation

We use the same cross-validation splits for this U-Net, as the nnU-Net. We use the same data preparation/preprocessing pipeline for this U-Net, as the nnU-Net, with two exceptions:

  • Resampling Spatial Resolution: The PI-CAI: Public Training and Development Dataset contains MRI scans acquired using seven different scanners, from two vendors, at three centers. Thus, the spatial resolution of its images vary across different patient exams. For the axial T2W scans, the most common voxel spacing (in mm/voxel) observed is 3.0×0.5×0.5 (43%), followed by 3.6×0.3×0.3 (25%), 3.0×0.342×0.342 (15%) and others (17%). As a naive approach, we simply resample all scans to 3.0×0.5×0.5 mm/voxel.

  • Cropping to Region-of-Interest: We naively assume that the prostate gland is typically located within the centre of every prostate MRI scan. Hence, we take a centre crop of each scan, measuring 20×256×256 voxels in dimensions. Note, this assumption does not hold true for the entirety of the PI-CAI: Public Training and Development Dataset, where the prostate gland is off-center in several cases.

For the data preparation, this means we prepare the data for the baseline semi-supervised U-Net algorithm as follows:

python prepare_data_semi_supervised.py --spacing 3.0 0.5 0.5 --matrix_size 20 256 256

After following all the steps listed under sections 'Folder Structure' and 'Data Preparation', generate overviews of the dataset using the provided plan_overview.py script. Set the paths to the corresponding folders on your system:

python src/picai_baseline/unet/plan_overview.py --task=Task2203_picai_baseline --workdir=/workdir --preprocessed_data_path=/workdir/nnUNet_raw_data/{task} --overviews_path=/workdir/results/UNet/overviews/{task} --splits=picai_pub

This command creates .json-based lists of every scan and its corresponding details (e.g. patient ID, study ID, paths to its imaging and annotation files) used in each split (training or validation split) per fold during 5-fold cross-validation, and stores them in /workdir/results/UNet/overviews/Task2203_picai_baseline. These lists are subsequently used by the U-Net's data loaders during training. For example, lists used to complete the first fold of cross-validation would be: PI-CAI_train-fold-0.json and PI-CAI_val-fold-0.json. To generate overviews for the supervised subset only, use --task=Task2201_picai_baseline and --splits=picai_pub_nnunet.

U-Net - Training and Cross-Validation

The overall framework for training this U-Net has been set up using various modular components from the monai module (e.g. U-Net architecture, template for training) and the batchgenerators module (e.g. data loaders, data augmentation policy as incorporated in the nnU-Net). To train the model, run the following command:

python -u src/picai_baseline/unet/train.py \
  --weights_dir='/workdir/results/UNet/weights/' \
  --overviews_dir='/workdir/results/UNet/overviews/Task2203_picai_baseline' \
  --folds 0 1 2 3 4 --max_threads 6 --enable_da 1 --num_epochs 250 \
  --validate_n_epochs 1 --validate_min_epoch 0

⚠️ If you are running this command inside a Docker container, please make sure that your container has at least 16 GB of shared memory space. Otherwise, you may run into issues with data generators and multithreading, as documented here.

Full list of all available training arguments can be found in train.py. Here is a summary of the arguments used in the command above:

Argument Meaning
weights_dir Directory to store model weights/checkpoints and an overview of performance metrics at train-time.
enable_da Enable data augmentations (simplified policy adapted from the nnU-Net).
overviews_dir Directory from which overview lists are loaded (which define all images and annotations used per split per cross-validation fold).
validate_min_epoch Minimum number of epochs after which evaluation is performed using the validation split. Performance metrics and model weights are stored after this point.
max_threads Number of CPU threads/cores to be used to parallelize data loaders.
validate_n_epochs Number of epochs that define the waiting period between two consecutive rounds of evaluation.
focal_loss_gamma Value of the gamma parameter in the focal loss (FL) function used at train-time. When gamma is set to 0, FL reduces down to weighted cross-entropy loss.
folds Cross-validation folds to be completed sequentially during training. E.g. --folds 0 for a single fold, or --folds 0 1 2 3 4 for all five folds
num_epochs Number of epochs that define the total training period.

Additionally, note:

U-Net - Inference Algorithm Submission to Grand Challenge

Once training is complete, there should be a single model checkpoint file (in .pt format) per fold, stored in the weights_dir that was specified at train-time. If the default command (noted in section 'U-Net - Training and Cross-Validation') is used, then one of these should be '/workdir/results/UNet/weights/unet_F0.pt'. Given that this checkpoint not only includes the trained model weights, but also the optimizer state and epoch number (which are used to resume training), its memory footprint can be quite large. Before preparing our Docker container for the algorithm, we should trim down its size and store only what we need for deployment. Please apply the following function to every model checkpoint file that you plan to use in your grand-challenge.org algorithm:

def process_model_weights(input_ckpt_path, output_ckpt_path):
    '''
    Loads model checkpoint that was stored at train-time, discards
    "optimizer_state_dict" and "epoch", and only keeps the trained 
    model weights (i.e. "model_state_dict"). Reduces memory footprint
    of weights file by nearly 5x.
    '''
    checkpoint = torch.load(input_ckpt_path, map_location=torch.device('cpu'))
    torch.save({
        'model_state_dict': checkpoint['model_state_dict']}, output_ckpt_path)

If you're using an ensemble, this function should be applied to each checkpoint file per member model (e.g. an ensemble can consist of the five models derived from all five folds of training/cross-validation).

Next, we highly recommend completing the full tutorial on how to create algorithms on grand-challenge.org. In accordance with the same, we've built an example algorithm for you to use/adapt here. Head over to the "AI: Algorithm Submissions" page on the challenge website for more details, and the final steps needed to make a submission to the PI-CAI challenge.