Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monai training of an MS lesion segmentation model #12

Draft
wants to merge 107 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
bbde8f2
created script to build msd dataset for monai nnunet model training
plbenveniste Mar 11, 2024
081276a
added requirements script
plbenveniste Mar 11, 2024
0ef3f0a
removed yaml from requirements
plbenveniste Mar 11, 2024
451ca30
changed output file names
plbenveniste Mar 12, 2024
1dbd553
added missing requirements
plbenveniste Mar 12, 2024
91bac32
initialised config.yml file example
plbenveniste Mar 12, 2024
48f109c
initialised main file
plbenveniste Mar 12, 2024
5aa2eae
initialised models file
plbenveniste Mar 12, 2024
5b233d2
initalised transforms file
plbenveniste Mar 12, 2024
794430a
modified to fit our training parameters
plbenveniste Mar 12, 2024
0619980
simplied main training script
plbenveniste Mar 12, 2024
e04c423
fixed canproco problem (having both img and label with same link)
plbenveniste Mar 13, 2024
26629c8
added training script for a monai trained UNETR
plbenveniste Mar 13, 2024
e815e14
added fake config file (for debugging)
plbenveniste Mar 13, 2024
5aaf6db
updated requirements to add monai[all] for problem with data loading
plbenveniste Mar 13, 2024
8eafeff
removed old config file
plbenveniste Mar 13, 2024
8f43eac
changed links to dataset
plbenveniste Mar 13, 2024
86e1a32
fixed training. Still need to fix validation params
plbenveniste Mar 13, 2024
32ca1cd
removed files using pytorch ligthning training
plbenveniste Mar 13, 2024
5647955
working monai script based on Jan's code: but no dice score improvement
plbenveniste Mar 14, 2024
a4d7958
pytorch lightning script based on Naga's work
plbenveniste Mar 14, 2024
4e768e4
modified requirements for pytorch lightning
plbenveniste Mar 14, 2024
186c602
added multiply by -1 transform
plbenveniste Mar 18, 2024
adaf04f
parameters changed for config file for first inference run
plbenveniste Mar 18, 2024
dab7dcc
added SoftDiceLoss
plbenveniste Mar 28, 2024
80494e5
removed print from inverse function in utils
plbenveniste Mar 28, 2024
4727dbc
changed resolution to 0.6 isotropic
plbenveniste Mar 28, 2024
5992732
added plot images function for wandb
plbenveniste Mar 28, 2024
9aefb6d
changed loss function and added image printing
plbenveniste Mar 28, 2024
855e26f
changed some parameters for training
plbenveniste Apr 1, 2024
e1ed6e2
added the image plot function
plbenveniste Apr 1, 2024
cffa941
changed model parameters for training
plbenveniste Apr 1, 2024
fc37ca0
code reviewed with no prob but output still problematic
plbenveniste Apr 1, 2024
9f1effd
added lines to save images before training
plbenveniste Apr 1, 2024
01c0912
correction: removed intensity normalisation for labels
plbenveniste Apr 1, 2024
6aa8225
fixed filename to save
plbenveniste Apr 1, 2024
a89b5d7
updated to add some data aug but then removed :/
plbenveniste Apr 2, 2024
6342be0
created file for unet training with multiple input channels
plbenveniste Apr 3, 2024
b858fbc
created file for unet training with multiple output channels
plbenveniste Apr 3, 2024
afb7b4f
training script cleaned for ms lesion seg
plbenveniste Apr 3, 2024
9d6af9b
moved all files in monai and removed nnunet folder
plbenveniste Apr 4, 2024
97b495f
renamed config_fake.yml to config.yml
plbenveniste Apr 4, 2024
1df0305
removed useless previous training script train_monai_UNETR.py
plbenveniste Apr 4, 2024
27eb185
script for training unet with finetuning data-aug parameters
plbenveniste Apr 5, 2024
4d1996f
modified training script with new data augmentation strategies
plbenveniste Apr 5, 2024
a28d176
fixed typos in script and arranged in functions
plbenveniste Apr 5, 2024
a9c293f
fixed parameters for model training on entire dataset
plbenveniste Apr 5, 2024
29b8b69
added function to cound lesions and get total volume
plbenveniste Apr 5, 2024
fc18318
changed batch-size to 8
plbenveniste Apr 5, 2024
b1563f8
changed to attentionUnet
plbenveniste Apr 5, 2024
8e8ee70
added lesion only dataset on entirety of images
plbenveniste Apr 5, 2024
d9997a6
added crop foreground for model training
plbenveniste Apr 10, 2024
6f35a4a
added more data augmentation
plbenveniste Apr 10, 2024
a3df668
added precision and recall metric
plbenveniste Apr 10, 2024
1c3c557
added precision and recall metric : function must be reviewed (not su…
plbenveniste Apr 11, 2024
9d80931
modified version of precision/recall metric
plbenveniste Apr 16, 2024
60a63e6
modified config file
plbenveniste Apr 16, 2024
318c117
modified code to include swinUNETR model
plbenveniste Apr 16, 2024
188ee48
fixed wandb and config file for cleaner pipeline
plbenveniste Apr 16, 2024
01bce84
new script to test the dataset
plbenveniste Apr 17, 2024
d96f1e3
config file for testing the dataset
plbenveniste Apr 17, 2024
761d32f
added cupy install for inference
plbenveniste Apr 17, 2024
11abb73
added script for plotting the performance (dice metric) on the data s…
plbenveniste Apr 17, 2024
e4d6088
correct typo in parser
plbenveniste Apr 17, 2024
aa42e94
fixed typo on basel and bavaria data import
plbenveniste Apr 17, 2024
55b0add
changes made for previous run (before ISMRM)
plbenveniste Jun 4, 2024
149fa34
add function to not take files which are in canproco/exclude.yml
plbenveniste Jun 4, 2024
f28bfe4
added lesion wide metrics
plbenveniste Jun 4, 2024
9ac26f3
changed for bavaria dataset new format
plbenveniste Jun 4, 2024
dadc8ac
added function to remove small objects in utils
plbenveniste Jun 4, 2024
3965b9b
added remove small objects for train, val and inference
plbenveniste Jun 4, 2024
71faa52
changed the min volume threshold
plbenveniste Jun 4, 2024
baaa980
changed msd dataset creation for nih and updated bavaria unstiched data
plbenveniste Jun 26, 2024
500de22
updated requirements and added loguru
plbenveniste Jun 26, 2024
8096913
corrected lesion mask name for nih
plbenveniste Jun 26, 2024
3c7e488
corrected requirements
plbenveniste Jun 26, 2024
7cea4c8
config file for training on ETS server
plbenveniste Jul 15, 2024
5fa1ac1
added nnUNet data augmentation
plbenveniste Jul 16, 2024
88b6619
added contrast, site and orientation in msd dataset
plbenveniste Jul 22, 2024
f387067
improved computation of orientation of image
plbenveniste Jul 22, 2024
ec256ab
added __init__.py file for import possibility
plbenveniste Jul 22, 2024
554b522
removed unused files
plbenveniste Jul 22, 2024
1fdbdd3
moved files to utils folder
plbenveniste Jul 22, 2024
8292ff8
updated parameters for model testing
plbenveniste Jul 23, 2024
7e361ba
updated inference script and evaluation plots scripts
plbenveniste Jul 23, 2024
bf07262
added removal of .nii.gz for UINT1 contrast
plbenveniste Jul 23, 2024
9f5587a
changed workers to 0 for test_model
plbenveniste Aug 1, 2024
c23d63f
added more info in output
plbenveniste Aug 1, 2024
034d97e
created file for cropping aroung head
plbenveniste Aug 1, 2024
00ec30c
updated training script to sota model training script (set workers to 0)
plbenveniste Aug 9, 2024
e038c67
changed location of saving of yaml file to save with the same date as…
plbenveniste Aug 9, 2024
391193c
init mednext training script
plbenveniste Aug 30, 2024
25e7874
added library for diffusion model
plbenveniste Sep 4, 2024
b7ee720
first draft (non-functionning) of diffusion model training script
plbenveniste Sep 4, 2024
6f807f9
created script to train a mednext model
plbenveniste Sep 4, 2024
74b6eed
removed cropping of image before inference
plbenveniste Sep 4, 2024
c497c4e
added new config files
plbenveniste Sep 4, 2024
825d816
added script to perform inference and compute the dice score with var…
plbenveniste Sep 4, 2024
ba6d90e
fixed .cpu problem and added more thresholds
plbenveniste Sep 4, 2024
9027f32
first draft of script for TTA
plbenveniste Sep 4, 2024
11eec85
added code to plot the opt threshold output
plbenveniste Sep 5, 2024
38c1594
fixed threshold to 0.5
plbenveniste Sep 5, 2024
e3b1eff
fixed parenthesis when computing dice score
plbenveniste Sep 11, 2024
a37c6a0
added script to compute TTA with 2nd strategy
plbenveniste Sep 11, 2024
60506b4
added script for mednext inference
plbenveniste Sep 12, 2024
a79cb44
added computation of f1-score, ppv and sensitivity
plbenveniste Sep 25, 2024
953de3f
fixed utils command
plbenveniste Oct 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 311 additions & 0 deletions monai/1_create_msd_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
"""
This file creates the MSD-style JSON datalist to train an nnunet model using monai.
The datasets used are CanProCo, Bavaria-quebec, basel and sct-testing-large.

Arguments:
-pd, --path-data: Path to the data set directory
-po, --path-out: Path to the output directory where dataset json is saved
--lesion-only: Use only masks which contain some lesions
--seed: Seed for reproducibility
--canproco-exclude: Path to the file containing the list of subjects to exclude from CanProCo

Example:
python 1_create_msd_data.py -pd /path/dataset -po /path/output --lesion-only --seed 42 --canproco-exclude /path/exclude_list.txt

TO DO:
*

Pierre-Louis Benveniste
"""

import os
import json
from tqdm import tqdm
import yaml
import argparse
from loguru import logger
from sklearn.model_selection import train_test_split
from datetime import date
from pathlib import Path
import nibabel as nib
import numpy as np
import skimage
from utils.image import Image


def get_parser():
"""
Get parser for script create_msd_data.py

Input:
None

Returns:
parser : argparse object
"""

parser = argparse.ArgumentParser(description='Code for MSD-style JSON datalist for lesion-agnostic nnunet model training.')

parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the folder containing the datasets')
parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved')
parser.add_argument('--canproco-exclude', type=str, help='Path to the file containing the list of subjects to exclude from CanProCo')
parser.add_argument('--lesion-only', action='store_true', help='Use only masks which contain some lesions')
parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility")

return parser


def count_lesion(label_file):
"""
This function takes a label file and counts the number of lesions in it.

Input:
label_file : str : Path to the label file

Returns:
count : int : Number of lesions in the label file
total_volume : float : Total volume of lesions in the label file
"""

label = nib.load(label_file)
label_data = label.get_fdata()

# get the total volume of the lesions
total_volume = np.sum(label_data)
resolution = label.header.get_zooms()
total_volume = total_volume * np.prod(resolution)

# get the number of lesions
_, nb_lesions = skimage.measure.label(label_data, connectivity=2, return_num=True)

return total_volume, nb_lesions


def get_orientation(image_path):
"""
This function takes an image file as input and returns its orientation.

Input:
image_path : str : Path to the image file

Returns:
orientation : str : Orientation of the image
"""
img = Image(str(image_path))
img.change_orientation('RPI')
# Get pixdim
pixdim = img.dim[4:7]
# If all are the same, the image is isotropic
if np.allclose(pixdim, pixdim[0], atol=1e-3):
orientation = 'iso'
return orientation
# Elif, the lowest arg is 0 then the orientation is sagittal
elif np.argmax(pixdim) == 0:
orientation = 'sag'
# Elif, the lowest arg is 1 then the orientation is coronal
elif np.argmax(pixdim) == 1:
orientation = 'cor'
# Else the orientation is axial
else:
orientation = 'ax'
return orientation


def main():
"""
This is the main function of the script.

Input:
None

Returns:
None
"""
# Get the arguments
parser = get_parser()
args = parser.parse_args()

root = args.path_data
seed = args.seed

# Get all subjects
basel_path = Path(os.path.join(root, "basel-mp2rage"))
bavaria_path = Path(os.path.join(root, "bavaria-quebec-spine-ms-unstitched"))
canproco_path = Path(os.path.join(root, "canproco"))
nih_path = Path(os.path.join(root, "nih-ms-mp2rage"))
sct_testing_path = Path(os.path.join(root, "sct-testing-large"))

derivatives_basel = list(basel_path.rglob('*_desc-rater3_label-lesion_seg.nii.gz'))
derivatives_bavaria = list(bavaria_path.rglob('*_lesion-manual.nii.gz'))
derivatives_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz'))
derivatives_nih = list(nih_path.rglob('*_desc-rater1_label-lesion_seg.nii.gz'))
derivatives_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz'))

# Path to the file containing the list of subjects to exclude from CanProCo
if args.canproco_exclude is not None:
with open(args.canproco_exclude, 'r') as file:
canproco_exclude_list = yaml.load(file, Loader=yaml.FullLoader)
# only keep the contrast psir and stir
canproco_exclude_list = canproco_exclude_list['PSIR'] + canproco_exclude_list['STIR']

derivatives = derivatives_basel + derivatives_bavaria + derivatives_canproco + derivatives_nih + derivatives_sct
logger.info(f"Total number of derivatives in the root directory: {len(derivatives)}")

# create one json file with 60-20-20 train-val-test split
train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2
train_derivatives, test_derivatives = train_test_split(derivatives, test_size=test_ratio, random_state=args.seed)
# Use the training split to further split into training and validation splits
train_derivatives, val_derivatives = train_test_split(train_derivatives, test_size=val_ratio / (train_ratio + val_ratio),
random_state=args.seed, )
# sort the subjects
train_derivatives = sorted(train_derivatives)
val_derivatives = sorted(val_derivatives)
test_derivatives = sorted(test_derivatives)

# logger.info(f"Number of training subjects: {len(train_subjects)}")
# logger.info(f"Number of validation subjects: {len(val_subjects)}")
# logger.info(f"Number of testing subjects: {len(test_subjects)}")

# dump train/val/test splits into a yaml file
with open(f"{args.path_out}/data_split_{str(date.today())}_seed{seed}.yaml", 'w') as file:
yaml.dump({'train': train_derivatives, 'val': val_derivatives, 'test': test_derivatives}, file, indent=2, sort_keys=True)

# keys to be defined in the dataset_0.json
params = {}
params["description"] = "ms-lesion-agnostic"
params["labels"] = {
"0": "background",
"1": "ms-lesion-seg"
}
params["license"] = "plb"
params["modality"] = {
"0": "MRI"
}
params["name"] = "ms-lesion-agnostic"
params["seed"] = args.seed
params["reference"] = "NeuroPoly"
params["tensorImageSize"] = "3D"

train_derivatives_dict = {"train": train_derivatives}
val_derivatives_dict = {"validation": val_derivatives}
test_derivatives_dict = {"test": test_derivatives}
all_derivatives_list = [train_derivatives_dict, val_derivatives_dict, test_derivatives_dict]

# iterate through the train/val/test splits and add those which have both image and label
for derivatives_dict in tqdm(all_derivatives_list, desc="Iterating through train/val/test splits"):

for name, derivs_list in derivatives_dict.items():

temp_list = []
for subject_no, derivative in enumerate(derivs_list):


temp_data_basel = {}
temp_data_bavaria = {}
temp_data_canproco = {}
temp_data_nih = {}
temp_data_sct = {}

# Basel
if 'basel-mp2rage' in str(derivative):
relative_path = derivative.relative_to(basel_path).parent
temp_data_basel["label"] = str(derivative)
temp_data_basel["image"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '')
if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]):
total_lesion_volume, nb_lesions = count_lesion(temp_data_basel["label"])
temp_data_basel["total_lesion_volume"] = total_lesion_volume
temp_data_basel["nb_lesions"] = nb_lesions
temp_data_basel["site"]='basel'
temp_data_basel["contrast"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '')
temp_data_basel["orientation"] = get_orientation(temp_data_basel["image"])
if args.lesion_only and nb_lesions == 0:
continue
temp_list.append(temp_data_basel)

# Bavaria-quebec
elif 'bavaria-quebec-spine-ms' in str(derivative):
temp_data_bavaria["label"] = str(derivative)
temp_data_bavaria["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '')
if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]):
total_lesion_volume, nb_lesions = count_lesion(temp_data_bavaria["label"])
temp_data_bavaria["total_lesion_volume"] = total_lesion_volume
temp_data_bavaria["nb_lesions"] = nb_lesions
temp_data_bavaria["site"]='bavaria-quebec'
temp_data_bavaria["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '')
temp_data_bavaria["orientation"] = get_orientation(temp_data_bavaria["image"])
if args.lesion_only and nb_lesions == 0:
continue
temp_list.append(temp_data_bavaria)

# Canproco
elif 'canproco' in str(derivative):
subject_id = derivative.name.replace('_PSIR_lesion-manual.nii.gz', '')
subject_id = subject_id.replace('_STIR_lesion-manual.nii.gz', '')
if subject_id in canproco_exclude_list:
continue
temp_data_canproco["label"] = str(derivative)
temp_data_canproco["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '')
if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]):
total_lesion_volume, nb_lesions = count_lesion(temp_data_canproco["label"])
temp_data_canproco["total_lesion_volume"] = total_lesion_volume
temp_data_canproco["nb_lesions"] = nb_lesions
temp_data_canproco["site"]='canproco'
temp_data_canproco["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '')
temp_data_canproco["orientation"] = get_orientation(temp_data_canproco["image"])
if args.lesion_only and nb_lesions == 0:
continue
temp_list.append(temp_data_canproco)

# nih-ms-mp2rage
elif 'nih-ms-mp2rage' in str(derivative):
temp_data_nih["label"] = str(derivative)
temp_data_nih["image"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '')
if os.path.exists(temp_data_nih["label"]) and os.path.exists(temp_data_nih["image"]):
total_lesion_volume, nb_lesions = count_lesion(temp_data_nih["label"])
temp_data_nih["total_lesion_volume"] = total_lesion_volume
temp_data_nih["nb_lesions"] = nb_lesions
temp_data_nih["site"]='nih'
temp_data_nih["contrast"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '')
temp_data_nih["orientation"] = get_orientation(temp_data_nih["image"])
if args.lesion_only and nb_lesions == 0:
continue
temp_list.append(temp_data_nih)

# sct-testing-large
elif 'sct-testing-large' in str(derivative):
temp_data_sct["label"] = str(derivative)
temp_data_sct["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '')
if os.path.exists(temp_data_sct["label"]) and os.path.exists(temp_data_sct["image"]):
total_lesion_volume, nb_lesions = count_lesion(temp_data_sct["label"])
temp_data_sct["total_lesion_volume"] = total_lesion_volume
temp_data_sct["nb_lesions"] = nb_lesions
temp_data_sct["site"]='sct-testing-large'
temp_data_sct["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '')
temp_data_sct["orientation"] = get_orientation(temp_data_sct["image"])
if args.lesion_only and nb_lesions == 0:
continue
temp_list.append(temp_data_sct)

params[name] = temp_list
logger.info(f"Number of images in {name} set: {len(temp_list)}")
params["numTest"] = len(params["test"])
params["numTraining"] = len(params["train"])
params["numValidation"] = len(params["validation"])
# Print total number of images
logger.info(f"Total number of images in the dataset: {params['numTest'] + params['numTraining'] + params['numValidation']}")

final_json = json.dumps(params, indent=4, sort_keys=True)
if not os.path.exists(args.path_out):
os.makedirs(args.path_out, exist_ok=True)
if args.lesion_only:
jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}_lesionOnly.json", "w")
else:
jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}.json", "w")
jsonFile.write(final_json)
jsonFile.close()

return None


if __name__ == "__main__":
main()
Loading