diff --git a/monai/1_create_msd_data.py b/monai/1_create_msd_data.py new file mode 100644 index 0000000..ddce188 --- /dev/null +++ b/monai/1_create_msd_data.py @@ -0,0 +1,311 @@ +""" +This file creates the MSD-style JSON datalist to train an nnunet model using monai. +The datasets used are CanProCo, Bavaria-quebec, basel and sct-testing-large. + +Arguments: + -pd, --path-data: Path to the data set directory + -po, --path-out: Path to the output directory where dataset json is saved + --lesion-only: Use only masks which contain some lesions + --seed: Seed for reproducibility + --canproco-exclude: Path to the file containing the list of subjects to exclude from CanProCo + +Example: + python 1_create_msd_data.py -pd /path/dataset -po /path/output --lesion-only --seed 42 --canproco-exclude /path/exclude_list.txt + +TO DO: + * + +Pierre-Louis Benveniste +""" + +import os +import json +from tqdm import tqdm +import yaml +import argparse +from loguru import logger +from sklearn.model_selection import train_test_split +from datetime import date +from pathlib import Path +import nibabel as nib +import numpy as np +import skimage +from utils.image import Image + + +def get_parser(): + """ + Get parser for script create_msd_data.py + + Input: + None + + Returns: + parser : argparse object + """ + + parser = argparse.ArgumentParser(description='Code for MSD-style JSON datalist for lesion-agnostic nnunet model training.') + + parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the folder containing the datasets') + parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') + parser.add_argument('--canproco-exclude', type=str, help='Path to the file containing the list of subjects to exclude from CanProCo') + parser.add_argument('--lesion-only', action='store_true', help='Use only masks which contain some lesions') + parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") + + return parser + + +def count_lesion(label_file): + """ + This function takes a label file and counts the number of lesions in it. + + Input: + label_file : str : Path to the label file + + Returns: + count : int : Number of lesions in the label file + total_volume : float : Total volume of lesions in the label file + """ + + label = nib.load(label_file) + label_data = label.get_fdata() + + # get the total volume of the lesions + total_volume = np.sum(label_data) + resolution = label.header.get_zooms() + total_volume = total_volume * np.prod(resolution) + + # get the number of lesions + _, nb_lesions = skimage.measure.label(label_data, connectivity=2, return_num=True) + + return total_volume, nb_lesions + + +def get_orientation(image_path): + """ + This function takes an image file as input and returns its orientation. + + Input: + image_path : str : Path to the image file + + Returns: + orientation : str : Orientation of the image + """ + img = Image(str(image_path)) + img.change_orientation('RPI') + # Get pixdim + pixdim = img.dim[4:7] + # If all are the same, the image is isotropic + if np.allclose(pixdim, pixdim[0], atol=1e-3): + orientation = 'iso' + return orientation + # Elif, the lowest arg is 0 then the orientation is sagittal + elif np.argmax(pixdim) == 0: + orientation = 'sag' + # Elif, the lowest arg is 1 then the orientation is coronal + elif np.argmax(pixdim) == 1: + orientation = 'cor' + # Else the orientation is axial + else: + orientation = 'ax' + return orientation + + +def main(): + """ + This is the main function of the script. + + Input: + None + + Returns: + None + """ + # Get the arguments + parser = get_parser() + args = parser.parse_args() + + root = args.path_data + seed = args.seed + + # Get all subjects + basel_path = Path(os.path.join(root, "basel-mp2rage")) + bavaria_path = Path(os.path.join(root, "bavaria-quebec-spine-ms-unstitched")) + canproco_path = Path(os.path.join(root, "canproco")) + nih_path = Path(os.path.join(root, "nih-ms-mp2rage")) + sct_testing_path = Path(os.path.join(root, "sct-testing-large")) + + derivatives_basel = list(basel_path.rglob('*_desc-rater3_label-lesion_seg.nii.gz')) + derivatives_bavaria = list(bavaria_path.rglob('*_lesion-manual.nii.gz')) + derivatives_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz')) + derivatives_nih = list(nih_path.rglob('*_desc-rater1_label-lesion_seg.nii.gz')) + derivatives_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) + + # Path to the file containing the list of subjects to exclude from CanProCo + if args.canproco_exclude is not None: + with open(args.canproco_exclude, 'r') as file: + canproco_exclude_list = yaml.load(file, Loader=yaml.FullLoader) + # only keep the contrast psir and stir + canproco_exclude_list = canproco_exclude_list['PSIR'] + canproco_exclude_list['STIR'] + + derivatives = derivatives_basel + derivatives_bavaria + derivatives_canproco + derivatives_nih + derivatives_sct + logger.info(f"Total number of derivatives in the root directory: {len(derivatives)}") + + # create one json file with 60-20-20 train-val-test split + train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 + train_derivatives, test_derivatives = train_test_split(derivatives, test_size=test_ratio, random_state=args.seed) + # Use the training split to further split into training and validation splits + train_derivatives, val_derivatives = train_test_split(train_derivatives, test_size=val_ratio / (train_ratio + val_ratio), + random_state=args.seed, ) + # sort the subjects + train_derivatives = sorted(train_derivatives) + val_derivatives = sorted(val_derivatives) + test_derivatives = sorted(test_derivatives) + + # logger.info(f"Number of training subjects: {len(train_subjects)}") + # logger.info(f"Number of validation subjects: {len(val_subjects)}") + # logger.info(f"Number of testing subjects: {len(test_subjects)}") + + # dump train/val/test splits into a yaml file + with open(f"{args.path_out}/data_split_{str(date.today())}_seed{seed}.yaml", 'w') as file: + yaml.dump({'train': train_derivatives, 'val': val_derivatives, 'test': test_derivatives}, file, indent=2, sort_keys=True) + + # keys to be defined in the dataset_0.json + params = {} + params["description"] = "ms-lesion-agnostic" + params["labels"] = { + "0": "background", + "1": "ms-lesion-seg" + } + params["license"] = "plb" + params["modality"] = { + "0": "MRI" + } + params["name"] = "ms-lesion-agnostic" + params["seed"] = args.seed + params["reference"] = "NeuroPoly" + params["tensorImageSize"] = "3D" + + train_derivatives_dict = {"train": train_derivatives} + val_derivatives_dict = {"validation": val_derivatives} + test_derivatives_dict = {"test": test_derivatives} + all_derivatives_list = [train_derivatives_dict, val_derivatives_dict, test_derivatives_dict] + + # iterate through the train/val/test splits and add those which have both image and label + for derivatives_dict in tqdm(all_derivatives_list, desc="Iterating through train/val/test splits"): + + for name, derivs_list in derivatives_dict.items(): + + temp_list = [] + for subject_no, derivative in enumerate(derivs_list): + + + temp_data_basel = {} + temp_data_bavaria = {} + temp_data_canproco = {} + temp_data_nih = {} + temp_data_sct = {} + + # Basel + if 'basel-mp2rage' in str(derivative): + relative_path = derivative.relative_to(basel_path).parent + temp_data_basel["label"] = str(derivative) + temp_data_basel["image"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): + total_lesion_volume, nb_lesions = count_lesion(temp_data_basel["label"]) + temp_data_basel["total_lesion_volume"] = total_lesion_volume + temp_data_basel["nb_lesions"] = nb_lesions + temp_data_basel["site"]='basel' + temp_data_basel["contrast"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_basel["orientation"] = get_orientation(temp_data_basel["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_basel) + + # Bavaria-quebec + elif 'bavaria-quebec-spine-ms' in str(derivative): + temp_data_bavaria["label"] = str(derivative) + temp_data_bavaria["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): + total_lesion_volume, nb_lesions = count_lesion(temp_data_bavaria["label"]) + temp_data_bavaria["total_lesion_volume"] = total_lesion_volume + temp_data_bavaria["nb_lesions"] = nb_lesions + temp_data_bavaria["site"]='bavaria-quebec' + temp_data_bavaria["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_bavaria["orientation"] = get_orientation(temp_data_bavaria["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_bavaria) + + # Canproco + elif 'canproco' in str(derivative): + subject_id = derivative.name.replace('_PSIR_lesion-manual.nii.gz', '') + subject_id = subject_id.replace('_STIR_lesion-manual.nii.gz', '') + if subject_id in canproco_exclude_list: + continue + temp_data_canproco["label"] = str(derivative) + temp_data_canproco["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]): + total_lesion_volume, nb_lesions = count_lesion(temp_data_canproco["label"]) + temp_data_canproco["total_lesion_volume"] = total_lesion_volume + temp_data_canproco["nb_lesions"] = nb_lesions + temp_data_canproco["site"]='canproco' + temp_data_canproco["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_canproco["orientation"] = get_orientation(temp_data_canproco["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_canproco) + + # nih-ms-mp2rage + elif 'nih-ms-mp2rage' in str(derivative): + temp_data_nih["label"] = str(derivative) + temp_data_nih["image"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_nih["label"]) and os.path.exists(temp_data_nih["image"]): + total_lesion_volume, nb_lesions = count_lesion(temp_data_nih["label"]) + temp_data_nih["total_lesion_volume"] = total_lesion_volume + temp_data_nih["nb_lesions"] = nb_lesions + temp_data_nih["site"]='nih' + temp_data_nih["contrast"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_nih["orientation"] = get_orientation(temp_data_nih["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_nih) + + # sct-testing-large + elif 'sct-testing-large' in str(derivative): + temp_data_sct["label"] = str(derivative) + temp_data_sct["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_sct["label"]) and os.path.exists(temp_data_sct["image"]): + total_lesion_volume, nb_lesions = count_lesion(temp_data_sct["label"]) + temp_data_sct["total_lesion_volume"] = total_lesion_volume + temp_data_sct["nb_lesions"] = nb_lesions + temp_data_sct["site"]='sct-testing-large' + temp_data_sct["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_sct["orientation"] = get_orientation(temp_data_sct["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_sct) + + params[name] = temp_list + logger.info(f"Number of images in {name} set: {len(temp_list)}") + params["numTest"] = len(params["test"]) + params["numTraining"] = len(params["train"]) + params["numValidation"] = len(params["validation"]) + # Print total number of images + logger.info(f"Total number of images in the dataset: {params['numTest'] + params['numTraining'] + params['numValidation']}") + + final_json = json.dumps(params, indent=4, sort_keys=True) + if not os.path.exists(args.path_out): + os.makedirs(args.path_out, exist_ok=True) + if args.lesion_only: + jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}_lesionOnly.json", "w") + else: + jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}.json", "w") + jsonFile.write(final_json) + jsonFile.close() + + return None + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/1_create_msd_data_head_cropped.py b/monai/1_create_msd_data_head_cropped.py new file mode 100644 index 0000000..cf28f9e --- /dev/null +++ b/monai/1_create_msd_data_head_cropped.py @@ -0,0 +1,395 @@ +""" +This file creates the MSD-style JSON datalist to train an nnunet model using monai. +The datasets used are CanProCo, Bavaria-quebec, basel and sct-testing-large. + +Arguments: + -pd, --path-data: Path to the data set directory + -po, --path-out: Path to the output directory where dataset json is saved + --lesion-only: Use only masks which contain some lesions + --seed: Seed for reproducibility + --canproco-exclude: Path to the file containing the list of subjects to exclude from CanProCo + +Example: + python 1_create_msd_data.py -pd /path/dataset -po /path/output --lesion-only --seed 42 --canproco-exclude /path/exclude_list.txt + +TO DO: + * + +Pierre-Louis Benveniste +""" + +import os +import json +from tqdm import tqdm +import yaml +import argparse +from loguru import logger +from sklearn.model_selection import train_test_split +from datetime import date +from pathlib import Path +import nibabel as nib +import numpy as np +import skimage +from utils.image import Image + + +def get_parser(): + """ + Get parser for script create_msd_data.py + + Input: + None + + Returns: + parser : argparse object + """ + + parser = argparse.ArgumentParser(description='Code for MSD-style JSON datalist for lesion-agnostic nnunet model training.') + + parser.add_argument('-pd', '--path-data', required=True, type=str, help='Path to the folder containing the datasets') + parser.add_argument('-po', '--path-out', type=str, help='Path to the output directory where dataset json is saved') + parser.add_argument('--canproco-exclude', type=str, help='Path to the file containing the list of subjects to exclude from CanProCo') + parser.add_argument('--lesion-only', action='store_true', help='Use only masks which contain some lesions') + parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") + + return parser + + +def count_lesion(label_file): + """ + This function takes a label file and counts the number of lesions in it. + + Input: + label_file : str : Path to the label file + + Returns: + count : int : Number of lesions in the label file + total_volume : float : Total volume of lesions in the label file + """ + + label = nib.load(label_file) + label_data = label.get_fdata() + + # get the total volume of the lesions + total_volume = np.sum(label_data) + resolution = label.header.get_zooms() + total_volume = total_volume * np.prod(resolution) + + # get the number of lesions + _, nb_lesions = skimage.measure.label(label_data, connectivity=2, return_num=True) + + return total_volume, nb_lesions + + +def get_orientation(image_path): + """ + This function takes an image file as input and returns its orientation. + + Input: + image_path : str : Path to the image file + + Returns: + orientation : str : Orientation of the image + """ + img = Image(str(image_path)) + img.change_orientation('RPI') + # Get pixdim + pixdim = img.dim[4:7] + # If all are the same, the image is isotropic + if np.allclose(pixdim, pixdim[0], atol=1e-3): + orientation = 'iso' + return orientation + # Elif, the lowest arg is 0 then the orientation is sagittal + elif np.argmax(pixdim) == 0: + orientation = 'sag' + # Elif, the lowest arg is 1 then the orientation is coronal + elif np.argmax(pixdim) == 1: + orientation = 'cor' + # Else the orientation is axial + else: + orientation = 'ax' + return orientation + + +def cropping_saving(image_path, label_path, cropped_head_data_folder): + """ + This function does the following action successively: + - copy image and label to the output folder for cropped head data + - segments the spinal cord on the image + - crops the image and label to the remove the superior part of the head (what is above the seg of the spinal cord) + - save the cropped image and label in the output folder + + Input: + image_path : str : Path to the image file + label_path : str : Path to the label file + cropped_head_data_folder : str : Path to the output folder + + Returns: + image_cropped : str : Path to the cropped image + seg_cropped : str : Path to the cropped label + """ + + # Copy image and label to the output folder for cropped head data + image_cropped = os.path.join(cropped_head_data_folder, image_path.split('/')[-1]) + seg_cropped = os.path.join(cropped_head_data_folder, label_path.split('/')[-1]) + img = Image(image_path) + img.change_orientation('RPI') + img.save(image_cropped) + seg = Image(label_path) + seg.change_orientation('RPI') + seg.save(seg_cropped) + + # Segment the spinal cord on the image + ## Create a temporary folder + temp_folder = os.path.join(cropped_head_data_folder, "temp") + os.makedirs(temp_folder, exist_ok=True) + ## Segment the spinal cord + os.system(f"sct_deepseg -i {image_cropped} -o {os.path.join(temp_folder, 'seg.nii.gz')} -task seg_sc_contrast_agnostic -thr 0.5") + ## Get the highest point of the spinal cord + spinal_cord_seg = Image(os.path.join(temp_folder, 'seg.nii.gz')) + spinal_cord_seg.change_orientation('RPI') + spinal_cord_seg_data = spinal_cord_seg.data + spinal_cord_superior = np.max(np.where(spinal_cord_seg_data == 1)[2]) + ## Remove the temporary folder + os.system(f"rm -rf {temp_folder}") + + # Crop the image and label to the remove the superior part of the head (what is above the seg of the spinal cord) + os.system(f"sct_crop_image -i {image_cropped} -o {image_cropped} -zmax {spinal_cord_superior}") + os.system(f"sct_crop_image -i {seg_cropped} -o {seg_cropped} -zmax {spinal_cord_superior}") + + return image_cropped, seg_cropped + + +def main(): + """ + This is the main function of the script. + + Input: + None + + Returns: + None + """ + # Get the arguments + parser = get_parser() + args = parser.parse_args() + + root = args.path_data + seed = args.seed + + # Get all subjects + basel_path = Path(os.path.join(root, "basel-mp2rage")) + bavaria_path = Path(os.path.join(root, "bavaria-quebec-spine-ms-unstitched")) + canproco_path = Path(os.path.join(root, "canproco")) + nih_path = Path(os.path.join(root, "nih-ms-mp2rage")) + sct_testing_path = Path(os.path.join(root, "sct-testing-large")) + + derivatives_basel = list(basel_path.rglob('*_desc-rater3_label-lesion_seg.nii.gz')) + derivatives_bavaria = list(bavaria_path.rglob('*_lesion-manual.nii.gz')) + derivatives_canproco = list(canproco_path.rglob('*_lesion-manual.nii.gz')) + derivatives_nih = list(nih_path.rglob('*_desc-rater1_label-lesion_seg.nii.gz')) + derivatives_sct = list(sct_testing_path.rglob('*_lesion-manual.nii.gz')) + + # Make the folder for the cropped images + cropped_head_data_folder = os.path.join(args.path_out, "cropped_head_data") + os.makedirs(args.path_out, exist_ok=True) + os.makedirs(cropped_head_data_folder, exist_ok=True) + os.makedirs(os.path.join(cropped_head_data_folder, "basel-mp2rage"), exist_ok=True) + os.makedirs(os.path.join(cropped_head_data_folder, "bavaria-quebec-spine-ms-unstitched"), exist_ok=True) + os.makedirs(os.path.join(cropped_head_data_folder, "canproco"), exist_ok=True) + os.makedirs(os.path.join(cropped_head_data_folder, "nih-ms-mp2rage"), exist_ok=True) + os.makedirs(os.path.join(cropped_head_data_folder, "sct-testing-large"), exist_ok=True) + + # Path to the file containing the list of subjects to exclude from CanProCo + if args.canproco_exclude is not None: + with open(args.canproco_exclude, 'r') as file: + canproco_exclude_list = yaml.load(file, Loader=yaml.FullLoader) + # only keep the contrast psir and stir + canproco_exclude_list = canproco_exclude_list['PSIR'] + canproco_exclude_list['STIR'] + + derivatives = derivatives_basel + derivatives_bavaria + derivatives_canproco + derivatives_nih + derivatives_sct + logger.info(f"Total number of derivatives in the root directory: {len(derivatives)}") + + # create one json file with 60-20-20 train-val-test split + train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2 + train_derivatives, test_derivatives = train_test_split(derivatives, test_size=test_ratio, random_state=args.seed) + # Use the training split to further split into training and validation splits + train_derivatives, val_derivatives = train_test_split(train_derivatives, test_size=val_ratio / (train_ratio + val_ratio), + random_state=args.seed, ) + # sort the subjects + train_derivatives = sorted(train_derivatives) + val_derivatives = sorted(val_derivatives) + test_derivatives = sorted(test_derivatives) + + # logger.info(f"Number of training subjects: {len(train_subjects)}") + # logger.info(f"Number of validation subjects: {len(val_subjects)}") + # logger.info(f"Number of testing subjects: {len(test_subjects)}") + + # keys to be defined in the dataset_0.json + params = {} + params["description"] = "ms-lesion-agnostic" + params["labels"] = { + "0": "background", + "1": "ms-lesion-seg" + } + params["license"] = "plb" + params["modality"] = { + "0": "MRI" + } + params["name"] = "ms-lesion-agnostic" + params["seed"] = args.seed + params["reference"] = "NeuroPoly" + params["tensorImageSize"] = "3D" + + train_derivatives_dict = {"train": train_derivatives} + val_derivatives_dict = {"validation": val_derivatives} + test_derivatives_dict = {"test": test_derivatives} + all_derivatives_list = [train_derivatives_dict, val_derivatives_dict, test_derivatives_dict] + + # iterate through the train/val/test splits and add those which have both image and label + for derivatives_dict in tqdm(all_derivatives_list, desc="Iterating through train/val/test splits"): + + for name, derivs_list in derivatives_dict.items(): + + temp_list = [] + for subject_no, derivative in enumerate(derivs_list): + + + temp_data_basel = {} + temp_data_bavaria = {} + temp_data_canproco = {} + temp_data_nih = {} + temp_data_sct = {} + + # Basel + if 'basel-mp2rage' in str(derivative): + relative_path = derivative.relative_to(basel_path).parent + temp_data_basel["label"] = str(derivative) + temp_data_basel["image"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_basel["label"]) and os.path.exists(temp_data_basel["image"]): + # Cropping image and seg and saving to the cropped_head_data folder + image, seg = cropping_saving(temp_data_basel["image"], temp_data_basel["label"], os.path.join(cropped_head_data_folder, "basel-mp2rage")) + temp_data_basel["label"] = seg + temp_data_basel["image"] = image + + total_lesion_volume, nb_lesions = count_lesion(temp_data_basel["label"]) + temp_data_basel["total_lesion_volume"] = total_lesion_volume + temp_data_basel["nb_lesions"] = nb_lesions + temp_data_basel["site"]='basel' + temp_data_basel["contrast"] = str(derivative).replace('_desc-rater3_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_basel["orientation"] = get_orientation(temp_data_basel["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_basel) + + # Bavaria-quebec + elif 'bavaria-quebec-spine-ms' in str(derivative): + temp_data_bavaria["label"] = str(derivative) + temp_data_bavaria["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_bavaria["label"]) and os.path.exists(temp_data_bavaria["image"]): + # Cropping image and seg and saving to the cropped_head_data folder + image, seg = cropping_saving(temp_data_bavaria["image"], temp_data_bavaria["label"], os.path.join(cropped_head_data_folder, "bavaria-quebec-spine-ms-unstitched")) + temp_data_bavaria["label"] = seg + temp_data_bavaria["image"] = image + + total_lesion_volume, nb_lesions = count_lesion(temp_data_bavaria["label"]) + temp_data_bavaria["total_lesion_volume"] = total_lesion_volume + temp_data_bavaria["nb_lesions"] = nb_lesions + temp_data_bavaria["site"]='bavaria-quebec' + temp_data_bavaria["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_bavaria["orientation"] = get_orientation(temp_data_bavaria["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_bavaria) + + # Canproco + elif 'canproco' in str(derivative): + subject_id = derivative.name.replace('_PSIR_lesion-manual.nii.gz', '') + subject_id = subject_id.replace('_STIR_lesion-manual.nii.gz', '') + if subject_id in canproco_exclude_list: + continue + temp_data_canproco["label"] = str(derivative) + temp_data_canproco["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_canproco["label"]) and os.path.exists(temp_data_canproco["image"]): + # Cropping image and seg and saving to the cropped_head_data folder + image, seg = cropping_saving(temp_data_canproco["image"], temp_data_canproco["label"], os.path.join(cropped_head_data_folder, "canproco")) + temp_data_canproco["label"] = seg + temp_data_canproco["image"] = image + + total_lesion_volume, nb_lesions = count_lesion(temp_data_canproco["label"]) + temp_data_canproco["total_lesion_volume"] = total_lesion_volume + temp_data_canproco["nb_lesions"] = nb_lesions + temp_data_canproco["site"]='canproco' + temp_data_canproco["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_canproco["orientation"] = get_orientation(temp_data_canproco["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_canproco) + + # nih-ms-mp2rage + elif 'nih-ms-mp2rage' in str(derivative): + temp_data_nih["label"] = str(derivative) + temp_data_nih["image"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_nih["label"]) and os.path.exists(temp_data_nih["image"]): + # Cropping image and seg and saving to the cropped_head_data folder + image, seg = cropping_saving(temp_data_nih["image"], temp_data_nih["label"], os.path.join(cropped_head_data_folder, "nih-ms-mp2rage")) + temp_data_nih["label"] = seg + temp_data_nih["image"] = image + + total_lesion_volume, nb_lesions = count_lesion(temp_data_nih["label"]) + temp_data_nih["total_lesion_volume"] = total_lesion_volume + temp_data_nih["nb_lesions"] = nb_lesions + temp_data_nih["site"]='nih' + temp_data_nih["contrast"] = str(derivative).replace('_desc-rater1_label-lesion_seg.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_nih["orientation"] = get_orientation(temp_data_nih["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_nih) + + # sct-testing-large + elif 'sct-testing-large' in str(derivative): + temp_data_sct["label"] = str(derivative) + temp_data_sct["image"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').replace('derivatives/labels/', '') + if os.path.exists(temp_data_sct["label"]) and os.path.exists(temp_data_sct["image"]): + # Cropping image and seg and saving to the cropped_head_data folder + image, seg = cropping_saving(temp_data_sct["image"], temp_data_sct["label"], os.path.join(cropped_head_data_folder, "sct-testing-large")) + temp_data_sct["label"] = seg + temp_data_sct["image"] = image + + total_lesion_volume, nb_lesions = count_lesion(temp_data_sct["label"]) + temp_data_sct["total_lesion_volume"] = total_lesion_volume + temp_data_sct["nb_lesions"] = nb_lesions + temp_data_sct["site"]='sct-testing-large' + temp_data_sct["contrast"] = str(derivative).replace('_lesion-manual.nii.gz', '.nii.gz').split('_')[-1].replace('.nii.gz', '') + temp_data_sct["orientation"] = get_orientation(temp_data_sct["image"]) + if args.lesion_only and nb_lesions == 0: + continue + temp_list.append(temp_data_sct) + + params[name] = temp_list + logger.info(f"Number of images in {name} set: {len(temp_list)}") + params["numTest"] = len(params["test"]) + params["numTraining"] = len(params["train"]) + params["numValidation"] = len(params["validation"]) + # Print total number of images + logger.info(f"Total number of images in the dataset: {params['numTest'] + params['numTraining'] + params['numValidation']}") + + final_json = json.dumps(params, indent=4, sort_keys=True) + if not os.path.exists(args.path_out): + os.makedirs(args.path_out, exist_ok=True) + if args.lesion_only: + jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}_lesionOnly.json", "w") + else: + jsonFile = open(args.path_out + "/" + f"dataset_{str(date.today())}_seed{seed}.json", "w") + jsonFile.write(final_json) + jsonFile.close() + + # dump train/val/test splits into a yaml file + with open(f"{args.path_out}/data_split_{str(date.today())}_seed{seed}.yaml", 'w') as file: + yaml.dump({'train': train_derivatives, 'val': val_derivatives, 'test': test_derivatives}, file, indent=2, sort_keys=True) + + return None + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/average_tta_performance.py b/monai/average_tta_performance.py new file mode 100644 index 0000000..14e6498 --- /dev/null +++ b/monai/average_tta_performance.py @@ -0,0 +1,84 @@ +""" +This file is used to get all the dice_scores_X.txt files in a directory and average them. + +Input: + - Path to the directory containing the dice_scores_X.txt files + +Output: + None + +Example: + python average_tta_performance.py --pred-dir-path /path/to/dice_scores + +Author: Pierre-Louis Benveniste +""" + +import os +import argparse +import numpy as np +import pandas as pd +from pathlib import Path + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Average the performance of the model") + parser.add_argument("--pred-dir-path", help="Path to the directory containing the dice_scores_X.txt files", required=True) + return parser + + +def main(): + """ + This function is used to average the performance of the model on the test set. + + Args: + None + + Returns: + None + """ + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Path to the dice_scores + path_to_outputs = args.pred_dir_path + + # Get all the dice_scores_X.txt files using rglob + dice_score_files = [str(file) for file in Path(path_to_outputs).rglob("dice_scores_*.txt")] + + # Dict to store the dice scores + dice_scores = {} + + # Loop over the dice_scores_X.txt files + for dice_score_file in dice_score_files: + # Open dice results (they are txt files) + with open(os.path.join(path_to_outputs, dice_score_file), 'r') as file: + for line in file: + key, value = line.strip().split(':') + if key in dice_scores: + dice_scores[key].append(float(value)) + else: + dice_scores[key] = [float(value)] + + # Average the dice scores ang get standard deviation + std = {} + for key in dice_scores: + std[key] = np.std(dice_scores[key]) + dice_scores[key] = np.mean(dice_scores[key]) + + # Save the averaged dice scores + with open(os.path.join(path_to_outputs, "dice_scores.txt"), 'w') as file: + for key in dice_scores: + file.write(f"{key}: {dice_scores[key]}\n") + + # Save the standard deviation + with open(os.path.join(path_to_outputs, "std.txt"), 'w') as file: + for key in std: + file.write(f"{key}: {std[key]}\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/compute_performance_tta_sum.py b/monai/compute_performance_tta_sum.py new file mode 100644 index 0000000..a8702a5 --- /dev/null +++ b/monai/compute_performance_tta_sum.py @@ -0,0 +1,130 @@ +""" +This script is used to sum all the image predictions of the same subject, then threshold to 0.5 and then compute the dice score. + +Input: + --path-pred: Path to the directory containing the predictions + --path-json: Path to the json file containing the data split + --split: Data split to use (train, validation, test) + --output-dir: Output directory to save the dice scores + +Output: + None + +Example: + python compute_performance_tta_sum.py --path-pred /path/to/predictions --path-json /path/to/data.json --split test --output-dir /path/to/output + +Author: Pierre-Louis Benveniste +""" + +import os +import numpy as np +import argparse +from pathlib import Path +import json +import nibabel as nib +from tqdm import tqdm + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--path-pred", type=str, required=True, help="Path to the directory containing the predictions") + parser.add_argument("--path-json", type=str, required=True, help="Path to the json file containing the data split") + parser.add_argument("--split", type=str, required=True, help="Data split to use (train, validation, test)") + parser.add_argument("--output-dir", type=str, required=True, help="Output directory to save the dice scores") + return parser.parse_args() + + +def dice_score(prediction, groundtruth, smooth=1.): + numer = (prediction * groundtruth).sum() + denor = (prediction + groundtruth).sum() + # loss = (2 * numer + self.smooth) / (denor + self.smooth) + dice = (2 * numer + smooth) / (denor + smooth) + return dice + + +def main(): + + # Parse arguments + args = parse_args() + path_pred = args.path_pred + path_json = args.path_json + split = args.split + output_dir = args.output_dir + + # Create the output directory + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Get all the predictions (with rglob) + predictions = list(Path(path_pred).rglob("*.nii.gz")) + + # List of subjects + subjects = [pred.name for pred in predictions] + + n_tta = 10 + + for i in range(n_tta): + # Remove the _pred_0, _pred_1 ... _pred_9 at the end of the name + subjects = [sub.replace(f"_pred_{i}", "") for sub in subjects] + + # Open the conversion dictionary (its a json file) + with open(path_json, "r") as f: + conversion_dict = json.load(f) + conversion_dict = conversion_dict[split] + + # Dict of dice score + dice_scores = {} + + # Iterate over the subjects in the predictions + for subject in subjects: + print(f"Processing subject {subject}") + + # Get all predictions corresponding to the subject + subject_predictions = [str(pred) for pred in predictions if subject.replace(".nii.gz", "") in pred.name] + # print(subject_predictions) + + # Find the corresponding label from the conversion dict + + image_dict = [data for data in conversion_dict if subject in data["image"]] + label = image_dict[0]["label"] + image = image_dict[0]["image"] + + # We now sum all the predictions + summed_prediction = None + for pred in subject_predictions: + pred_data = nib.load(pred).get_fdata() + if summed_prediction is None: + summed_prediction = pred_data + else: + summed_prediction += pred_data + + # Threshold the summed prediction + summed_prediction[summed_prediction >= 0.5] = 1 + summed_prediction[summed_prediction < 0.5] = 0 + + # Load the label + label_data = nib.load(label).get_fdata() + + # Compute dice score + dice = dice_score(summed_prediction, label_data) + # print(f"Dice score for summed prediction: {dice}") + + # Compare the dice score with the individual predictions + for pred in subject_predictions: + pred_data = nib.load(pred).get_fdata() + dice_pred = dice_score(pred_data, label_data) + # print(f"Dice score for {pred}: {dice_pred}") + + # Save the dice score + dice_scores[image] = dice + + # Save the results + with open(os.path.join(output_dir, "dice_scores.txt"), "w") as f: + for key, value in dice_scores.items(): + f.write(f"{key}: {value}\n") + + return None + + +if __name__ == "__main__": + main() diff --git a/monai/config.yml b/monai/config.yml new file mode 100644 index 0000000..5e953b2 --- /dev/null +++ b/monai/config.yml @@ -0,0 +1,50 @@ +# Description: Configuration file for the UNETR model + +# Path to the data json file +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_lesion_sc.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_10_each.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_sc.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json +# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-17_seed42_lesionOnly.json +data: /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-06-26_seed42_lesionOnly.json +# data: /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-08-13_seed42_lesionOnly.json +# data: /home/plbenveniste/net/ms-lesion-agnostic/msd_data/fake.json + +# Resampling resolution +# pixdim : [1.0, 1.0, 1.0] +pixdim : [0.7, 0.7, 0.7] +# pixdim : [0.5, 0.5, 0.5] + +# Spatial size of the input data +spatial_size : [64, 128, 128] # RL, AP, IS +batch_size : 4 # smaller batch size lead to better generalization https://arxiv.org/abs/1609.04836 but longer to train + +# Augmentation parameters +DA_probability : 0.2 + +# Optimizer parameters +lr : 0.0001 +weight_decay: 0.00001 +early_stopping_patience : 50 + +# Training parameters +max_iterations : 250 +eval_num : 2 + +# Outputs +# output_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/ +output_path : /home/plbenveniste/net/ms-lesion-agnostic/results/ +# output_path : /home/plbenveniste/net/ms-lesion-agnostic/results_cropped_head/ + +# Seed +seed : 42 + +# UNET model parameters +unet_channels : [32, 64, 128, 256, 512, 1024] +unet_strides : [2, 2, 2, 2, 2, 2, 2] + +# AttentionUnet +attention_unet_channels : [32, 64, 128, 256, 512] +attention_unet_strides : [2, 2, 2, 2, 2] \ No newline at end of file diff --git a/monai/config_test.yml b/monai/config_test.yml new file mode 100644 index 0000000..4fc3c59 --- /dev/null +++ b/monai/config_test.yml @@ -0,0 +1,21 @@ +# dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-17_seed42_lesionOnly.json +# dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json +dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-06-26_seed42_lesionOnly.json +# dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-08-13_seed42_lesionOnly.json +# dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_optThresh.json +# dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/fake.json + +pixdim : [0.7, 0.7, 0.7] +spatial_size : [64, 128, 128] +attention_unet_channels : [32, 64, 128, 256, 512] +attention_unet_strides : [2, 2, 2, 2, 2] + +# path_to_model : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/best_model.pth/best_model.ckpt +# path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/tta_exp/best_model.pth/best_model.ckpt +path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-07-18_10:46:21.634514/best_model.pth/best_model.ckpt +# path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-09-02_12:14:28.124188/best_model.pth/best_model.ckpt + +# output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/ +# output_dir : /home/plbenveniste/net/ms-lesion-agnostic/tta_exp +output_dir : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-07-18_10:46:21.634514/ +# output_dir : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-09-02_12:14:28.124188/ \ No newline at end of file diff --git a/monai/plot_optThresh.py b/monai/plot_optThresh.py new file mode 100644 index 0000000..17d4908 --- /dev/null +++ b/monai/plot_optThresh.py @@ -0,0 +1,85 @@ +""" +This script plots the performance of the model based on the threshold applied to the predictions. + +Input: + --path-scores: Path to the directory containing the dice_scores_X.txt files + +Output: + None + +Example: + python plot_optThresh.py --path-scores /path/to/dice_scores + +Author: Pierre-Louis Benveniste +""" + +import os +import argparse +import numpy as np +from pathlib import Path +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Plot the optimal threshold") + parser.add_argument("--path-scores", help="Path to the directory containing the dice_scores_X.txt files", required=True) + return parser + + +def main(): + + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Path to the dice_scores + path_to_outputs = args.path_scores + + # Get all the dice_scores_X.txt files using rglob + dice_score_files = [str(file) for file in Path(path_to_outputs).rglob("dice_scores_*.txt")] + + # Create a list to store the dataframes + test_dice_results_list = [None] * len(dice_score_files) + + # For each file, get the threshold and the dice score + for i, dice_score_file in enumerate(dice_score_files): + test_dice_results = {} + with open(dice_score_file, 'r') as file: + for line in file: + key, value = line.strip().split(':') + test_dice_results[key] = float(value) + # convert to a df with name and dice score + test_dice_results_list[i] = pd.DataFrame(list(test_dice_results.items()), columns=['name', 'dice_score']) + # Create a column which stores the threshold + test_dice_results_list[i]['threshold'] = str(Path(dice_score_file).name).replace('dice_scores_', '').replace('.txt', '').replace('_', '.') + + # Concatenate all the dataframes + test_dice_results = pd.concat(test_dice_results_list) + + # Plot + plt.figure(figsize=(20, 10)) + plt.grid(True) + sns.violinplot(x='threshold', y='dice_score', data=test_dice_results) + # y ranges from -0.2 to 1.2 + plt.ylim(-0.2, 1.2) + plt.title('Dice scores per threshold') + plt.show() + + # Save the plot + plt.savefig(path_to_outputs + '/dice_scores_contrast.png') + print(f"Saved the dice_scores plot in {path_to_outputs}") + + # Print the average dice score per threshold + for thresh in test_dice_results['threshold'].unique(): + print(f"Threshold: {thresh} - Average dice score: {test_dice_results[test_dice_results['threshold'] == thresh]['dice_score'].mean()}") + + return None + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/plot_performance.py b/monai/plot_performance.py new file mode 100644 index 0000000..2fc719d --- /dev/null +++ b/monai/plot_performance.py @@ -0,0 +1,206 @@ +"""" +This script is used to plot the performance of the model on the test set, validation and train set. +It saves a plot of dice scores per contrat in the output folder + +""" +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import argparse +import json + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Plot the performance of the model") + parser.add_argument("--pred-dir-path", help="Path to the directory containing the dice_score.txt file", required=True) + parser.add_argument("--data-json-path", help="Path to the json file containing the data split", required=True) + parser.add_argument("--split", help="Data split to use (train, validation, test)", required=True, type=str) + return parser + + +def main(): + """ + This function is used to plot the performance of the model on the test set. + + Args: + None + + Returns: + None + """ + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Path to the dice_scores + path_to_outputs = args.pred_dir_path + dice_score_file = path_to_outputs + '/dice_scores.txt' + + # Open dice results (they are txt files) + test_dice_results = {} + with open(dice_score_file, 'r') as file: + for line in file: + key, value = line.strip().split(':') + test_dice_results[key] = float(value) + + # convert to a df with name and dice score + test_dice_results = pd.DataFrame(list(test_dice_results.items()), columns=['name', 'dice_score']) + + # Create an empty column for the contrast, the site and the resolution + test_dice_results['contrast'] = None + test_dice_results['site'] = None + test_dice_results['resolution'] = None + + # Load the data json file + data_json_path = args.data_json_path + with open(data_json_path, 'r') as f: + jsondata = json.load(f) + + # Iterate over the test files + for file in test_dice_results['name']: + # We find the corresponding file in the json file + for data in jsondata[args.split]: + if data["image"] == file: + # Add the contrat, the site and the resolution to the df + test_dice_results.loc[test_dice_results['name'] == file, 'contrast'] = data['contrast'] + test_dice_results.loc[test_dice_results['name'] == file, 'site'] = data['site'] + test_dice_results.loc[test_dice_results['name'] == file, 'orientation'] = data['orientation'] + test_dice_results.loc[test_dice_results['name'] == file, 'nb_lesions'] = data['nb_lesions'] + test_dice_results.loc[test_dice_results['name'] == file, 'total_lesion_volume'] = data['total_lesion_volume'] + + # Count the number of samples per contrast + contrast_counts = test_dice_results['contrast'].value_counts() + + # In the df replace the contrats by the number of samples of the contarsts( for example, T2 becomes T2 (n=10)) + test_dice_results['contrast_count'] = test_dice_results['contrast'].apply(lambda x: x + f' (n={contrast_counts[x]})') + + # Same for the site + site_counts = test_dice_results['site'].value_counts() + test_dice_results['site_count'] = test_dice_results['site'].apply(lambda x: x + f' (n={site_counts[x]})') + + # Same for the resolution + resolution_counts = test_dice_results['orientation'].value_counts() + test_dice_results['orientation_count'] = test_dice_results['orientation'].apply(lambda x: x + f' (n={resolution_counts[x]})') + + # then we add the ppv score to the df + ppv_score_file = path_to_outputs + '/ppv_scores.txt' + ppv_scores = {} + with open(ppv_score_file, 'r') as file: + for line in file: + key, value = line.strip().split(':') + ppv_scores[key] = float(value) + test_dice_results['ppv_score'] = test_dice_results['name'].apply(lambda x: ppv_scores[x]) + + # then we add the f1 score to the df + f1_score_file = path_to_outputs + '/f1_scores.txt' + f1_scores = {} + with open(f1_score_file, 'r') as file: + for line in file: + key, value = line.strip().split(':') + f1_scores[key] = float(value) + test_dice_results['f1_score'] = test_dice_results['name'].apply(lambda x: f1_scores[x]) + + # then we add the sensitivity score to the df + sensitivity_score_file = path_to_outputs + '/sensitivity_scores.txt' + sensitivity_scores = {} + with open(sensitivity_score_file, 'r') as file: + for line in file: + key, value = line.strip().split(':') + sensitivity_scores[key] = float(value) + test_dice_results['sensitivity_score'] = test_dice_results['name'].apply(lambda x: sensitivity_scores[x]) + + # We rename th df to metrics_results + metrics_results = test_dice_results + + # Sort the order of the lines by contrast (alphabetical order) + metrics_results = metrics_results.sort_values(by='contrast').reset_index(drop=True) + + # plot a violin plot per contrast for dice scores + plt.figure(figsize=(20, 10)) + plt.grid(True) + sns.violinplot(x='contrast_count', y='dice_score', data=metrics_results) + # y ranges from -0.2 to 1.2 + plt.ylim(-0.2, 1.2) + plt.title('Dice scores per contrast') + plt.show() + # # Save the plot + plt.savefig(path_to_outputs + '/dice_scores_contrast.png') + print(f"Saved the dice plot in {path_to_outputs}") + + # plot a violin plot per contrast for ppv scores + plt.figure(figsize=(20, 10)) + plt.grid(True) + sns.violinplot(x='contrast_count', y='ppv_score', data=metrics_results) + # y ranges from -0.2 to 1.2 + plt.ylim(-0.2, 1.2) + plt.title('PPV scores per contrast') + plt.show() + + # # Save the plot + plt.savefig(path_to_outputs + '/ppv_scores_contrast.png') + print(f"Saved the ppv plot in {path_to_outputs}") + + # plot a violin plot per contrast for f1 scores + plt.figure(figsize=(20, 10)) + plt.grid(True) + sns.violinplot(x='contrast_count', y='f1_score', data=metrics_results) + # y ranges from -0.2 to 1.2 + plt.ylim(-0.2, 1.2) + plt.title('F1 scores per contrast') + plt.show() + + # # Save the plot + plt.savefig(path_to_outputs + '/f1_scores_contrast.png') + print(f"Saved the F1 plot in {path_to_outputs}") + + # plot a violin plot per contrast for f1 scores + plt.figure(figsize=(20, 10)) + plt.grid(True) + sns.violinplot(x='contrast_count', y='sensitivity_score', data=metrics_results) + # y ranges from -0.2 to 1.2 + plt.ylim(-0.2, 1.2) + plt.title('Sensitivity scores per contrast') + plt.show() + + # # Save the plot + plt.savefig(path_to_outputs + '/sensitivity_scores_contrast.png') + print(f"Saved the sensitivity plot in {path_to_outputs}") + + # # plot a violin plot per site + # plt.figure(figsize=(20, 10)) + # plt.grid(True) + # sns.violinplot(x='site_count', y='dice_score', data=test_dice_results, order = ['bavaria-quebec (n=208)', 'sct-testing-large (n=233)', 'canproco (n=71)','nih (n=25)','basel (n=32)']) + # # y ranges from -0.2 to 1.2 + # plt.ylim(-0.2, 1.2) + # plt.title('Dice scores per site') + # plt.show() + + # # Save the plot + # plt.savefig(path_to_outputs + '/dice_scores_site.png') + # print(f"Saved the dice_scores per site plot in {path_to_outputs}") + + # # plot a violin plot per resolution + # plt.figure(figsize=(20, 10)) + # plt.grid(True) + # sns.violinplot(x='orientation_count', y='dice_score', data=test_dice_results, order = ['iso (n=58)', 'ax (n=343)', 'sag (n=168)']) + # # y ranges from -0.2 to 1.2 + # plt.ylim(-0.2, 1.2) + # plt.title('Dice scores per orientation') + # plt.show() + + # # Save the plot + # plt.savefig(path_to_outputs + '/dice_scores_orientation.png') + # print(f"Saved the dice_scores per orientation plot in {path_to_outputs}") + + # # Save the test_dice_results dataframe + # test_dice_results.to_csv(path_to_outputs + '/dice_results.csv', index=False) + + return None + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/requirements.txt b/monai/requirements.txt new file mode 100644 index 0000000..c14734f --- /dev/null +++ b/monai/requirements.txt @@ -0,0 +1,13 @@ +numpy==1.24.3 +tqdm==4.65.0 +torch==2.0.1 +torchvision==0.15.2 +monai[all]==1.3.0 +matplotlib==3.8.2 +pytorch-lightning==2.2.1 +cupy-cuda117==10.6.0 +loguru==0.7.2 +wandb==0.15.12 +dynamic-network-architectures==0.2 +seaborn==0.13.2 +monai-generative==0.2.3 \ No newline at end of file diff --git a/monai/test_model.py b/monai/test_model.py new file mode 100644 index 0000000..d0bb1d5 --- /dev/null +++ b/monai/test_model.py @@ -0,0 +1,206 @@ +""" +This code is used to test the model on a test set. +It uses the class Model which was defined in the file train_monai_unet_lightning.py. +""" +import os +from monai.transforms import ( + Compose, + LoadImaged, + EnsureChannelFirstd, + Orientationd, + Spacingd, + NormalizeIntensityd, + ResizeWithPadOrCropd, + Invertd, + EnsureTyped, + SaveImage, +) +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch, Dataset) +from monai.networks.nets import AttentionUnet +import torch +from monai.inferers import sliding_window_inference +import torch.nn.functional as F +from utils.utils import dice_score, lesion_f1_score, lesion_ppv, lesion_sensitivity +import argparse +import yaml +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Test the model on the test set") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + parser.add_argument("--data-split", help="Data split to use (train, validation, test)", required=True, type=str) + return parser + + +def main(): + """ + This function is used to test the model on a test set. + + Args: + None + + Returns: + None + """ + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Load the config file + with open(args.config, "r") as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + + # Device + DEVICE = "cuda" + + # build output directory + output_dir = os.path.join(cfg["output_dir"], args.data_split +"_set") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Dict of scores + dice_scores = {} + ppv_scores = {} + sensitivity_scores = {} + f1_scores = {} + + # Load the data + test_files = load_decathlon_datalist(cfg["dataset"], True, args.data_split) + + #Create the test transforms + test_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=cfg["pixdim"], + mode=(2, 0), + ), + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # ResizeWithPadOrCropd( + # keys=["image", "label"], + # spatial_size=cfg["spatial_size"], + # ), + ] + ) + + # Create the prediction post-processing function + ## For this to work I had to add cupy-cuda117==10.6.0 to the requirements + test_post_pred = Compose([ + EnsureTyped(keys=["pred"]), + Invertd(keys=["pred"], transform=test_transforms, + orig_keys=["image"], + meta_keys=["pred_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + + # Create the data loader + test_ds = CacheDataset(data=test_files, transform=test_transforms, cache_rate=0.1, num_workers=0) + test_data_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0) + + # Load the model + net = AttentionUnet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=cfg["attention_unet_channels"], + strides=cfg["attention_unet_strides"], + dropout=0.1, + ) + net.to(DEVICE) + checkpoint = torch.load(cfg["path_to_model"], map_location=torch.device(DEVICE))["state_dict"] + # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 + for key in list(checkpoint.keys()): + if 'net.' in key: + checkpoint[key.replace('net.', '')] = checkpoint[key] + del checkpoint[key] + # remove the key loss_function.dice.class_weights because it is not needed + # I had the error but I don't really know why + if 'loss_function.dice.class_weight' in key: + del checkpoint[key] + net.load_state_dict(checkpoint) + net.eval() + + # Run inference + with torch.no_grad(): + for i, batch in enumerate(test_data_loader): + # get the test input + test_input = batch["image"].to(DEVICE) + + # run inference + batch["pred"] = sliding_window_inference(test_input, cfg["spatial_size"], mode="gaussian", + sw_batch_size=4, predictor=net, overlap=0.5, progress=False) + + # NOTE: monai's models do not normalize the output, so we need to do it manually + if bool(F.relu(batch["pred"]).max()): + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() + else: + batch["pred"] = F.relu(batch["pred"]) + + # Threshold the prediction with 0.5 based on this investigation: https://github.com/ivadomed/ms-lesion-agnostic/issues/32 + pred_cpu = batch["pred"].cpu() + pred_cpu[pred_cpu < 0.5] = 0 + pred_cpu[pred_cpu >= 0.5] = 1 + # Compute the dice score + dice = dice_score(pred_cpu, batch["label"].cpu()) + ppv = lesion_ppv(batch["label"].cpu(), pred_cpu) + sensitivity = lesion_sensitivity(batch["label"].cpu(), pred_cpu) + f1 = lesion_f1_score(batch["label"].cpu(), pred_cpu) + + # post-process the prediction + post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] + + # Threshold the prediction with 0.5 before saving + pred = post_test_out[0]['pred'].cpu() + pred[pred < 0.5] = 0 + pred[pred >= 0.5] = 1 + + # Get file name + file_name = test_files[i]["image"].split("/")[-1].split(".")[0] + print(f"Saving {file_name}: dice score = {dice}, f1 = {f1}") + + # Save the prediction + pred_saver = SaveImage( + output_dir=output_dir , output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False) + # save the prediction + pred_saver(pred) + + # Save the scores + dice_scores[test_files[i]["image"]] = dice + ppv_scores[test_files[i]["image"]] = ppv + sensitivity_scores[test_files[i]["image"]] = sensitivity + f1_scores[test_files[i]["image"]] = f1 + + test_input.detach() + + + # Save the dice scores + with open(os.path.join(output_dir, "dice_scores.txt"), "w") as f: + for key, value in dice_scores.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "ppv_scores.txt"), "w") as f: + for key, value in ppv_scores.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "sensitivity_scores.txt"), "w") as f: + for key, value in sensitivity_scores.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "f1_scores.txt"), "w") as f: + for key, value in f1_scores.items(): + f.write(f"{key}: {value}\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/test_model_mednext.py b/monai/test_model_mednext.py new file mode 100644 index 0000000..9c1e6fd --- /dev/null +++ b/monai/test_model_mednext.py @@ -0,0 +1,193 @@ +""" +This code is used to test the model on a test set. +It uses the class Model which was defined in the file train_monai_unet_lightning.py. +""" +import os +from monai.transforms import ( + Compose, + LoadImaged, + EnsureChannelFirstd, + Orientationd, + Spacingd, + NormalizeIntensityd, + ResizeWithPadOrCropd, + Invertd, + EnsureTyped, + SaveImage, +) +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch, Dataset) +from monai.networks.nets import AttentionUnet +import torch +from monai.inferers import sliding_window_inference +import torch.nn.functional as F +from utils.utils import dice_score +import argparse +import yaml +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + +from nnunet_mednext import MedNeXt + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Test the model on the test set") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + parser.add_argument("--data-split", help="Data split to use (train, validation, test)", required=True, type=str) + return parser + + +def main(): + """ + This function is used to test the model on a test set. + + Args: + None + + Returns: + None + """ + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Load the config file + with open(args.config, "r") as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + + # Device + DEVICE = "cuda" + + # build output directory + output_dir = os.path.join(cfg["output_dir"], args.data_split +"_set") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Dict of dice score + dice_scores = {} + + # Load the data + test_files = load_decathlon_datalist(cfg["dataset"], True, args.data_split) + + #Create the test transforms + test_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=cfg["pixdim"], + mode=(2, 0), + ), + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # ResizeWithPadOrCropd( + # keys=["image", "label"], + # spatial_size=cfg["spatial_size"], + # ), + ] + ) + + # Create the prediction post-processing function + ## For this to work I had to add cupy-cuda117==10.6.0 to the requirements + test_post_pred = Compose([ + EnsureTyped(keys=["pred"]), + Invertd(keys=["pred"], transform=test_transforms, + orig_keys=["image"], + meta_keys=["pred_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + + # Create the data loader + test_ds = CacheDataset(data=test_files, transform=test_transforms, cache_rate=0.1, num_workers=0) + test_data_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0) + + # Load the model + net = MedNeXt( + in_channels=1, + n_channels=32, + n_classes=1, + exp_r=2, + kernel_size=3, + do_res=True, + do_res_up_down=True, + checkpoint_style="outside_block", + block_counts=[2,2,2,2,1,1,1,1,1] + ) + + net.to(DEVICE) + checkpoint = torch.load(cfg["path_to_model"], map_location=torch.device(DEVICE))["state_dict"] + # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 + for key in list(checkpoint.keys()): + if 'net.' in key: + checkpoint[key.replace('net.', '')] = checkpoint[key] + del checkpoint[key] + # remove the key loss_function.dice.class_weights because it is not needed + # I had the error but I don't really know why + if 'loss_function.dice.class_weight' in key: + del checkpoint[key] + net.load_state_dict(checkpoint) + net.eval() + + # Run inference + with torch.no_grad(): + for i, batch in enumerate(test_data_loader): + # get the test input + test_input = batch["image"].to(DEVICE) + + # run inference + batch["pred"] = sliding_window_inference(test_input, cfg["spatial_size"], mode="gaussian", + sw_batch_size=4, predictor=net, overlap=0.5, progress=False) + + # NOTE: monai's models do not normalize the output, so we need to do it manually + if bool(F.relu(batch["pred"]).max()): + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() + else: + batch["pred"] = F.relu(batch["pred"]) + + # Threshold the prediction with 0.5 based on this investigation: https://github.com/ivadomed/ms-lesion-agnostic/issues/32 + pred_cpu = batch["pred"].cpu() + pred_cpu[pred_cpu < 0.5] = 0 + pred_cpu[pred_cpu >= 0.5] = 1 + # Compute the dice score + dice = dice_score(pred_cpu, batch["label"].cpu()) + + # post-process the prediction + post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] + + # Threshold the prediction with 0.5 before saving + pred = post_test_out[0]['pred'].cpu() + pred[pred < 0.5] = 0 + pred[pred >= 0.5] = 1 + + # Get file name + file_name = test_files[i]["image"].split("/")[-1].split(".")[0] + print(f"Saving {file_name}: dice score = {dice}") + + # Save the prediction + pred_saver = SaveImage( + output_dir=output_dir , output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False) + # save the prediction + pred_saver(pred) + + # Save the dice score + dice_scores[test_files[i]["image"]] = dice + + test_input.detach() + + + # Save the dice scores + with open(os.path.join(output_dir, "dice_scores.txt"), "w") as f: + for key, value in dice_scores.items(): + f.write(f"{key}: {value}\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/test_model_optThresh.py b/monai/test_model_optThresh.py new file mode 100644 index 0000000..a51f6ff --- /dev/null +++ b/monai/test_model_optThresh.py @@ -0,0 +1,358 @@ +""" +This code is used to test the model on a test set. +It uses the class Model which was defined in the file train_monai_unet_lightning.py. +""" +import os +from monai.transforms import ( + Compose, + LoadImaged, + EnsureChannelFirstd, + Orientationd, + Spacingd, + NormalizeIntensityd, + ResizeWithPadOrCropd, + Invertd, + EnsureTyped, + SaveImage, +) +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch, Dataset) +from monai.networks.nets import AttentionUnet +import torch +from monai.inferers import sliding_window_inference +import torch.nn.functional as F +from utils.utils import dice_score +import argparse +import yaml +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') +import numpy as np + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Test the model on the test set") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + parser.add_argument("--data-split", help="Data split to use (train, validation, test)", required=True, type=str) + return parser + + +def main(): + """ + This function is used to test the model on a test set. + + Args: + None + + Returns: + None + """ + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Load the config file + with open(args.config, "r") as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + + # Device + DEVICE = "cuda" + + # build output directory + output_dir = os.path.join(cfg["output_dir"], args.data_split +"_set") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Dict of dice score + dice_scores = {} + dice_scores_0_01 = {} + dice_scores_0_02 = {} + dice_scores_0_05 = {} + dice_scores_0_1 = {} + dice_scores_0_2 = {} + dice_scores_0_3 = {} + dice_scores_0_4 = {} + dice_scores_0_5 = {} + dice_scores_0_6 = {} + dice_scores_0_7 = {} + dice_scores_0_8 = {} + dice_scores_0_9 = {} + dice_scores_0_95 = {} + dice_scores_0_98 = {} + dice_scores_0_99 = {} + + + # Load the data + test_files = load_decathlon_datalist(cfg["dataset"], True, args.data_split) + + #Create the test transforms + test_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=cfg["pixdim"], + mode=(2, 0), + ), + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # ResizeWithPadOrCropd( + # keys=["image", "label"], + # spatial_size=cfg["spatial_size"], + # ), + ] + ) + + # Create the prediction post-processing function + ## For this to work I had to add cupy-cuda117==10.6.0 to the requirements + test_post_pred = Compose([ + EnsureTyped(keys=["pred"]), + Invertd(keys=["pred"], transform=test_transforms, + orig_keys=["image"], + meta_keys=["pred_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + + # Create the data loader + test_ds = CacheDataset(data=test_files, transform=test_transforms, cache_rate=0.1, num_workers=0) + test_data_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0) + + # Load the model + net = AttentionUnet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=cfg["attention_unet_channels"], + strides=cfg["attention_unet_strides"], + dropout=0.1, + ) + net.to(DEVICE) + checkpoint = torch.load(cfg["path_to_model"], map_location=torch.device(DEVICE))["state_dict"] + # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 + for key in list(checkpoint.keys()): + if 'net.' in key: + checkpoint[key.replace('net.', '')] = checkpoint[key] + del checkpoint[key] + # remove the key loss_function.dice.class_weights because it is not needed + # I had the error but I don't really know why + if 'loss_function.dice.class_weight' in key: + del checkpoint[key] + net.load_state_dict(checkpoint) + net.eval() + + # Run inference + with torch.no_grad(): + for i, batch in enumerate(test_data_loader): + # get the test input + test_input = batch["image"].to(DEVICE) + + # run inference + batch["pred"] = sliding_window_inference(test_input, cfg["spatial_size"], mode="gaussian", + sw_batch_size=4, predictor=net, overlap=0.5, progress=False) + + # NOTE: monai's models do not normalize the output, so we need to do it manually + if bool(F.relu(batch["pred"]).max()): + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() + else: + batch["pred"] = F.relu(batch["pred"]) + + # compute the dice score + dice = dice_score(batch["pred"].cpu(), batch["label"].cpu()) + + # post-process the prediction + post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] + + pred = post_test_out[0]['pred'].cpu() + + pred_cpu = batch["pred"].cpu() + label_cpu = batch["label"].cpu() + + # Threshold the prediction and compute the dice score + pred_0 = pred_cpu.clone() + pred_0[pred_0 < 0.01] = 0 + pred_0[pred_0 >= 0.01] = 1 + dice = dice_score(pred_0, batch["label"].cpu()) + print(f"For thresh 0 dice score = {dice}") + + pred_0_01 = pred_cpu.clone() + pred_0_01[pred_0_01 < 0.01] = 0 + pred_0_01[pred_0_01 >= 0.01] = 1 + dice_0_01 = dice_score(pred_0_01, batch["label"].cpu()) + print(f"For thresh 0.01 dice score = {dice_0_01}") + + pred_0_02 = pred_cpu.clone() + pred_0_02[pred_0_02 < 0.02] = 0 + pred_0_02[pred_0_02 >= 0.02] = 1 + dice_0_02 = dice_score(pred_0_02, batch["label"].cpu()) + print(f"For thresh 0.02 dice score = {dice_0_02}") + + pred_0_05 = pred_cpu.clone() + pred_0_05[pred_0_05 < 0.05] = 0 + pred_0_05[pred_0_05 >= 0.05] = 1 + dice_0_05 = dice_score(pred_0_05, batch["label"].cpu()) + print(f"For thresh 0.05 dice score = {dice_0_05}") + + pred_0_1 = pred_cpu.clone() + pred_0_1[pred_0_1 < 0.1] = 0 + pred_0_1[pred_0_1 >= 0.1] = 1 + dice_0_1 = dice_score(pred_0_1, batch["label"].cpu()) + print(f"For thresh 0.1 dice score = {dice_0_1}") + + pred_0_2 = pred_cpu.clone() + pred_0_2[pred_0_2 < 0.2] = 0 + pred_0_2[pred_0_2 >= 0.2] = 1 + dice_0_2 = dice_score(pred_0_2, batch["label"].cpu()) + print(f"For thresh 0.2 dice score = {dice_0_2}") + + pred_0_3 = pred_cpu.clone() + pred_0_3[pred_0_3 < 0.3] = 0 + pred_0_3[pred_0_3 >= 0.3] = 1 + dice_0_3 = dice_score(pred_0_3, batch["label"].cpu()) + print(f"For thresh 0.3 dice score = {dice_0_3}") + + pred_0_4 = pred_cpu.clone() + pred_0_4[pred_0_4 < 0.4] = 0 + pred_0_4[pred_0_4 >= 0.4] = 1 + dice_0_4 = dice_score(pred_0_4, batch["label"].cpu()) + print(f"For thresh 0.4 dice score = {dice_0_4}") + + pred_0_5 = pred_cpu.clone() + pred_0_5[pred_0_5 < 0.5] = 0 + pred_0_5[pred_0_5 >= 0.5] = 1 + dice_0_5 = dice_score(pred_0_5, batch["label"].cpu()) + print(f"For thresh 0.5 dice score = {dice_0_5}") + + pred_0_6 = pred_cpu.clone() + pred_0_6[pred_0_6 < 0.6] = 0 + pred_0_6[pred_0_6 >= 0.6] = 1 + dice_0_6 = dice_score(pred_0_6, batch["label"].cpu()) + print(f"For thresh 0.6 dice score = {dice_0_6}") + + pred_0_7 = pred_cpu.clone() + pred_0_7[pred_0_7 < 0.7] = 0 + pred_0_7[pred_0_7 >= 0.7] = 1 + dice_0_7 = dice_score(pred_0_7, batch["label"].cpu()) + print(f"For thresh 0.7 dice score = {dice_0_7}") + + pred_0_8 = pred_cpu.clone() + pred_0_8[pred_0_8 < 0.8] = 0 + pred_0_8[pred_0_8 >= 0.8] = 1 + dice_0_8 = dice_score(pred_0_8, batch["label"].cpu()) + print(f"For thresh 0.8 dice score = {dice_0_8}") + + pred_0_9 = pred_cpu.clone() + pred_0_9[pred_0_9 < 0.9] = 0 + pred_0_9[pred_0_9 >= 0.9] = 1 + dice_0_9 = dice_score(pred_0_9, batch["label"].cpu()) + print(f"For thresh 0.9 dice score = {dice_0_9}") + + pred_0_95 = pred_cpu.clone() + pred_0_95[pred_0_95 < 0.95] = 0 + pred_0_95[pred_0_95 >= 0.95] = 1 + dice_0_95 = dice_score(pred_0_95, batch["label"].cpu()) + print(f"For thresh 0.95 dice score = {dice_0_95}") + + pred_0_98 = pred_cpu.clone() + pred_0_98[pred_0_98 < 0.98] = 0 + pred_0_98[pred_0_98 >= 0.98] = 1 + dice_0_98 = dice_score(pred_0_98, batch["label"].cpu()) + print(f"For thresh 0.98 dice score = {dice_0_98}") + + pred_0_99 = pred_cpu.clone() + pred_0_99[pred_0_99 < 0.99] = 0 + pred_0_99[pred_0_99 >= 0.99] = 1 + dice_0_99 = dice_score(pred_0_99, batch["label"].cpu()) + print(f"For thresh 0.99 dice score = {dice_0_99}") + + # Get file name + file_name = test_files[i]["image"].split("/")[-1].split(".")[0] + print(f"Saving {file_name}: dice score = {dice}") + + # Save the prediction + pred_saver = SaveImage( + output_dir=output_dir , output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False) + # save the prediction + pred_saver(pred) + + # Save the dice score + dice_scores[test_files[i]["image"]] = dice + dice_scores_0_01[test_files[i]["image"]] = dice_0_01 + dice_scores_0_02[test_files[i]["image"]] = dice_0_02 + dice_scores_0_05[test_files[i]["image"]] = dice_0_05 + dice_scores_0_1[test_files[i]["image"]] = dice_0_1 + dice_scores_0_2[test_files[i]["image"]] = dice_0_2 + dice_scores_0_3[test_files[i]["image"]] = dice_0_3 + dice_scores_0_4[test_files[i]["image"]] = dice_0_4 + dice_scores_0_5[test_files[i]["image"]] = dice_0_5 + dice_scores_0_6[test_files[i]["image"]] = dice_0_6 + dice_scores_0_7[test_files[i]["image"]] = dice_0_7 + dice_scores_0_8[test_files[i]["image"]] = dice_0_8 + dice_scores_0_9[test_files[i]["image"]] = dice_0_9 + dice_scores_0_95[test_files[i]["image"]] = dice_0_95 + dice_scores_0_98[test_files[i]["image"]] = dice_0_98 + dice_scores_0_99[test_files[i]["image"]] = dice_0_99 + + test_input.detach() + + + # Save the dice scores + with open(os.path.join(output_dir, "dice_scores.txt"), "w") as f: + for key, value in dice_scores.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_01.txt"), "w") as f: + for key, value in dice_scores_0_01.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_02.txt"), "w") as f: + for key, value in dice_scores_0_02.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_05.txt"), "w") as f: + for key, value in dice_scores_0_05.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_1.txt"), "w") as f: + for key, value in dice_scores_0_1.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_2.txt"), "w") as f: + for key, value in dice_scores_0_2.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_3.txt"), "w") as f: + for key, value in dice_scores_0_3.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_4.txt"), "w") as f: + for key, value in dice_scores_0_4.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_5.txt"), "w") as f: + for key, value in dice_scores_0_5.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_6.txt"), "w") as f: + for key, value in dice_scores_0_6.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_7.txt"), "w") as f: + for key, value in dice_scores_0_7.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_8.txt"), "w") as f: + for key, value in dice_scores_0_8.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_9.txt"), "w") as f: + for key, value in dice_scores_0_9.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_95.txt"), "w") as f: + for key, value in dice_scores_0_95.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_98.txt"), "w") as f: + for key, value in dice_scores_0_98.items(): + f.write(f"{key}: {value}\n") + with open(os.path.join(output_dir, "dice_scores_0_99.txt"), "w") as f: + for key, value in dice_scores_0_99.items(): + f.write(f"{key}: {value}\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/test_model_tta.py b/monai/test_model_tta.py new file mode 100644 index 0000000..1337f13 --- /dev/null +++ b/monai/test_model_tta.py @@ -0,0 +1,220 @@ +""" +This code is used to test the model on a test set. +It uses the class Model which was defined in the file train_monai_unet_lightning.py. +""" +import os +from monai.transforms import ( + Compose, + LoadImaged, + EnsureChannelFirstd, + Orientationd, + Spacingd, + NormalizeIntensityd, + ResizeWithPadOrCropd, + Invertd, + EnsureTyped, + SaveImage, + RandGaussianNoised, + RandFlipd, + Rand3DElasticd +) +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch, Dataset) +from monai.networks.nets import AttentionUnet +import torch +from monai.inferers import sliding_window_inference +import torch.nn.functional as F +from utils.utils import dice_score +import argparse +import yaml +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Test the model on the test set") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + parser.add_argument("--data-split", help="Data split to use (train, validation, test)", required=True, type=str) + return parser + + +def main(): + """ + This function is used to test the model on a test set. + + Args: + None + + Returns: + None + """ + # Get the parser + parser = get_parser() + args = parser.parse_args() + + # Load the config file + with open(args.config, "r") as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + + # Device + DEVICE = "cuda" + + # build output directory + output_dir = os.path.join(cfg["output_dir"], args.data_split +"_set") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Num test time augmentations + n_tta = 10 + + # Dict of dice score + dice_scores = [{} for i in range(n_tta)] + + # Load the data + test_files = load_decathlon_datalist(cfg["dataset"], True, args.data_split) + + #Create the test transforms + test_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=cfg["pixdim"], + mode=(2, 0), + ), + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + RandGaussianNoised( + keys=["image"], + prob=0.2, + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image"], + spatial_axis=[1], + prob=0.2, + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image"], + spatial_axis=[2], + prob=0.2, + ), + # Random elastic deformation + Rand3DElasticd( + keys=["image"], + sigma_range=(5, 7), + magnitude_range=(50, 150), + prob=0.2, + mode='bilinear', + ), + # ResizeWithPadOrCropd( + # keys=["image", "label"], + # spatial_size=cfg["spatial_size"], + # ), + ] + ) + + # Create the prediction post-processing function + ## For this to work I had to add cupy-cuda117==10.6.0 to the requirements + test_post_pred = Compose([ + EnsureTyped(keys=["pred"]), + Invertd(keys=["pred"], transform=test_transforms, + orig_keys=["image"], + meta_keys=["pred_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + + # Load the model + net = AttentionUnet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=cfg["attention_unet_channels"], + strides=cfg["attention_unet_strides"], + dropout=0.1, + ) + net.to(DEVICE) + checkpoint = torch.load(cfg["path_to_model"], map_location=torch.device(DEVICE))["state_dict"] + # NOTE: remove the 'net.' prefix from the keys because of how the model was initialized in lightning + # https://discuss.pytorch.org/t/missing-keys-unexpected-keys-in-state-dict-when-loading-self-trained-model/22379/14 + for key in list(checkpoint.keys()): + if 'net.' in key: + checkpoint[key.replace('net.', '')] = checkpoint[key] + del checkpoint[key] + # remove the key loss_function.dice.class_weights because it is not needed + # I had the error but I don't really know why + if 'loss_function.dice.class_weight' in key: + del checkpoint[key] + net.load_state_dict(checkpoint) + net.eval() + + # Create the data loader + test_ds = [CacheDataset(data=test_files, transform=test_transforms, cache_rate=0.1, num_workers=0) for i in range(n_tta)] + + # Run inference + with torch.no_grad(): + for k in range(n_tta): + test_data_loader = DataLoader(test_ds[k], batch_size=1, shuffle=False, num_workers=0) + for i, batch in enumerate(test_data_loader): + # get the test input + test_input = batch["image"].to(DEVICE) + + # run inference + batch["pred"] = sliding_window_inference(test_input, cfg["spatial_size"], mode="gaussian", + sw_batch_size=4, predictor=net, overlap=0.5, progress=False) + + # NOTE: monai's models do not normalize the output, so we need to do it manually + if bool(F.relu(batch["pred"]).max()): + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() + else: + batch["pred"] = F.relu(batch["pred"]) + + # Threshold the prediction with 0.5 based on this investigation: https://github.com/ivadomed/ms-lesion-agnostic/issues/32 + pred_cpu = batch["pred"].cpu() + pred_cpu[pred_cpu < 0.5] = 0 + pred_cpu[pred_cpu >= 0.5] = 1 + # Compute the dice score + dice = dice_score(pred_cpu, batch["label"].cpu()) + + # post-process the prediction + post_test_out = [test_post_pred(i) for i in decollate_batch(batch)] + + # Threshold the prediction with 0.5 before saving + pred = post_test_out[0]['pred'].cpu() + pred[pred < 0.5] = 0 + pred[pred >= 0.5] = 1 + + # Get file name + file_name = test_files[i]["image"].split("/")[-1].split(".")[0] + print(f"Saving {file_name}: dice score = {dice}") + + # Save the prediction + pred_saver = SaveImage( + output_dir=output_dir , output_postfix=f"pred_{k}", output_ext=".nii.gz", + separate_folder=False, print_log=False) + # save the prediction + pred_saver(pred) + + # Save the dice score + dice_scores[k][test_files[i]["image"]] = dice + + test_input.detach() + + + # Save the dice scores + for j in range(n_tta): + with open(os.path.join(output_dir, f"dice_scores_{j}.txt"), "w") as f: + for key, value in dice_scores[j].items(): + f.write(f"{key}: {value}\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/train_monai_diffusion_lightning.py b/monai/train_monai_diffusion_lightning.py new file mode 100644 index 0000000..b23862b --- /dev/null +++ b/monai/train_monai_diffusion_lightning.py @@ -0,0 +1,824 @@ +import os +import argparse +from datetime import datetime +from loguru import logger +import yaml +import nibabel as nib +from datetime import datetime +import numpy as np +import wandb +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +import matplotlib.pyplot as plt +import time +import torch.multiprocessing + +# Added this to solve problem with too many files open +## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +## Linke to other issue: https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/issues/59 +torch.multiprocessing.set_sharing_strategy('file_system') + +from utils.losses import AdapWingLoss, SoftDiceLoss + +from utils.utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, remove_small_lesions +from monai.networks.nets import UNet, BasicUNet, AttentionUnet, SwinUNETR +from generative.networks.nets.diffusion_model_unet import DiffusionModelUNet +from monai.metrics import DiceMetric +from monai.losses import DiceLoss, DiceCELoss +from monai.networks.layers import Norm +from monai.transforms import ( + EnsureChannelFirstd, + Compose, + LoadImaged, + Orientationd, + RandFlipd, + RandShiftIntensityd, + Spacingd, + RandRotate90d, + NormalizeIntensityd, + RandCropByPosNegLabeld, + BatchInverseTransform, + RandAdjustContrastd, + AsDiscreted, + RandHistogramShiftd, + ResizeWithPadOrCropd, + EnsureTyped, + RandLambdad, + CropForegroundd, + RandGaussianNoised, + LabelToContourd, + Invertd, + SaveImage, + EnsureType, + Rand3DElasticd, + RandSimulateLowResolutiond, + RandBiasFieldd, + RandAffined, + RandRotated, + RandZoomd, + RandGaussianSmoothd, + RandScaleIntensityd +) +from monai.utils import set_determinism +from monai.inferers import sliding_window_inference +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) + +# Added this because of following warning received: +## You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` +## which will trade-off precision for performance. For more details, +## read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision +# torch.set_float32_matmul_precision('medium' | 'high') + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Train a nnUNet model using monai") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + return parser + + +# create a "model"-agnostic class with PL to use different models +class Model(pl.LightningModule): + def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): + super().__init__() + self.cfg = config + self.save_hyperparameters(ignore=['net', 'loss_function']) + self.root = data_root + self.net = net + self.lr = config["lr"] + self.loss_function = loss_function + self.optimizer_class = optimizer_class + self.save_exp_id = exp_id + self.results_path = results_path + + self.best_val_dice, self.best_val_epoch = 0, 0 + self.best_val_loss = float("inf") + + # define cropping and padding dimensions + # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference + # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches + # which could be sub-optimal. + # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that + # only the SC is in context. + self.spacing = config["spatial_size"] + self.voxel_cropping_size = self.inference_roi_size = config["spatial_size"] + + # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) + self.val_post_pred = self.val_post_label = Compose([EnsureType()]) + + # define evaluation metric + self.soft_dice_metric = dice_score + # self.lesion_wise_precision_recall = lesion_wise_precision_recall + + # temp lists for storing outputs from training, validation, and testing + self.train_step_outputs = [] + self.val_step_outputs = [] + self.test_step_outputs = [] + + + # -------------------------------- + # FORWARD PASS + # -------------------------------- + def forward(self, x): + + out = self.net(x) + # # NOTE: MONAI's models only output the logits, not the output after the final activation function + # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # # as the final block applied to the input, which is just a convolutional layer with no activation function + # # Hence, we are used Normalized ReLU to normalize the logits to the final output + # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + + return out # returns logits + + + # -------------------------------- + # DATA PREPARATION + # -------------------------------- + def prepare_data(self): + # set deterministic training for reproducibility + set_determinism(seed=self.cfg["seed"]) + + # define training and validation transforms + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 0), + ), + # Normalize the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # # This crops the image around areas where the mask is non-zero + # # (the margin is added because otherwise the image would be just the size of the lesion) + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=200 + # ), + # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=0, + num_samples=4, + image_key="image", + image_threshold=0, + allow_smaller=True, + ), + # This resizes the image and the label to the spatial size defined in the config + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), + # Flips the image : left becomes right + RandFlipd( + keys=["image", "label"], + spatial_axis=[0], + prob=self.cfg["DA_probability"], + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image", "label"], + spatial_axis=[1], + prob=self.cfg["DA_probability"], + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image", "label"], + spatial_axis=[2], + prob=self.cfg["DA_probability"], + ), + # Random elastic deformation + Rand3DElasticd( + keys=["image", "label"], + sigma_range=(5, 7), + magnitude_range=(50, 150), + prob=self.cfg["DA_probability"], + mode=['bilinear', 'nearest'], + ), + # Random affine transform of the image + RandAffined( + keys=["image", "label"], + prob=self.cfg["DA_probability"], + mode=('bilinear', 'nearest'), + padding_mode='zeros', + ), + # RandAdjustContrastd( + # keys=["image"], + # prob=self.cfg["DA_probability"], + # gamma=(0.5, 4.5), + # invert_image=True, + # ), + # # we add the multiplication of the image by -1 + # RandLambdad( + # keys='image', + # func=multiply_by_negative_one, + # prob=0.5 + # ), + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', + # ), + RandGaussianNoised( + keys=["image"], + prob=self.cfg["DA_probability"], + ), + # Random simulation of low resolution + RandSimulateLowResolutiond( + keys=["image"], + zoom_range=(0.8, 1.5), + prob=self.cfg["DA_probability"] + ), + # Adding a random bias field which is usefull considering that this sometimes done for image pre-processing + RandBiasFieldd( + keys=["image"], + coeff_range=(0.0, 0.5), + degree=3, + prob=self.cfg["DA_probability"] + ), + # RandShiftIntensityd( + # keys=["image"], + # offsets=0.1, + # prob=0.2, + # ), + # EnsureTyped(keys=["image", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ), + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ] + ) + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 0), + ), + # This normalizes the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=150), + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=1, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # allow_smaller=True, + # ), + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', + # ), + # EnsureTyped(keys=["image", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ) + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ] + + ) + + # load the dataset + dataset = self.cfg["data"] + logger.info(f"Loading dataset: {dataset}") + train_files = load_decathlon_datalist(dataset, True, "train") + val_files = load_decathlon_datalist(dataset, True, "validation") + test_files = load_decathlon_datalist(dataset, True, "test") + + train_cache_rate = 0.5 + val_cache_rate = 0.25 + self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=8) + self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=val_cache_rate, num_workers=8) + + # define test transforms + transforms_test = val_transforms + + # Hidden because we don't use it + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ]) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + + # -------------------------------- + # DATA LOADERS + # -------------------------------- + def train_dataloader(self): + return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=8, + pin_memory=True, persistent_workers=True) + + + def val_dataloader(self): + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, + persistent_workers=False) + + + def test_dataloader(self): + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) + + + # -------------------------------- + # OPTIMIZATION + # -------------------------------- + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.cfg["weight_decay"]) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["max_iterations"]) + return [optimizer], [scheduler] + + + # -------------------------------- + # TRAINING + # -------------------------------- + def training_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # The following was done to debug : + # I was checking the image and the label to see if they were empty or not + + # # print(inputs.shape, labels.shape) + # input_0 = inputs[0].detach().cpu().squeeze() + # # print(input_0.shape) + # label_0 = labels[0].detach().cpu().squeeze() + + # time_0 = datetime.now() + + # # save input 0 in a nifti file + # input_0_nifti = nib.Nifti1Image(input_0.numpy(), affine=np.eye(4)) + # nib.save(input_0_nifti, f"~/ms_lesion_agnostic/temp/input_0_{time_0}.nii.gz") + + # # save label in a nifti file + # label_nifti = nib.Nifti1Image(label_0.numpy(), affine=np.eye(4)) + # nib.save(label_nifti, f"~/ms_lesion_agnostic/temp/label_0_{time_0}.nii.gz") + + # # # check if any label image patch is empty in the batch + # if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + # return None + + output = self.forward(inputs) # logits + # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") + + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate training loss + loss = self.loss_function(output, labels) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + + # Compute precision and recall + # train_precision, train_recall = self.lesion_wise_precision_recall(output.detach().cpu(), labels.detach().cpu()) + # print("sucess") + + metrics_dict = { + "loss": loss.cpu(), + "train_soft_dice": train_soft_dice.detach().cpu(), + "train_number": len(inputs), + "train_image": inputs[0].detach().cpu().squeeze(), + "train_gt": labels[0].detach().cpu().squeeze(), + "train_pred": output[0].detach().cpu().squeeze(), + # "train_precision": train_precision.detach().cpu(), + # "train_recall": train_recall.detach().cpu(), + } + self.train_step_outputs.append(metrics_dict) + + return metrics_dict + + + def on_train_epoch_end(self): + + if self.train_step_outputs == []: + # means the training step was skipped because of empty input patch + return None + else: + train_loss, train_soft_dice = 0, 0 + # precision_score, recall_score = 0, 0 + num_items = len(self.train_step_outputs) + for output in self.train_step_outputs: + train_loss += output["loss"].item() + train_soft_dice += output["train_soft_dice"].item() + # precision_score = output["train_precision"] + # recall_score = output["train_recall"] + + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) + # mean_precision_score = np.mean(precision_score.detach().numpy()) + # mean_recall_score = np.mean(recall_score.detach().numpy()) + + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss, + # "train_precision": mean_precision_score, + # "train_recall": mean_recall_score, + } + + self.log_dict(wandb_logs) + + # plot the training images + fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + gt=self.train_step_outputs[0]["train_gt"], + pred=self.train_step_outputs[0]["train_pred"], + ) + wandb.log({"training images": wandb.Image(fig)}) + plt.close(fig) + + # free up memory + self.train_step_outputs.clear() + wandb_logs.clear() + + + + # -------------------------------- + # VALIDATION + # -------------------------------- + def validation_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # NOTE: this calculates the loss on the entire image after sliding window + outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=self.forward, overlap=0.5,) + + # get probabilities from logits + outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) + + # calculate validation loss + loss = self.loss_function(outputs, labels) + + # post-process for calculating the evaluation metric + post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] + post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] + val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) + + hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + + # compute precision and recall + # val_precision, val_recall = self.lesion_wise_precision_recall(post_outputs[0].detach().cpu(), post_labels[0].detach().cpu()) + # print("sucess val") + + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, + # using .detach() to avoid storing the whole computation graph + # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 + metrics_dict = { + "val_loss": loss.detach().cpu(), + "val_soft_dice": val_soft_dice.detach().cpu(), + "val_hard_dice": val_hard_dice.detach().cpu(), + "val_number": len(post_outputs), + "val_image": inputs[0].detach().cpu().squeeze(), + "val_gt": labels[0].detach().cpu().squeeze(), + "val_pred": post_outputs[0].detach().cpu().squeeze(), + # "val_precision": val_precision.detach().cpu(), + # "val_recall": val_recall.detach().cpu(), + } + self.val_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_validation_epoch_end(self): + + val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + # val_precision, val_recall = 0, 0 + for output in self.val_step_outputs: + val_loss += output["val_loss"].sum().item() + val_soft_dice += output["val_soft_dice"].sum().item() + val_hard_dice += output["val_hard_dice"].sum().item() + num_items += output["val_number"] + # val_precision += output["val_precision"].sum().item() + # val_recall += output["val_recall"].sum().item() + + mean_val_loss = (val_loss / num_items) + mean_val_soft_dice = (val_soft_dice / num_items) + mean_val_hard_dice = (val_hard_dice / num_items) + # mean_val_precision = (val_precision / num_items) + # mean_val_recall = (val_recall / num_items) + + wandb_logs = { + "val_soft_dice": mean_val_soft_dice, + # "val_hard_dice": mean_val_hard_dice, + "val_loss": mean_val_loss, + # "val_precision": mean_val_precision, + # "val_recall": mean_val_recall, + } + + self.log_dict(wandb_logs) + + # save the best model based on validation dice score + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice + self.best_val_epoch = self.current_epoch + + # save the best model based on validation loss + if mean_val_loss < self.best_val_loss: + self.best_val_loss = mean_val_loss + self.best_val_epoch = self.current_epoch + + logger.info( + f"\nCurrent epoch: {self.current_epoch}" + f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" + # f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + f"\nAverage DiceLoss (VAL): {mean_val_loss:.4f}" + f"\nBest Average DiceLoss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + f"\n----------------------------------------------------") + + # log on to wandb + self.log_dict(wandb_logs) + + # plot 1 validation image + fig = plot_slices(image=self.val_step_outputs[0]["val_image"], + gt=self.val_step_outputs[0]["val_gt"], + pred=self.val_step_outputs[0]["val_pred"],) + wandb.log({"validation image 1": wandb.Image(fig)}) + plt.close(fig) + + # plot another validation image + fig0 = plot_slices(image=self.val_step_outputs[1]["val_image"], + gt=self.val_step_outputs[1]["val_gt"], + pred=self.val_step_outputs[1]["val_pred"],) + wandb.log({"validation image 2": wandb.Image(fig0)}) + plt.close(fig0) + + # free up memory + self.val_step_outputs.clear() + wandb_logs.clear() + + + # -------------------------------- + # TESTING + # -------------------------------- + def test_step(self, batch, batch_idx): + + test_input = batch["image"] + # print(batch["label_meta_dict"]["filename_or_obj"][0]) + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5) + + # normalize the logits + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) + + post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate soft and hard dice here (for quick overview), other metrics can be computed from + # the saved predictions using ANIMA + # 1. Dice Score + test_soft_dice = self.soft_dice_metric(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) + + metrics_dict = { + "test_hard_dice": test_hard_dice, + "test_soft_dice": test_soft_dice, + } + self.test_step_outputs.append(metrics_dict) + + return metrics_dict + + + def on_test_epoch_end(self): + + avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + + self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test + self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test + + # free up memory + self.test_step_outputs.clear() + + +# -------------------------------- +# MAIN +# -------------------------------- +def main(): + # get the parser + parser = get_parser() + args= parser.parse_args() + + # load config file + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # Setting the seed + pl.seed_everything(config["seed"], workers=True) + + # define root path for finding datalists + dataset_root = config["data"] + + # define optimizer + optimizer_class = torch.optim.Adam + + output_path = os.path.join(config["output_path"], str(datetime.now().date()) +"_" +str(datetime.now().time())) + os.makedirs(output_path, exist_ok=True) + + wandb.init(project=f'monai-ms-lesion-seg-unet', config=config, save_code=True, dir=output_path) + + logger.info("Building the model ...") + + # define model + + # net = UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=config['unet_channels'], + # strides=config['unet_strides'], + # kernel_size=3, + # up_kernel_size=3, + # num_res_units=0, + # act='PRELU', + # norm=Norm.INSTANCE, + # dropout=0.0, + # bias=True, + # adn_ordering='NDA', + # ) + + # net=UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(32, 64, 128), + # strides=(2, 2, 2, ), + # dropout=0.1 + # ) + + # net = AttentionUnet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=config["attention_unet_channels"], + # strides=config["attention_unet_strides"], + # dropout=0.1, + # ) + net = DiffusionModelUNet( + spatial_dims=3, + in_channels=1, + out_channels=1, + num_channels=(64, 128, 256, 256), + attention_levels=(False, False, True, True), + num_res_blocks=(2, 2, 2, 2), + num_head_channels=32, + with_conditioning=False, + norm_eps= 1e-6, + dropout_cattn=0.1, + ) + + # net = SwinUNETR( + # img_size=config["spatial_size"], + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # feature_size=48, + # use_checkpoint=True, + # ) + + # net.use_multiprocessing = False + + # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) + + # net = create_nnunet_from_plans() + + logger.add(os.path.join(output_path, 'log.txt'), rotation="10 MB", level="INFO") + + # define loss function + # loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # loss_func = DiceLoss(sigmoid=False, smooth_dr=1e-4) + loss_func = DiceCELoss(sigmoid=False, smooth_dr=1e-4) + # loss_func = SoftDiceLoss(smooth=1e-5) + # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above + # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") + #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") + logger.info(f"Using DiceCELoss ...") + # define callbacks + early_stopping = pl.callbacks.EarlyStopping( + monitor="val_loss", min_delta=0.00, + patience=config["early_stopping_patience"], + verbose=False, mode="min") + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + + best_model_path = os.path.join(output_path, "best_model.pth") + + # i.e. train by loading weights from scratch + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id="test", results_path=best_model_path) + + # saving the best model based on validation loss + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath= best_model_path, filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=True) + + logger.info(f"Starting training from scratch ...") + # wandb logger + exp_logger = pl.loggers.WandbLogger( + name="test", + save_dir=output_path, + group="test-on-canproco", + log_model=True, # save best model using checkpoint callback + config=config) + + # Saving training script to wandb + wandb.save(args.config) + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["eval_num"], + max_epochs=config["max_iterations"], + precision=32, + # precision='bf16-mixed', + enable_progress_bar=True) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + # Closing wandb log + wandb.finish() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/train_monai_mednext_lightning.py b/monai/train_monai_mednext_lightning.py new file mode 100644 index 0000000..48bb0de --- /dev/null +++ b/monai/train_monai_mednext_lightning.py @@ -0,0 +1,825 @@ +import os +import argparse +from datetime import datetime +from loguru import logger +import yaml +import nibabel as nib +from datetime import datetime +import numpy as np +import wandb +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +import matplotlib.pyplot as plt +import time +import torch.multiprocessing + +# Added this to solve problem with too many files open +## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +## Linke to other issue: https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/issues/59 +torch.multiprocessing.set_sharing_strategy('file_system') + +from utils.losses import AdapWingLoss, SoftDiceLoss + +from utils.utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, remove_small_lesions +from monai.networks.nets import UNet, BasicUNet, AttentionUnet, SwinUNETR +from monai.metrics import DiceMetric +from monai.losses import DiceLoss, DiceCELoss +from monai.networks.layers import Norm +from monai.transforms import ( + EnsureChannelFirstd, + Compose, + LoadImaged, + Orientationd, + RandFlipd, + RandShiftIntensityd, + Spacingd, + RandRotate90d, + NormalizeIntensityd, + RandCropByPosNegLabeld, + BatchInverseTransform, + RandAdjustContrastd, + AsDiscreted, + RandHistogramShiftd, + ResizeWithPadOrCropd, + EnsureTyped, + RandLambdad, + CropForegroundd, + RandGaussianNoised, + LabelToContourd, + Invertd, + SaveImage, + EnsureType, + Rand3DElasticd, + RandSimulateLowResolutiond, + RandBiasFieldd, + RandAffined, + RandRotated, + RandZoomd, + RandGaussianSmoothd, + RandScaleIntensityd +) +from monai.utils import set_determinism +from monai.inferers import sliding_window_inference +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) + +# Added this because of following warning received: +## You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` +## which will trade-off precision for performance. For more details, +## read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision +# torch.set_float32_matmul_precision('medium' | 'high') + +# Adding Mednext model +from nnunet_mednext import MedNeXt + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Train a nnUNet model using monai") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + return parser + + +# create a "model"-agnostic class with PL to use different models +class Model(pl.LightningModule): + def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): + super().__init__() + self.cfg = config + self.save_hyperparameters(ignore=['net', 'loss_function']) + self.root = data_root + self.net = net + self.lr = config["lr"] + self.loss_function = loss_function + self.optimizer_class = optimizer_class + self.save_exp_id = exp_id + self.results_path = results_path + + self.best_val_dice, self.best_val_epoch = 0, 0 + self.best_val_loss = float("inf") + + # define cropping and padding dimensions + # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference + # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches + # which could be sub-optimal. + # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that + # only the SC is in context. + self.spacing = config["spatial_size"] + self.voxel_cropping_size = self.inference_roi_size = config["spatial_size"] + + # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) + self.val_post_pred = self.val_post_label = Compose([EnsureType()]) + + # define evaluation metric + self.soft_dice_metric = dice_score + # self.lesion_wise_precision_recall = lesion_wise_precision_recall + + # temp lists for storing outputs from training, validation, and testing + self.train_step_outputs = [] + self.val_step_outputs = [] + self.test_step_outputs = [] + + + # -------------------------------- + # FORWARD PASS + # -------------------------------- + def forward(self, x): + + out = self.net(x) + # # NOTE: MONAI's models only output the logits, not the output after the final activation function + # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # # as the final block applied to the input, which is just a convolutional layer with no activation function + # # Hence, we are used Normalized ReLU to normalize the logits to the final output + # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + + return out # returns logits + + + # -------------------------------- + # DATA PREPARATION + # -------------------------------- + def prepare_data(self): + # set deterministic training for reproducibility + set_determinism(seed=self.cfg["seed"]) + + # define training and validation transforms + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 0), + ), + # Normalize the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # # This crops the image around areas where the mask is non-zero + # # (the margin is added because otherwise the image would be just the size of the lesion) + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=200 + # ), + # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=0, + num_samples=2, + image_key="image", + image_threshold=0, + allow_smaller=True, + ), + # This resizes the image and the label to the spatial size defined in the config + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), + # Flips the image : left becomes right + RandFlipd( + keys=["image", "label"], + spatial_axis=[0], + prob=self.cfg["DA_probability"], + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image", "label"], + spatial_axis=[1], + prob=self.cfg["DA_probability"], + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image", "label"], + spatial_axis=[2], + prob=self.cfg["DA_probability"], + ), + # Random elastic deformation + Rand3DElasticd( + keys=["image", "label"], + sigma_range=(5, 7), + magnitude_range=(50, 150), + prob=self.cfg["DA_probability"], + mode=['bilinear', 'nearest'], + ), + # Random affine transform of the image + RandAffined( + keys=["image", "label"], + prob=self.cfg["DA_probability"], + mode=('bilinear', 'nearest'), + padding_mode='zeros', + ), + # RandAdjustContrastd( + # keys=["image"], + # prob=self.cfg["DA_probability"], + # gamma=(0.5, 4.5), + # invert_image=True, + # ), + # # we add the multiplication of the image by -1 + # RandLambdad( + # keys='image', + # func=multiply_by_negative_one, + # prob=0.5 + # ), + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', + # ), + RandGaussianNoised( + keys=["image"], + prob=self.cfg["DA_probability"], + ), + # Random simulation of low resolution + RandSimulateLowResolutiond( + keys=["image"], + zoom_range=(0.8, 1.5), + prob=self.cfg["DA_probability"] + ), + # Adding a random bias field which is usefull considering that this sometimes done for image pre-processing + RandBiasFieldd( + keys=["image"], + coeff_range=(0.0, 0.5), + degree=3, + prob=self.cfg["DA_probability"] + ), + # RandShiftIntensityd( + # keys=["image"], + # offsets=0.1, + # prob=0.2, + # ), + # EnsureTyped(keys=["image", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ), + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ] + ) + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 0), + ), + # This normalizes the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=150), + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=1, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # allow_smaller=True, + # ), + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', + # ), + # EnsureTyped(keys=["image", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ) + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ] + + ) + + # load the dataset + dataset = self.cfg["data"] + logger.info(f"Loading dataset: {dataset}") + train_files = load_decathlon_datalist(dataset, True, "train") + val_files = load_decathlon_datalist(dataset, True, "validation") + test_files = load_decathlon_datalist(dataset, True, "test") + + train_cache_rate = 0.5 + val_cache_rate = 0.25 + self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=8) + self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=val_cache_rate, num_workers=8) + + # define test transforms + transforms_test = val_transforms + + # Hidden because we don't use it + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ]) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + + # -------------------------------- + # DATA LOADERS + # -------------------------------- + def train_dataloader(self): + return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=8, + pin_memory=True, persistent_workers=True) + + + def val_dataloader(self): + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, + persistent_workers=False) + + + def test_dataloader(self): + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) + + + # -------------------------------- + # OPTIMIZATION + # -------------------------------- + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.cfg["weight_decay"]) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["max_iterations"]) + return [optimizer], [scheduler] + + + # -------------------------------- + # TRAINING + # -------------------------------- + def training_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # The following was done to debug : + # I was checking the image and the label to see if they were empty or not + + # # print(inputs.shape, labels.shape) + # input_0 = inputs[0].detach().cpu().squeeze() + # # print(input_0.shape) + # label_0 = labels[0].detach().cpu().squeeze() + + # time_0 = datetime.now() + + # # save input 0 in a nifti file + # input_0_nifti = nib.Nifti1Image(input_0.numpy(), affine=np.eye(4)) + # nib.save(input_0_nifti, f"~/ms_lesion_agnostic/temp/input_0_{time_0}.nii.gz") + + # # save label in a nifti file + # label_nifti = nib.Nifti1Image(label_0.numpy(), affine=np.eye(4)) + # nib.save(label_nifti, f"~/ms_lesion_agnostic/temp/label_0_{time_0}.nii.gz") + + # # # check if any label image patch is empty in the batch + # if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + # return None + + output = self.forward(inputs) # logits + # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") + + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate training loss + loss = self.loss_function(output, labels) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + + # Compute precision and recall + # train_precision, train_recall = self.lesion_wise_precision_recall(output.detach().cpu(), labels.detach().cpu()) + # print("sucess") + + metrics_dict = { + "loss": loss.cpu(), + "train_soft_dice": train_soft_dice.detach().cpu(), + "train_number": len(inputs), + "train_image": inputs[0].detach().cpu().squeeze(), + "train_gt": labels[0].detach().cpu().squeeze(), + "train_pred": output[0].detach().cpu().squeeze(), + # "train_precision": train_precision.detach().cpu(), + # "train_recall": train_recall.detach().cpu(), + } + self.train_step_outputs.append(metrics_dict) + + return metrics_dict + + + def on_train_epoch_end(self): + + if self.train_step_outputs == []: + # means the training step was skipped because of empty input patch + return None + else: + train_loss, train_soft_dice = 0, 0 + # precision_score, recall_score = 0, 0 + num_items = len(self.train_step_outputs) + for output in self.train_step_outputs: + train_loss += output["loss"].item() + train_soft_dice += output["train_soft_dice"].item() + # precision_score = output["train_precision"] + # recall_score = output["train_recall"] + + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) + # mean_precision_score = np.mean(precision_score.detach().numpy()) + # mean_recall_score = np.mean(recall_score.detach().numpy()) + + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss, + # "train_precision": mean_precision_score, + # "train_recall": mean_recall_score, + } + + self.log_dict(wandb_logs) + + # plot the training images + fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + gt=self.train_step_outputs[0]["train_gt"], + pred=self.train_step_outputs[0]["train_pred"], + ) + wandb.log({"training images": wandb.Image(fig)}) + plt.close(fig) + + # free up memory + self.train_step_outputs.clear() + wandb_logs.clear() + + + + # -------------------------------- + # VALIDATION + # -------------------------------- + def validation_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # NOTE: this calculates the loss on the entire image after sliding window + outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=self.forward, overlap=0.5,) + + # get probabilities from logits + outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) + + # calculate validation loss + loss = self.loss_function(outputs, labels) + + # post-process for calculating the evaluation metric + post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] + post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] + val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) + + hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + + # compute precision and recall + # val_precision, val_recall = self.lesion_wise_precision_recall(post_outputs[0].detach().cpu(), post_labels[0].detach().cpu()) + # print("sucess val") + + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, + # using .detach() to avoid storing the whole computation graph + # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 + metrics_dict = { + "val_loss": loss.detach().cpu(), + "val_soft_dice": val_soft_dice.detach().cpu(), + "val_hard_dice": val_hard_dice.detach().cpu(), + "val_number": len(post_outputs), + "val_image": inputs[0].detach().cpu().squeeze(), + "val_gt": labels[0].detach().cpu().squeeze(), + "val_pred": post_outputs[0].detach().cpu().squeeze(), + # "val_precision": val_precision.detach().cpu(), + # "val_recall": val_recall.detach().cpu(), + } + self.val_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_validation_epoch_end(self): + + val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + # val_precision, val_recall = 0, 0 + for output in self.val_step_outputs: + val_loss += output["val_loss"].sum().item() + val_soft_dice += output["val_soft_dice"].sum().item() + val_hard_dice += output["val_hard_dice"].sum().item() + num_items += output["val_number"] + # val_precision += output["val_precision"].sum().item() + # val_recall += output["val_recall"].sum().item() + + mean_val_loss = (val_loss / num_items) + mean_val_soft_dice = (val_soft_dice / num_items) + mean_val_hard_dice = (val_hard_dice / num_items) + # mean_val_precision = (val_precision / num_items) + # mean_val_recall = (val_recall / num_items) + + wandb_logs = { + "val_soft_dice": mean_val_soft_dice, + # "val_hard_dice": mean_val_hard_dice, + "val_loss": mean_val_loss, + # "val_precision": mean_val_precision, + # "val_recall": mean_val_recall, + } + + self.log_dict(wandb_logs) + + # save the best model based on validation dice score + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice + self.best_val_epoch = self.current_epoch + + # save the best model based on validation loss + if mean_val_loss < self.best_val_loss: + self.best_val_loss = mean_val_loss + self.best_val_epoch = self.current_epoch + + logger.info( + f"\nCurrent epoch: {self.current_epoch}" + f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" + # f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + f"\nAverage DiceLoss (VAL): {mean_val_loss:.4f}" + f"\nBest Average DiceLoss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + f"\n----------------------------------------------------") + + # log on to wandb + self.log_dict(wandb_logs) + + # plot 1 validation image + fig = plot_slices(image=self.val_step_outputs[0]["val_image"], + gt=self.val_step_outputs[0]["val_gt"], + pred=self.val_step_outputs[0]["val_pred"],) + wandb.log({"validation image 1": wandb.Image(fig)}) + plt.close(fig) + + # plot another validation image + fig0 = plot_slices(image=self.val_step_outputs[1]["val_image"], + gt=self.val_step_outputs[1]["val_gt"], + pred=self.val_step_outputs[1]["val_pred"],) + wandb.log({"validation image 2": wandb.Image(fig0)}) + plt.close(fig0) + + # free up memory + self.val_step_outputs.clear() + wandb_logs.clear() + + + # -------------------------------- + # TESTING + # -------------------------------- + def test_step(self, batch, batch_idx): + + test_input = batch["image"] + # print(batch["label_meta_dict"]["filename_or_obj"][0]) + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5) + + # normalize the logits + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) + + post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate soft and hard dice here (for quick overview), other metrics can be computed from + # the saved predictions using ANIMA + # 1. Dice Score + test_soft_dice = self.soft_dice_metric(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) + + metrics_dict = { + "test_hard_dice": test_hard_dice, + "test_soft_dice": test_soft_dice, + } + self.test_step_outputs.append(metrics_dict) + + return metrics_dict + + + def on_test_epoch_end(self): + + avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + + self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test + self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test + + # free up memory + self.test_step_outputs.clear() + + +# -------------------------------- +# MAIN +# -------------------------------- +def main(): + # get the parser + parser = get_parser() + args= parser.parse_args() + + # load config file + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # Setting the seed + pl.seed_everything(config["seed"], workers=True) + + # define root path for finding datalists + dataset_root = config["data"] + + # define optimizer + optimizer_class = torch.optim.Adam + + output_path = os.path.join(config["output_path"], str(datetime.now().date()) +"_" +str(datetime.now().time())) + os.makedirs(output_path, exist_ok=True) + + wandb.init(project=f'monai-ms-lesion-seg-unet', config=config, save_code=True, dir=output_path) + + logger.info("Building the model ...") + + # define model + + # net = UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=config['unet_channels'], + # strides=config['unet_strides'], + # kernel_size=3, + # up_kernel_size=3, + # num_res_units=0, + # act='PRELU', + # norm=Norm.INSTANCE, + # dropout=0.0, + # bias=True, + # adn_ordering='NDA', + # ) + + # net=UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(32, 64, 128), + # strides=(2, 2, 2, ), + # dropout=0.1 + # ) + + # net = AttentionUnet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=config["attention_unet_channels"], + # strides=config["attention_unet_strides"], + # dropout=0.1, + # ) + + # net = SwinUNETR( + # img_size=config["spatial_size"], + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # feature_size=48, + # use_checkpoint=True, + # ) + + net = MedNeXt( + in_channels=1, + n_channels=32, + n_classes=1, + exp_r=2, + kernel_size=3, + do_res=True, + do_res_up_down=True, + checkpoint_style="outside_block", + block_counts=[2,2,2,2,1,1,1,1,1] + ) + + # net.use_multiprocessing = False + + # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) + + # net = create_nnunet_from_plans() + + logger.add(os.path.join(output_path, 'log.txt'), rotation="10 MB", level="INFO") + + # define loss function + # loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # loss_func = DiceLoss(sigmoid=False, smooth_dr=1e-4) + loss_func = DiceCELoss(sigmoid=False, smooth_dr=1e-4) + # loss_func = SoftDiceLoss(smooth=1e-5) + # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above + # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") + #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") + logger.info(f"Using DiceCELoss ...") + # define callbacks + early_stopping = pl.callbacks.EarlyStopping( + monitor="val_loss", min_delta=0.00, + patience=config["early_stopping_patience"], + verbose=False, mode="min") + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + + best_model_path = os.path.join(output_path, "best_model.pth") + + # i.e. train by loading weights from scratch + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id="test", results_path=best_model_path) + + # saving the best model based on validation loss + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath= best_model_path, filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=True) + + logger.info(f"Starting training from scratch ...") + # wandb logger + exp_logger = pl.loggers.WandbLogger( + name="test", + save_dir=output_path, + group="test-on-canproco", + log_model=True, # save best model using checkpoint callback + config=config) + + # Saving training script to wandb + wandb.save(args.config) + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["eval_num"], + max_epochs=config["max_iterations"], + precision=32, + # precision='bf16-mixed', + enable_progress_bar=True) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + # Closing wandb log + wandb.finish() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/train_monai_unet_lightning.py b/monai/train_monai_unet_lightning.py new file mode 100644 index 0000000..4685086 --- /dev/null +++ b/monai/train_monai_unet_lightning.py @@ -0,0 +1,811 @@ +import os +import argparse +from datetime import datetime +from loguru import logger +import yaml +import nibabel as nib +from datetime import datetime +import numpy as np +import wandb +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +import matplotlib.pyplot as plt +import time +import torch.multiprocessing + +# Added this to solve problem with too many files open +## Link here : https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +## Linke to other issue: https://github.com/sct-pipeline/contrast-agnostic-softseg-spinalcord/issues/59 +torch.multiprocessing.set_sharing_strategy('file_system') + +from utils.losses import AdapWingLoss, SoftDiceLoss + +from utils.utils import dice_score, check_empty_patch, multiply_by_negative_one, plot_slices, remove_small_lesions +from monai.networks.nets import UNet, BasicUNet, AttentionUnet, SwinUNETR +from monai.metrics import DiceMetric +from monai.losses import DiceLoss, DiceCELoss +from monai.networks.layers import Norm +from monai.transforms import ( + EnsureChannelFirstd, + Compose, + LoadImaged, + Orientationd, + RandFlipd, + RandShiftIntensityd, + Spacingd, + RandRotate90d, + NormalizeIntensityd, + RandCropByPosNegLabeld, + BatchInverseTransform, + RandAdjustContrastd, + AsDiscreted, + RandHistogramShiftd, + ResizeWithPadOrCropd, + EnsureTyped, + RandLambdad, + CropForegroundd, + RandGaussianNoised, + LabelToContourd, + Invertd, + SaveImage, + EnsureType, + Rand3DElasticd, + RandSimulateLowResolutiond, + RandBiasFieldd, + RandAffined, + RandRotated, + RandZoomd, + RandGaussianSmoothd, + RandScaleIntensityd +) +from monai.utils import set_determinism +from monai.inferers import sliding_window_inference +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) + +# Added this because of following warning received: +## You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` +## which will trade-off precision for performance. For more details, +## read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision +# torch.set_float32_matmul_precision('medium' | 'high') + + +def get_parser(): + """ + This function returns the parser for the command line arguments. + """ + parser = argparse.ArgumentParser(description="Train a nnUNet model using monai") + parser.add_argument("-c", "--config", help="Path to the config file (.yml file)", required=True) + return parser + + +# create a "model"-agnostic class with PL to use different models +class Model(pl.LightningModule): + def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): + super().__init__() + self.cfg = config + self.save_hyperparameters(ignore=['net', 'loss_function']) + self.root = data_root + self.net = net + self.lr = config["lr"] + self.loss_function = loss_function + self.optimizer_class = optimizer_class + self.save_exp_id = exp_id + self.results_path = results_path + + self.best_val_dice, self.best_val_epoch = 0, 0 + self.best_val_loss = float("inf") + + # define cropping and padding dimensions + # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference + # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches + # which could be sub-optimal. + # On the other hand, ivadomed used a patch-size that's heavily padded along the R-L direction so that + # only the SC is in context. + self.spacing = config["spatial_size"] + self.voxel_cropping_size = self.inference_roi_size = config["spatial_size"] + + # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) + self.val_post_pred = self.val_post_label = Compose([EnsureType()]) + + # define evaluation metric + self.soft_dice_metric = dice_score + # self.lesion_wise_precision_recall = lesion_wise_precision_recall + + # temp lists for storing outputs from training, validation, and testing + self.train_step_outputs = [] + self.val_step_outputs = [] + self.test_step_outputs = [] + + + # -------------------------------- + # FORWARD PASS + # -------------------------------- + def forward(self, x): + + out = self.net(x) + # # NOTE: MONAI's models only output the logits, not the output after the final activation function + # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # # as the final block applied to the input, which is just a convolutional layer with no activation function + # # Hence, we are used Normalized ReLU to normalize the logits to the final output + # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + + return out # returns logits + + + # -------------------------------- + # DATA PREPARATION + # -------------------------------- + def prepare_data(self): + # set deterministic training for reproducibility + set_determinism(seed=self.cfg["seed"]) + + # define training and validation transforms + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 0), + ), + # Normalize the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # # This crops the image around areas where the mask is non-zero + # # (the margin is added because otherwise the image would be just the size of the lesion) + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=200 + # ), + # This crops the image around a foreground object of label with ratio pos/(pos+neg) (however, it cannot pad so keeping padding after) + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=self.cfg["spatial_size"], + pos=1, + neg=0, + num_samples=4, + image_key="image", + image_threshold=0, + allow_smaller=True, + ), + # This resizes the image and the label to the spatial size defined in the config + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), + # Flips the image : left becomes right + RandFlipd( + keys=["image", "label"], + spatial_axis=[0], + prob=self.cfg["DA_probability"], + ), + # Flips the image : supperior becomes inferior + RandFlipd( + keys=["image", "label"], + spatial_axis=[1], + prob=self.cfg["DA_probability"], + ), + # Flips the image : anterior becomes posterior + RandFlipd( + keys=["image", "label"], + spatial_axis=[2], + prob=self.cfg["DA_probability"], + ), + # Random elastic deformation + Rand3DElasticd( + keys=["image", "label"], + sigma_range=(5, 7), + magnitude_range=(50, 150), + prob=self.cfg["DA_probability"], + mode=['bilinear', 'nearest'], + ), + # Random affine transform of the image + RandAffined( + keys=["image", "label"], + prob=self.cfg["DA_probability"], + mode=('bilinear', 'nearest'), + padding_mode='zeros', + ), + # RandAdjustContrastd( + # keys=["image"], + # prob=self.cfg["DA_probability"], + # gamma=(0.5, 4.5), + # invert_image=True, + # ), + # # we add the multiplication of the image by -1 + # RandLambdad( + # keys='image', + # func=multiply_by_negative_one, + # prob=0.5 + # ), + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', + # ), + RandGaussianNoised( + keys=["image"], + prob=self.cfg["DA_probability"], + ), + # Random simulation of low resolution + RandSimulateLowResolutiond( + keys=["image"], + zoom_range=(0.8, 1.5), + prob=self.cfg["DA_probability"] + ), + # Adding a random bias field which is usefull considering that this sometimes done for image pre-processing + RandBiasFieldd( + keys=["image"], + coeff_range=(0.0, 0.5), + degree=3, + prob=self.cfg["DA_probability"] + ), + # RandShiftIntensityd( + # keys=["image"], + # offsets=0.1, + # prob=0.2, + # ), + # EnsureTyped(keys=["image", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ), + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ] + ) + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], reader="NibabelReader"), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RPI"), + Spacingd( + keys=["image", "label"], + pixdim=self.cfg["pixdim"], + mode=(2, 0), + ), + # This normalizes the intensity of the image + NormalizeIntensityd( + keys=["image"], + nonzero=False, + channel_wise=False + ), + # CropForegroundd( + # keys=["image", "label"], + # source_key="label", + # margin=150), + # RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=self.cfg["spatial_size"], + # pos=1, + # neg=1, + # num_samples=4, + # image_key="image", + # image_threshold=0, + # allow_smaller=True, + # ), + ResizeWithPadOrCropd( + keys=["image", "label"], + spatial_size=self.cfg["spatial_size"], + ), + # LabelToContourd( + # keys=["image"], + # kernel_type='Laplace', + # ), + # EnsureTyped(keys=["image", "label"]), + # AsDiscreted( + # keys=["label"], + # num_classes=2, + # threshold_values=True, + # logit_thresh=0.2, + # ) + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ] + + ) + + # load the dataset + dataset = self.cfg["data"] + logger.info(f"Loading dataset: {dataset}") + train_files = load_decathlon_datalist(dataset, True, "train") + val_files = load_decathlon_datalist(dataset, True, "validation") + test_files = load_decathlon_datalist(dataset, True, "test") + + train_cache_rate = 0.5 + val_cache_rate = 0.25 + self.train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=train_cache_rate, num_workers=8) + self.val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=val_cache_rate, num_workers=8) + + # define test transforms + transforms_test = val_transforms + + # Hidden because we don't use it + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + # # Remove small lesions in the label + # RandLambdad( + # keys='label', + # func=lambda label: remove_small_lesions(label, self.cfg["pixdim"]), + # prob=1.0 + # ) + ]) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + + # -------------------------------- + # DATA LOADERS + # -------------------------------- + def train_dataloader(self): + return DataLoader(self.train_ds, batch_size=self.cfg["batch_size"], shuffle=True, num_workers=8, + pin_memory=True, persistent_workers=True) + + + def val_dataloader(self): + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, + persistent_workers=False) + + + def test_dataloader(self): + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) + + + # -------------------------------- + # OPTIMIZATION + # -------------------------------- + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.cfg["weight_decay"]) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["max_iterations"]) + return [optimizer], [scheduler] + + + # -------------------------------- + # TRAINING + # -------------------------------- + def training_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # The following was done to debug : + # I was checking the image and the label to see if they were empty or not + + # # print(inputs.shape, labels.shape) + # input_0 = inputs[0].detach().cpu().squeeze() + # # print(input_0.shape) + # label_0 = labels[0].detach().cpu().squeeze() + + # time_0 = datetime.now() + + # # save input 0 in a nifti file + # input_0_nifti = nib.Nifti1Image(input_0.numpy(), affine=np.eye(4)) + # nib.save(input_0_nifti, f"~/ms_lesion_agnostic/temp/input_0_{time_0}.nii.gz") + + # # save label in a nifti file + # label_nifti = nib.Nifti1Image(label_0.numpy(), affine=np.eye(4)) + # nib.save(label_nifti, f"~/ms_lesion_agnostic/temp/label_0_{time_0}.nii.gz") + + # # # check if any label image patch is empty in the batch + # if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + # return None + + output = self.forward(inputs) # logits + # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") + + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate training loss + loss = self.loss_function(output, labels) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + + # Compute precision and recall + # train_precision, train_recall = self.lesion_wise_precision_recall(output.detach().cpu(), labels.detach().cpu()) + # print("sucess") + + metrics_dict = { + "loss": loss.cpu(), + "train_soft_dice": train_soft_dice.detach().cpu(), + "train_number": len(inputs), + "train_image": inputs[0].detach().cpu().squeeze(), + "train_gt": labels[0].detach().cpu().squeeze(), + "train_pred": output[0].detach().cpu().squeeze(), + # "train_precision": train_precision.detach().cpu(), + # "train_recall": train_recall.detach().cpu(), + } + self.train_step_outputs.append(metrics_dict) + + return metrics_dict + + + def on_train_epoch_end(self): + + if self.train_step_outputs == []: + # means the training step was skipped because of empty input patch + return None + else: + train_loss, train_soft_dice = 0, 0 + # precision_score, recall_score = 0, 0 + num_items = len(self.train_step_outputs) + for output in self.train_step_outputs: + train_loss += output["loss"].item() + train_soft_dice += output["train_soft_dice"].item() + # precision_score = output["train_precision"] + # recall_score = output["train_recall"] + + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) + # mean_precision_score = np.mean(precision_score.detach().numpy()) + # mean_recall_score = np.mean(recall_score.detach().numpy()) + + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss, + # "train_precision": mean_precision_score, + # "train_recall": mean_recall_score, + } + + self.log_dict(wandb_logs) + + # plot the training images + fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + gt=self.train_step_outputs[0]["train_gt"], + pred=self.train_step_outputs[0]["train_pred"], + ) + wandb.log({"training images": wandb.Image(fig)}) + plt.close(fig) + + # free up memory + self.train_step_outputs.clear() + wandb_logs.clear() + + + + # -------------------------------- + # VALIDATION + # -------------------------------- + def validation_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # NOTE: this calculates the loss on the entire image after sliding window + outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=self.forward, overlap=0.5,) + + # get probabilities from logits + outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) + + # calculate validation loss + loss = self.loss_function(outputs, labels) + + # post-process for calculating the evaluation metric + post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] + post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] + val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) + + hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + + # compute precision and recall + # val_precision, val_recall = self.lesion_wise_precision_recall(post_outputs[0].detach().cpu(), post_labels[0].detach().cpu()) + # print("sucess val") + + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, + # using .detach() to avoid storing the whole computation graph + # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 + metrics_dict = { + "val_loss": loss.detach().cpu(), + "val_soft_dice": val_soft_dice.detach().cpu(), + "val_hard_dice": val_hard_dice.detach().cpu(), + "val_number": len(post_outputs), + "val_image": inputs[0].detach().cpu().squeeze(), + "val_gt": labels[0].detach().cpu().squeeze(), + "val_pred": post_outputs[0].detach().cpu().squeeze(), + # "val_precision": val_precision.detach().cpu(), + # "val_recall": val_recall.detach().cpu(), + } + self.val_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_validation_epoch_end(self): + + val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + # val_precision, val_recall = 0, 0 + for output in self.val_step_outputs: + val_loss += output["val_loss"].sum().item() + val_soft_dice += output["val_soft_dice"].sum().item() + val_hard_dice += output["val_hard_dice"].sum().item() + num_items += output["val_number"] + # val_precision += output["val_precision"].sum().item() + # val_recall += output["val_recall"].sum().item() + + mean_val_loss = (val_loss / num_items) + mean_val_soft_dice = (val_soft_dice / num_items) + mean_val_hard_dice = (val_hard_dice / num_items) + # mean_val_precision = (val_precision / num_items) + # mean_val_recall = (val_recall / num_items) + + wandb_logs = { + "val_soft_dice": mean_val_soft_dice, + # "val_hard_dice": mean_val_hard_dice, + "val_loss": mean_val_loss, + # "val_precision": mean_val_precision, + # "val_recall": mean_val_recall, + } + + self.log_dict(wandb_logs) + + # save the best model based on validation dice score + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice + self.best_val_epoch = self.current_epoch + + # save the best model based on validation loss + if mean_val_loss < self.best_val_loss: + self.best_val_loss = mean_val_loss + self.best_val_epoch = self.current_epoch + + logger.info( + f"\nCurrent epoch: {self.current_epoch}" + f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" + # f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + f"\nAverage DiceLoss (VAL): {mean_val_loss:.4f}" + f"\nBest Average DiceLoss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + f"\n----------------------------------------------------") + + # log on to wandb + self.log_dict(wandb_logs) + + # plot 1 validation image + fig = plot_slices(image=self.val_step_outputs[0]["val_image"], + gt=self.val_step_outputs[0]["val_gt"], + pred=self.val_step_outputs[0]["val_pred"],) + wandb.log({"validation image 1": wandb.Image(fig)}) + plt.close(fig) + + # plot another validation image + fig0 = plot_slices(image=self.val_step_outputs[1]["val_image"], + gt=self.val_step_outputs[1]["val_gt"], + pred=self.val_step_outputs[1]["val_pred"],) + wandb.log({"validation image 2": wandb.Image(fig0)}) + plt.close(fig0) + + # free up memory + self.val_step_outputs.clear() + wandb_logs.clear() + + + # -------------------------------- + # TESTING + # -------------------------------- + def test_step(self, batch, batch_idx): + + test_input = batch["image"] + # print(batch["label_meta_dict"]["filename_or_obj"][0]) + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5) + + # normalize the logits + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) + + post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate soft and hard dice here (for quick overview), other metrics can be computed from + # the saved predictions using ANIMA + # 1. Dice Score + test_soft_dice = self.soft_dice_metric(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) + + metrics_dict = { + "test_hard_dice": test_hard_dice, + "test_soft_dice": test_soft_dice, + } + self.test_step_outputs.append(metrics_dict) + + return metrics_dict + + + def on_test_epoch_end(self): + + avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + + self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test + self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test + + # free up memory + self.test_step_outputs.clear() + + +# -------------------------------- +# MAIN +# -------------------------------- +def main(): + # get the parser + parser = get_parser() + args= parser.parse_args() + + # load config file + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # Setting the seed + pl.seed_everything(config["seed"], workers=True) + + # define root path for finding datalists + dataset_root = config["data"] + + # define optimizer + optimizer_class = torch.optim.Adam + + output_path = os.path.join(config["output_path"], str(datetime.now().date()) +"_" +str(datetime.now().time())) + os.makedirs(output_path, exist_ok=True) + + wandb.init(project=f'monai-ms-lesion-seg-unet', config=config, save_code=True, dir=output_path) + + logger.info("Building the model ...") + + # define model + + # net = UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=config['unet_channels'], + # strides=config['unet_strides'], + # kernel_size=3, + # up_kernel_size=3, + # num_res_units=0, + # act='PRELU', + # norm=Norm.INSTANCE, + # dropout=0.0, + # bias=True, + # adn_ordering='NDA', + # ) + + # net=UNet( + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # channels=(32, 64, 128), + # strides=(2, 2, 2, ), + # dropout=0.1 + # ) + + net = AttentionUnet( + spatial_dims=3, + in_channels=1, + out_channels=1, + channels=config["attention_unet_channels"], + strides=config["attention_unet_strides"], + dropout=0.1, + ) + + # net = SwinUNETR( + # img_size=config["spatial_size"], + # spatial_dims=3, + # in_channels=1, + # out_channels=1, + # feature_size=48, + # use_checkpoint=True, + # ) + + # net.use_multiprocessing = False + + # net = BasicUNet(spatial_dims=3, features=(32, 64, 128, 256, 32), out_channels=1) + + # net = create_nnunet_from_plans() + + logger.add(os.path.join(output_path, 'log.txt'), rotation="10 MB", level="INFO") + + # define loss function + # loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # loss_func = DiceLoss(sigmoid=False, smooth_dr=1e-4) + loss_func = DiceCELoss(sigmoid=False, smooth_dr=1e-4) + # loss_func = SoftDiceLoss(smooth=1e-5) + # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above + # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") + #logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") + logger.info(f"Using DiceCELoss ...") + # define callbacks + early_stopping = pl.callbacks.EarlyStopping( + monitor="val_loss", min_delta=0.00, + patience=config["early_stopping_patience"], + verbose=False, mode="min") + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + + best_model_path = os.path.join(output_path, "best_model.pth") + + # i.e. train by loading weights from scratch + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id="test", results_path=best_model_path) + + # saving the best model based on validation loss + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath= best_model_path, filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=True) + + logger.info(f"Starting training from scratch ...") + # wandb logger + exp_logger = pl.loggers.WandbLogger( + name="test", + save_dir=output_path, + group="test-on-canproco", + log_model=True, # save best model using checkpoint callback + config=config) + + # Saving training script to wandb + wandb.save(args.config) + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["eval_num"], + max_epochs=config["max_iterations"], + precision=32, + # precision='bf16-mixed', + enable_progress_bar=True) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + # Closing wandb log + wandb.finish() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/monai/utils/image.py b/monai/utils/image.py new file mode 100644 index 0000000..03e670c --- /dev/null +++ b/monai/utils/image.py @@ -0,0 +1,685 @@ +import os +import numpy as np +import nibabel as nib +import logging +from copy import deepcopy + +logger = logging.getLogger(__name__) + +class Image(object): + """ + Compact version of SCT's Image Class (https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/master/spinalcordtoolbox/image.py#L245) + Create an object that behaves similarly to nibabel's image object. Useful additions include: dims, change_orientation and getNonZeroCoordinates. + """ + + def __init__(self, param=None, hdr=None, orientation=None, absolutepath=None, dim=None): + """ + :param param: string indicating a path to a image file or an `Image` object. + """ + + # initialization of all parameters + self.affine = None + self.data = None + self._path = None + self.ext = "" + + if absolutepath is not None: + self._path = os.path.abspath(absolutepath) + + # Case 1: load an image from file + if isinstance(param, str): + self.loadFromPath(param) + # Case 2: create a copy of an existing `Image` object + elif isinstance(param, type(self)): + self.copy(param) + # Case 3: create a blank image from a list of dimensions + elif isinstance(param, list): + self.data = np.zeros(param) + self.hdr = hdr.copy() if hdr is not None else nib.Nifti1Header() + self.hdr.set_data_shape(self.data.shape) + # Case 4: create an image from an existing data array + elif isinstance(param, (np.ndarray, np.generic)): + self.data = param + self.hdr = hdr.copy() if hdr is not None else nib.Nifti1Header() + self.hdr.set_data_shape(self.data.shape) + else: + raise TypeError('Image constructor takes at least one argument.') + + # Fix any mismatch between the array's datatype and the header datatype + self.fix_header_dtype() + + @property + def dim(self): + return get_dimension(self) + + @property + def orientation(self): + return get_orientation(self) + + @property + def absolutepath(self): + """ + Storage path (either actual or potential) + + Notes: + + - As several tools perform chdir() it's very important to have absolute paths + - When set, if relative: + + - If it already existed, it becomes a new basename in the old dirname + - Else, it becomes absolute (shortcut) + + Usually not directly touched (use `Image.save`), but in some cases it's + the best way to set it. + """ + return self._path + + @absolutepath.setter + def absolutepath(self, value): + if value is None: + self._path = None + return + elif not os.path.isabs(value) and self._path is not None: + value = os.path.join(os.path.dirname(self._path), value) + elif not os.path.isabs(value): + value = os.path.abspath(value) + self._path = value + + @property + def header(self): + return self.hdr + + @header.setter + def header(self, value): + self.hdr = value + + def __deepcopy__(self, memo): + return type(self)(deepcopy(self.data, memo), deepcopy(self.hdr, memo), deepcopy(self.orientation, memo), deepcopy(self.absolutepath, memo), deepcopy(self.dim, memo)) + + def copy(self, image=None): + if image is not None: + self.affine = deepcopy(image.affine) + self.data = deepcopy(image.data) + self.hdr = deepcopy(image.hdr) + self._path = deepcopy(image._path) + else: + return deepcopy(self) + + def loadFromPath(self, path): + """ + This function load an image from an absolute path using nibabel library + + :param path: path of the file from which the image will be loaded + :return: + """ + + self.absolutepath = os.path.abspath(path) + im_file = nib.load(self.absolutepath, mmap=True) + self.affine = im_file.affine.copy() + self.data = np.asanyarray(im_file.dataobj) + self.hdr = im_file.header.copy() + if path != self.absolutepath: + logger.debug("Loaded %s (%s) orientation %s shape %s", path, self.absolutepath, self.orientation, self.data.shape) + else: + logger.debug("Loaded %s orientation %s shape %s", path, self.orientation, self.data.shape) + + def change_orientation(self, orientation, inverse=False): + """ + Change orientation on image (in-place). + + :param orientation: orientation string (SCT "from" convention) + + :param inverse: if you think backwards, use this to specify that you actually\ + want to transform *from* the specified orientation, not *to*\ + it. + + """ + change_orientation(self, orientation, self, inverse=inverse) + return self + + def getNonZeroCoordinates(self, sorting=None, reverse_coord=False): + """ + This function return all the non-zero coordinates that the image contains. + Coordinate list can also be sorted by x, y, z, or the value with the parameter sorting='x', sorting='y', sorting='z' or sorting='value' + If reverse_coord is True, coordinate are sorted from larger to smaller. + + Removed Coordinate object + """ + n_dim = 1 + if self.dim[3] == 1: + n_dim = 3 + else: + n_dim = 4 + if self.dim[2] == 1: + n_dim = 2 + + if n_dim == 3: + X, Y, Z = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]] for i in range(0, len(X))] + elif n_dim == 2: + try: + X, Y = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], 0, self.data[X[i], Y[i]]] for i in range(0, len(X))] + except ValueError: + X, Y, Z = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], 0, self.data[X[i], Y[i], 0]] for i in range(0, len(X))] + + if sorting is not None: + if reverse_coord not in [True, False]: + raise ValueError('reverse_coord parameter must be a boolean') + + if sorting == 'x': + list_coordinates = sorted(list_coordinates, key=lambda el: el[0], reverse=reverse_coord) + elif sorting == 'y': + list_coordinates = sorted(list_coordinates, key=lambda el: el[1], reverse=reverse_coord) + elif sorting == 'z': + list_coordinates = sorted(list_coordinates, key=lambda el: el[2], reverse=reverse_coord) + elif sorting == 'value': + list_coordinates = sorted(list_coordinates, key=lambda el: el[3], reverse=reverse_coord) + else: + raise ValueError("sorting parameter must be either 'x', 'y', 'z' or 'value'") + + return list_coordinates + + def change_type(self, dtype): + """ + Change data type on image. + + Note: the image path is voided. + """ + change_type(self, dtype, self) + return self + + def fix_header_dtype(self): + """ + Change the header dtype to the match the datatype of the array. + """ + # Using bool for nibabel headers is unsupported, so use uint8 instead: + # `nibabel.spatialimages.HeaderDataError: data dtype "bool" not supported` + dtype_data = self.data.dtype + if dtype_data == bool: + dtype_data = np.uint8 + + dtype_header = self.hdr.get_data_dtype() + if dtype_header != dtype_data: + logger.warning(f"Image header specifies datatype '{dtype_header}', but array is of type " + f"'{dtype_data}'. Header metadata will be overwritten to use '{dtype_data}'.") + self.hdr.set_data_dtype(dtype_data) + + def save(self, path=None, dtype=None, verbose=1, mutable=False): + """ + Write an image in a nifti file + + :param path: Where to save the data, if None it will be taken from the\ + absolutepath member.\ + If path is a directory, will save to a file under this directory\ + with the basename from the absolutepath member. + + :param dtype: if not set, the image is saved in the same type as input data\ + if 'minimize', image storage space is minimized\ + (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),\ + (4, 'int16', np.int16, "NIFTI_TYPE_INT16"),\ + (8, 'int32', np.int32, "NIFTI_TYPE_INT32"),\ + (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),\ + (32, 'complex64', np.complex64, "NIFTI_TYPE_COMPLEX64"),\ + (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),\ + (256, 'int8', np.int8, "NIFTI_TYPE_INT8"),\ + (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),\ + (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),\ + (1024,'int64', np.int64, "NIFTI_TYPE_INT64"),\ + (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),\ + (1536, 'float128', _float128t, "NIFTI_TYPE_FLOAT128"),\ + (1792, 'complex128', np.complex128, "NIFTI_TYPE_COMPLEX128"),\ + (2048, 'complex256', _complex256t, "NIFTI_TYPE_COMPLEX256"), + + :param mutable: whether to update members with newly created path or dtype + """ + if mutable: # do all modifications in-place + # Case 1: `path` not specified + if path is None: + if self.absolutepath: # Fallback to the original filepath + path = self.absolutepath + else: + raise ValueError("Don't know where to save the image (no absolutepath or path parameter)") + # Case 2: `path` points to an existing directory + elif os.path.isdir(path): + if self.absolutepath: # Use the original filename, but save to the directory specified by `path` + path = os.path.join(os.path.abspath(path), os.path.basename(self.absolutepath)) + else: + raise ValueError("Don't know where to save the image (path parameter is dir, but absolutepath is " + "missing)") + # Case 3: `path` points to a file (or a *nonexistent* directory) so use its value as-is + # (We're okay with letting nonexistent directories slip through, because it's difficult to distinguish + # between nonexistent directories and nonexistent files. Plus, `nibabel` will catch any further errors.) + else: + pass + + if os.path.isfile(path) and verbose: + logger.warning("File %s already exists. Will overwrite it.", path) + if os.path.isabs(path): + logger.debug("Saving image to %s orientation %s shape %s", + path, self.orientation, self.data.shape) + else: + logger.debug("Saving image to %s (%s) orientation %s shape %s", + path, os.path.abspath(path), self.orientation, self.data.shape) + + # Now that `path` has been set and log messages have been written, we can assign it to the image itself + self.absolutepath = os.path.abspath(path) + + if dtype is not None: + self.change_type(dtype) + + if self.hdr is not None: + self.hdr.set_data_shape(self.data.shape) + self.fix_header_dtype() + + # nb. that copy() is important because if it were a memory map, save() would corrupt it + dataobj = self.data.copy() + affine = None + header = self.hdr.copy() if self.hdr is not None else None + nib.save(nib.nifti1.Nifti1Image(dataobj, affine, header), self.absolutepath) + if not os.path.isfile(self.absolutepath): + raise RuntimeError(f"Couldn't save image to {self.absolutepath}") + else: + # if we're not operating in-place, then make any required modifications on a throw-away copy + self.copy().save(path, dtype, verbose, mutable=True) + return self + + +class SlicerOneAxis(object): + """ + Image slicer to use when you don't care about the 2D slice orientation, + and don't want to specify them. + The slicer will just iterate through the right axis that corresponds to + its specification. + + Can help getting ranges and slice indices. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + + def __init__(self, im, axis="IS"): + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + axis_labels = "LRPAIS" + if len(axis) != 2: + raise ValueError() + if axis[0] not in axis_labels: + raise ValueError() + if axis[1] not in axis_labels: + raise ValueError() + if axis[0] != opposite_character[axis[1]]: + raise ValueError() + + for idx_axis in range(2): + dim_nr = im.orientation.find(axis[idx_axis]) + if dim_nr != -1: + break + if dim_nr == -1: + raise ValueError() + + # SCT convention + from_dir = im.orientation[dim_nr] + self.direction = +1 if axis[0] == from_dir else -1 + self.nb_slices = im.dim[dim_nr] + self.im = im + self.axis = axis + self._slice = lambda idx: tuple([(idx if x in axis else slice(None)) for x in im.orientation]) + + def __len__(self): + return self.nb_slices + + def __getitem__(self, idx): + """ + + :return: an image slice, at slicing index idx + :param idx: slicing index (according to the slicing direction) + """ + if isinstance(idx, slice): + raise NotImplementedError() + + if idx >= self.nb_slices: + raise IndexError("I just have {} slices!".format(self.nb_slices)) + + if self.direction == -1: + idx = self.nb_slices - 1 - idx + + return self.im.data[self._slice(idx)] + +def get_dimension(im_file, verbose=1): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + Get dimension from Image or nibabel object. Manages 2D, 3D or 4D images. + + :param: im_file: Image or nibabel object + :return: nx, ny, nz, nt, px, py, pz, pt + """ + if not isinstance(im_file, (nib.nifti1.Nifti1Image, Image)): + raise TypeError("The provided image file is neither a nibabel.nifti1.Nifti1Image instance nor an Image instance") + # initializating ndims [nx, ny, nz, nt] and pdims [px, py, pz, pt] + ndims = [1, 1, 1, 1] + pdims = [1, 1, 1, 1] + data_shape = im_file.header.get_data_shape() + zooms = im_file.header.get_zooms() + for i in range(min(len(data_shape), 4)): + ndims[i] = data_shape[i] + pdims[i] = zooms[i] + return *ndims, *pdims + + +def change_orientation(im_src, orientation, im_dst=None, inverse=False): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im_src: source image + :param orientation: orientation string (SCT "from" convention) + :param im_dst: destination image (can be the source image for in-place + operation, can be unset to generate one) + :param inverse: if you think backwards, use this to specify that you actually + want to transform *from* the specified orientation, not *to* it. + :return: an image with changed orientation + + .. note:: + - the resulting image has no path member set + - if the source image is < 3D, it is reshaped to 3D and the destination is 3D + """ + + if len(im_src.data.shape) < 3: + pass # Will reshape to 3D + elif len(im_src.data.shape) == 3: + pass # OK, standard 3D volume + elif len(im_src.data.shape) == 4: + pass # OK, standard 4D volume + elif len(im_src.data.shape) == 5 and im_src.header.get_intent()[0] == "vector": + pass # OK, physical displacement field + else: + raise NotImplementedError("Don't know how to change orientation for this image") + + im_src_orientation = im_src.orientation + im_dst_orientation = orientation + if inverse: + im_src_orientation, im_dst_orientation = im_dst_orientation, im_src_orientation + + perm, inversion = _get_permutations(im_src_orientation, im_dst_orientation) + + if im_dst is None: + im_dst = im_src.copy() + im_dst._path = None + + im_src_data = im_src.data + if len(im_src_data.shape) < 3: + im_src_data = im_src_data.reshape(tuple(list(im_src_data.shape) + ([1] * (3 - len(im_src_data.shape))))) + + # Update data by performing inversions and swaps + + # axes inversion (flip) + data = im_src_data[::inversion[0], ::inversion[1], ::inversion[2]] + + # axes manipulations (transpose) + if perm == [1, 0, 2]: + data = np.swapaxes(data, 0, 1) + elif perm == [2, 1, 0]: + data = np.swapaxes(data, 0, 2) + elif perm == [0, 2, 1]: + data = np.swapaxes(data, 1, 2) + elif perm == [2, 0, 1]: + data = np.swapaxes(data, 0, 2) # transform [2, 0, 1] to [1, 0, 2] + data = np.swapaxes(data, 0, 1) # transform [1, 0, 2] to [0, 1, 2] + elif perm == [1, 2, 0]: + data = np.swapaxes(data, 0, 2) # transform [1, 2, 0] to [0, 2, 1] + data = np.swapaxes(data, 1, 2) # transform [0, 2, 1] to [0, 1, 2] + elif perm == [0, 1, 2]: + # do nothing + pass + else: + raise NotImplementedError() + + # Update header + + im_src_aff = im_src.hdr.get_best_affine() + aff = nib.orientations.inv_ornt_aff( + np.array((perm, inversion)).T, + im_src_data.shape) + im_dst_aff = np.matmul(im_src_aff, aff) + + im_dst.header.set_qform(im_dst_aff) + im_dst.header.set_sform(im_dst_aff) + im_dst.header.set_data_shape(data.shape) + im_dst.data = data + + return im_dst + + +def _get_permutations(im_src_orientation, im_dst_orientation): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im_src_orientation str: Orientation of source image. Example: 'RPI' + :param im_dest_orientation str: Orientation of destination image. Example: 'SAL' + :return: list of axes permutations and list of inversions to achieve an orientation change + """ + + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + + perm = [0, 1, 2] + inversion = [1, 1, 1] + for i, character in enumerate(im_src_orientation): + try: + perm[i] = im_dst_orientation.index(character) + except ValueError: + perm[i] = im_dst_orientation.index(opposite_character[character]) + inversion[i] = -1 + + return perm, inversion + + +def get_orientation(im): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im: an Image + :return: reference space string (ie. what's in Image.orientation) + """ + res = "".join(nib.orientations.aff2axcodes(im.hdr.get_best_affine())) + return orientation_string_nib2sct(res) + + +def orientation_string_nib2sct(s): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :return: SCT reference space code from nibabel one + """ + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + return "".join([opposite_character[x] for x in s]) + + +def change_type(im_src, dtype, im_dst=None): + """ + Change the voxel type of the image + + :param dtype: if not set, the image is saved in standard type\ + if 'minimize', image space is minimize\ + if 'minimize_int', image space is minimize and values are approximated to integers\ + (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),\ + (4, 'int16', np.int16, "NIFTI_TYPE_INT16"),\ + (8, 'int32', np.int32, "NIFTI_TYPE_INT32"),\ + (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),\ + (32, 'complex64', np.complex64, "NIFTI_TYPE_COMPLEX64"),\ + (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),\ + (256, 'int8', np.int8, "NIFTI_TYPE_INT8"),\ + (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),\ + (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),\ + (1024,'int64', np.int64, "NIFTI_TYPE_INT64"),\ + (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),\ + (1536, 'float128', _float128t, "NIFTI_TYPE_FLOAT128"),\ + (1792, 'complex128', np.complex128, "NIFTI_TYPE_COMPLEX128"),\ + (2048, 'complex256', _complex256t, "NIFTI_TYPE_COMPLEX256"), + :return: + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + """ + + if im_dst is None: + im_dst = im_src.copy() + im_dst._path = None + + if dtype is None: + return im_dst + + # get min/max from input image + min_in = np.nanmin(im_src.data) + max_in = np.nanmax(im_src.data) + + # find optimum type for the input image + if dtype in ('minimize', 'minimize_int'): + # warning: does not take intensity resolution into account, neither complex voxels + + # check if voxel values are real or integer + isInteger = True + if dtype == 'minimize': + for vox in im_src.data.flatten(): + if int(vox) != vox: + isInteger = False + break + + if isInteger: + if min_in >= 0: # unsigned + if max_in <= np.iinfo(np.uint8).max: + dtype = np.uint8 + elif max_in <= np.iinfo(np.uint16): + dtype = np.uint16 + elif max_in <= np.iinfo(np.uint32).max: + dtype = np.uint32 + elif max_in <= np.iinfo(np.uint64).max: + dtype = np.uint64 + else: + raise ValueError("Maximum value of the image is to big to be represented.") + else: + if max_in <= np.iinfo(np.int8).max and min_in >= np.iinfo(np.int8).min: + dtype = np.int8 + elif max_in <= np.iinfo(np.int16).max and min_in >= np.iinfo(np.int16).min: + dtype = np.int16 + elif max_in <= np.iinfo(np.int32).max and min_in >= np.iinfo(np.int32).min: + dtype = np.int32 + elif max_in <= np.iinfo(np.int64).max and min_in >= np.iinfo(np.int64).min: + dtype = np.int64 + else: + raise ValueError("Maximum value of the image is to big to be represented.") + else: + # if max_in <= np.finfo(np.float16).max and min_in >= np.finfo(np.float16).min: + # type = 'np.float16' # not supported by nibabel + if max_in <= np.finfo(np.float32).max and min_in >= np.finfo(np.float32).min: + dtype = np.float32 + elif max_in <= np.finfo(np.float64).max and min_in >= np.finfo(np.float64).min: + dtype = np.float64 + + dtype = to_dtype(dtype) + else: + dtype = to_dtype(dtype) + + # if output type is int, check if it needs intensity rescaling + if "int" in dtype.name: + # get min/max from output type + min_out = np.iinfo(dtype).min + max_out = np.iinfo(dtype).max + # before rescaling, check if there would be an intensity overflow + + if (min_in < min_out) or (max_in > max_out): + # This condition is important for binary images since we do not want to scale them + logger.warning(f"To avoid intensity overflow due to convertion to +{dtype.name}+, intensity will be rescaled to the maximum quantization scale") + # rescale intensity + data_rescaled = im_src.data * (max_out - min_out) / (max_in - min_in) + im_dst.data = data_rescaled - (data_rescaled.min() - min_out) + + # change type of data in both numpy array and nifti header + im_dst.data = getattr(np, dtype.name)(im_dst.data) + im_dst.hdr.set_data_dtype(dtype) + return im_dst + + +def to_dtype(dtype): + """ + Take a dtypeification and return an np.dtype + + :param dtype: dtypeification (string or np.dtype or None are supported for now) + :return: dtype or None + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + """ + # TODO add more or filter on things supported by nibabel + + if dtype is None: + return None + if isinstance(dtype, type): + if isinstance(dtype(0).dtype, np.dtype): + return dtype(0).dtype + if isinstance(dtype, np.dtype): + return dtype + if isinstance(dtype, str): + return np.dtype(dtype) + + raise TypeError("data type {}: {} not understood".format(dtype.__class__, dtype)) + + +def zeros_like(img, dtype=None): + """ + + :param img: reference image + :param dtype: desired data type (optional) + :return: an Image with the same shape and header, filled with zeros + + Similar to numpy.zeros_like(), the goal of the function is to show the developer's + intent and avoid doing a copy, which is slower than initialization with a constant. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + zimg = Image(np.zeros_like(img.data), hdr=img.hdr.copy()) + if dtype is not None: + zimg.change_type(dtype) + return zimg + + +def empty_like(img, dtype=None): + """ + :param img: reference image + :param dtype: desired data type (optional) + :return: an Image with the same shape and header, whose data is uninitialized + + Similar to numpy.empty_like(), the goal of the function is to show the developer's + intent and avoid touching the allocated memory, because it will be written to + afterwards. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + dst = change_type(img, dtype) + return dst + + +def find_zmin_zmax(im, threshold=0.1): + """ + Find the min (and max) z-slice index below which (and above which) slices only have voxels below a given threshold. + + :param im: Image object + :param threshold: threshold to apply before looking for zmin/zmax, typically corresponding to noise level. + :return: [zmin, zmax] + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + slicer = SlicerOneAxis(im, axis="IS") + + # Make sure image is not empty + if not np.any(slicer): + logger.error('Input image is empty') + + # Iterate from bottom to top until we find data + for zmin in range(0, len(slicer)): + if np.any(slicer[zmin] > threshold): + break + + # Conversely from top to bottom + for zmax in range(len(slicer) - 1, zmin, -1): + if np.any(slicer[zmax] > threshold): + break + + return zmin, zmax \ No newline at end of file diff --git a/monai/utils/losses.py b/monai/utils/losses.py new file mode 100644 index 0000000..7449032 --- /dev/null +++ b/monai/utils/losses.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import scipy +import numpy as np + + +class SoftDiceLoss(nn.Module): + ''' + soft-dice loss, useful in binary segmentation + taken from: https://github.com/CoinCheung/pytorch-loss/blob/master/soft_dice_loss.py + ''' + def __init__(self, p=1, smooth=1e-5): + super(SoftDiceLoss, self).__init__() + self.p = p + self.smooth = smooth + + def forward(self, logits, labels): + ''' + inputs: + preds: logits - tensor of shape (N, H, W, ...) + labels: soft labels [0,1] - tensor of shape(N, H, W, ...) + output: + loss: tensor of shape(1, ) + ''' + preds = logits # F.relu(logits) / F.relu(logits).max() if bool(F.relu(logits).max()) else F.relu(logits) + + numer = (preds * labels).sum() + denor = (preds.pow(self.p) + labels.pow(self.p)).sum() + # loss = 1. - (2 * numer + self.smooth) / (denor + self.smooth) + loss = - (2 * numer + self.smooth) / (denor + self.smooth) + return loss + +class AdapWingLoss(nn.Module): + """ + Adaptive Wing loss used for heatmap regression + Adapted from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/losses.py#L341 + + .. seealso:: + Wang, Xinyao, Liefeng Bo, and Li Fuxin. "Adaptive wing loss for robust face alignment via heatmap regression." + Proceedings of the IEEE International Conference on Computer Vision. 2019. + + Args: + theta (float): Threshold to switch between the linear and non-linear parts of the piece-wise loss function. + alpha (float): Used to adapt the behaviour of the loss function at y=0 and y=1 and make loss smooth at 0 (background). + It needs to be slightly above 2 to maintain ideal properties. + omega (float): Multiplicative factor for non linear part of the loss. + epsilon (float): factor to avoid gradient explosion. It must not be too small + NOTE: Larger omega and smaller epsilon values will increase the influence on small errors and vice versa + """ + + def __init__(self, theta=0.5, alpha=2.1, omega=14, epsilon=1, reduction='sum'): + self.theta = theta + self.alpha = alpha + self.omega = omega + self.epsilon = epsilon + self.reduction = reduction + super(AdapWingLoss, self).__init__() + + def forward(self, input, target): + eps = self.epsilon + batch_size = target.size()[0] + + # Adaptive Wing loss. Section 4.2 of the paper. + # Compute adaptive factor + A = self.omega * (1 / (1 + torch.pow(self.theta / eps, + self.alpha - target))) * \ + (self.alpha - target) * torch.pow(self.theta / eps, + self.alpha - target - 1) * (1 / eps) + + # Constant term to link linear and non linear part + C = (self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / eps, self.alpha - target))) + + diff_hm = torch.abs(target - input) + AWingLoss = A * diff_hm - C + idx = diff_hm < self.theta + # NOTE: this is a memory-efficient version than the one in ivadomed losses.py + # where idx is True, compute the non-linear part of the loss, otherwise keep the linear part + # the non-linear parts ensures small errors (as given by idx) have a larger influence to refine the predictions at the boundaries + # the linear part makes the loss function behave more like the MSE loss, which has a linear influence + # (i.e. small errors where y=0 --> small influence --> small gradients) + AWingLoss = torch.where(idx, self.omega * torch.log(1 + torch.pow(diff_hm / eps, self.alpha - target)), AWingLoss) + + + # Mask for weighting the loss function. Section 4.3 of the paper. + mask = torch.zeros_like(target) + kernel = scipy.ndimage.generate_binary_structure(2, 2) + # For 3D segmentation tasks + if len(input.shape) == 5: + kernel = scipy.ndimage.generate_binary_structure(3, 2) + + for i in range(batch_size): + img_list = list() + img_list.append(np.round(target[i].cpu().numpy() * 255)) + img_merge = np.concatenate(img_list) + img_dilate = scipy.ndimage.binary_opening(img_merge, np.expand_dims(kernel, axis=0)) + # NOTE: why 51? the paper thresholds the dilated GT heatmap at 0.2. So, 51/255 = 0.2 + img_dilate[img_dilate < 51] = 1 # 0*omega+1 + img_dilate[img_dilate >= 51] = 1 + self.omega # 1*omega+1 + img_dilate = np.array(img_dilate, dtype=int) + + mask[i] = torch.tensor(img_dilate) + + AWingLoss *= mask + + sum_loss = torch.sum(AWingLoss) + if self.reduction == "sum": + return sum_loss + elif self.reduction == "mean": + all_pixel = torch.sum(mask) + return sum_loss / all_pixel \ No newline at end of file diff --git a/monai/utils/utils.py b/monai/utils/utils.py new file mode 100644 index 0000000..1e116ae --- /dev/null +++ b/monai/utils/utils.py @@ -0,0 +1,443 @@ +import numpy as np +import matplotlib.pyplot as plt +from torch.optim.lr_scheduler import _LRScheduler +import torch + +import torch.nn as nn +import torch.nn.functional as F + +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 + +from scipy import ndimage + + + +def dice_score(prediction, groundtruth, smooth=1.): + numer = (prediction * groundtruth).sum() + denor = (prediction + groundtruth).sum() + # loss = (2 * numer + self.smooth) / (denor + self.smooth) + dice = (2 * numer + smooth) / (denor + smooth) + return dice + +# Check if any label image patch is empty in the batch +def check_empty_patch(labels): + for i, label in enumerate(labels): + if torch.sum(label) == 0.0: + # print(f"Empty label patch found at index {i}. Skipping training step ...") + return None + return labels # If no empty patch is found, return the labels + +# Function to multiply by -1 +def multiply_by_negative_one(x): + return x * -1 + +def plot_slices(image, gt, pred, debug=False): + """ + Plot the image, ground truth and prediction of the mid-sagittal axial slice + The orientaion is assumed to RPI + """ + + # bring everything to numpy + ## added the .float() because of issue : TypeError: Got unsupported ScalarType BFloat16 + image = image.float().numpy() + gt = gt.float().numpy() + pred = pred.float().numpy() + + + mid_sagittal = image.shape[0]//2 + # plot X slices before and after the mid-sagittal slice in a grid + fig, axs = plt.subplots(3, 6, figsize=(10, 6)) + fig.suptitle('Original Image --> Ground Truth --> Prediction') + for i in range(6): + axs[0, i].imshow(image[mid_sagittal-3+i,:,:].T, cmap='gray'); axs[0, i].axis('off') + axs[1, i].imshow(gt[mid_sagittal-3+i,:,:].T); axs[1, i].axis('off') + axs[2, i].imshow(pred[mid_sagittal-3+i,:,:].T); axs[2, i].axis('off') + + # fig, axs = plt.subplots(1, 3, figsize=(10, 8)) + # fig.suptitle('Original Image --> Ground Truth --> Prediction') + # slice = image.shape[2]//2 + + # axs[0].imshow(image[:, :, slice].T, cmap='gray'); axs[0].axis('off') + # axs[1].imshow(gt[:, :, slice].T); axs[1].axis('off') + # axs[2].imshow(pred[:, :, slice].T); axs[2].axis('off') + + plt.tight_layout() + fig.show() + return fig + + +def lesion_wise_tp_fp_fn(truth, prediction): + """ + Computes the true positives, false positives, and false negatives two masks. Masks are considered true positives + if at least one voxel overlaps between the truth and the prediction. + Adapted from: https://github.com/npnl/atlas2_grand_challenge/blob/main/isles/scoring.py#L341 + + Parameters + ---------- + truth : array-like, bool + 3D array. If not boolean, will be converted. + prediction : array-like, bool + 3D array with a shape matching 'truth'. If not boolean, will be converted. + empty_value : scalar, float + Optional. Value to which to default if there are no labels. Default: 1.0. + + Returns + ------- + tp (int): 3D connected-component from the ground-truth image that overlaps at least on one voxel with the prediction image. + fp (int): 3D connected-component from the prediction image that has no voxel overlapping with the ground-truth image. + fn (int): 3d connected-component from the ground-truth image that has no voxel overlapping with the prediction image. + + Notes + ----- + This function computes lesion-wise score by defining true positive lesions (tp), false positive lesions (fp) and + false negative lesions (fn) using 3D connected-component-analysis. + + tp: 3D connected-component from the ground-truth image that overlaps at least on one voxel with the prediction image. + fp: 3D connected-component from the prediction image that has no voxel overlapping with the ground-truth image. + fn: 3d connected-component from the ground-truth image that has no voxel overlapping with the prediction image. + """ + tp, fp, fn = 0, 0, 0 + + # For each true lesion, check if there is at least one overlapping voxel. This determines true positives and + # false negatives (unpredicted lesions) + labeled_ground_truth, num_lesions = ndimage.label(truth.astype(bool)) + for idx_lesion in range(1, num_lesions+1): + lesion = labeled_ground_truth == idx_lesion + lesion_pred_sum = lesion + prediction + if(np.max(lesion_pred_sum) > 1): + tp += 1 + else: + fn += 1 + + # For each predicted lesion, check if there is at least one overlapping voxel in the ground truth. + labaled_prediction, num_pred_lesions = ndimage.label(prediction.astype(bool)) + for idx_lesion in range(1, num_pred_lesions+1): + lesion = labaled_prediction == idx_lesion + lesion_pred_sum = lesion + truth + if(np.max(lesion_pred_sum) <= 1): # No overlap + fp += 1 + + return tp, fp, fn + + +def lesion_f1_score(truth, prediction): + """ + Computes the lesion-wise F1-score between two masks by defining true positive lesions (tp), false positive lesions (fp) + and false negative lesions (fn) using 3D connected-component-analysis. + + Masks are considered true positives if at least one voxel overlaps between the truth and the prediction. + + Returns + ------- + f1_score : float + Lesion-wise F1-score as float. + Max score = 1 + Min score = 0 + If both images are empty (tp + fp + fn =0) = empty_value + """ + empty_value = 1.0 # Value to which to default if there are no labels. Default: 1.0. + + if not np.any(truth) and not np.any(prediction): + # Both reference and prediction are empty --> model learned correctly + return 1.0 + elif np.any(truth) and not np.any(prediction): + # Reference is not empty, prediction is empty --> model did not learn correctly (it's false negative) + return 0.0 + # if the predction is not empty and ref is empty, it's false positive + # if both are not empty, it's true positive + else: + tp, fp, fn = lesion_wise_tp_fp_fn(truth, prediction) + f1_score = empty_value + + # Compute f1_score + denom = tp + (fp + fn)/2 + if(denom != 0): + f1_score = tp / denom + return f1_score + + +def lesion_ppv(truth, prediction): + """ + Computes the lesion-wise positive predictive value (PPV) between two masks + Returns + ------- + ppv (float): Lesion-wise positive predictive value as float. + Max score = 1 + Min score = 0 + If both images are empty (tp + fp + fn =0) = empty_value + """ + if not np.any(truth) and not np.any(prediction): + # Both reference and prediction are empty --> model learned correctly + return 1.0 + elif np.any(truth) and not np.any(prediction): + # Reference is not empty, prediction is empty --> model did not learn correctly (it's false negative) + return 0.0 + # if the predction is not empty and ref is empty, it's false positive + # if both are not empty, it's true positive + else: + tp, fp, _ = lesion_wise_tp_fp_fn(truth, prediction) + ppv = 1.0 + + # Compute ppv + denom = tp + fp + # denom should ideally not be zero inside this else as it should be caught by the empty checks above + if(denom != 0): + ppv = tp / denom + return ppv + + +def lesion_sensitivity(truth, prediction): + """ + Computes the lesion-wise sensitivity between two masks + Returns + ------- + sensitivity (float): Lesion-wise sensitivity as float. + Max score = 1 + Min score = 0 + If both images are empty (tp + fp + fn =0) = empty_value + """ + empty_value = 1.0 # Value to which to default if there are no labels. Default: 1.0. + + if not np.any(truth) and not np.any(prediction): + # Both reference and prediction are empty --> model learned correctly + return 1.0 + # if the predction is not empty and ref is empty, it's false positive + # if both are not empty, it's true positive + else: + + tp, _, fn = lesion_wise_tp_fp_fn(truth, prediction) + sensitivity = empty_value + + # Compute sensitivity + denom = tp + fn + if(denom != 0): + sensitivity = tp / denom + return sensitivity + + +def remove_small_lesions(lesion_seg, resolution, min_volume=7.5): + """ + Remove lesions which are smaller than a given volume threshold. + + Args: + predictions (ndarray or nibabel object): Input segmentation. Image could be 2D or 3D. + resolution (list): Resolution of the image (Example: [1, 1, 1]) in mm + min_volume (float): Minimum volume of the lesion to be kept. in mm3 (Default is 5 voxels in canproco = 5*0.7*0.7*3=7.35 ) + + Returns: + ndarray or nibabel (same object as the input). + """ + # Find number of closed objects using skimage "label" + labeled_obj, num_obj = ndimage.label(np.copy(lesion_seg)) + # Compute the volume of each object + obj_volume = np.zeros(num_obj) + for i in range(num_obj): + obj_volume[i] = np.sum(labeled_obj == i+1)*np.prod(resolution) + # Remove objects with volume less than min_volume + lesion_seg = np.copy(lesion_seg) + for i in range(num_obj): + if obj_volume[i] < min_volume: + lesion_seg[labeled_obj == i+1] = 0 + labeled_obj, num_obj = ndimage.label(lesion_seg) + return lesion_seg + + +# def lesion_wise_precision_recall(prediction, groundtruth, iou_threshold=0.1): +# """ +# This function computes the lesion-wise precision and recall. + +# Args: +# prediction: predicted segmentation mask +# groundtruth: ground truth segmentation mask +# iou_threshold: threshold for intersection over union (IoU) for a lesion to be considered as true positive +# Returns: +# precision: lesion-wise precision +# recall: lesion-wise recall +# """ +# prediction_cpu = prediction#.detach().numpy() +# groundtruth_cpu = groundtruth#.detach().numpy() + +# precision = [] +# recall = [] +# # print(prediction_cpu.shape) +# for i in range(prediction_cpu.shape[0]): +# # Compute connected components in the predicted and ground truth segmentation masks +# if len(prediction_cpu.shape) == 4: +# # print("iteration") +# # binarize the prediction and ground truth +# prediction_cpu[0] = prediction_cpu[0] > 0.2 +# groundtruth_cpu[0] = groundtruth_cpu[0] > 0.2 +# # compute connected components +# pred_labels, num_components_pred = skimage.measure.label(prediction_cpu[0], connectivity=2, return_num=True) +# gt_labels, num_components_gt = skimage.measure.label(groundtruth_cpu[0], connectivity=2, return_num=True) +# # print('c', pred_labels.shape) +# # print('d', gt_labels.shape) +# if len(prediction_cpu.shape) == 5: +# # binarize the prediction and ground truth +# prediction_cpu[i][0] = prediction_cpu[i][0] > 0.2 +# groundtruth_cpu[i][0] = groundtruth_cpu[i][0] > 0.2 +# # compute connected components +# pred_labels, num_components_pred = skimage.measure.label(prediction_cpu[i][0], connectivity=2, return_num=True) +# gt_labels, num_components_gt = skimage.measure.label(groundtruth_cpu[i][0], connectivity=2, return_num=True) +# # print('e', pred_labels.shape) +# # print('f', gt_labels.shape) + +# # If there are no connected components in the predicted or ground truth segmentation masks we return 0 and continue +# if num_components_gt==0 or num_components_pred==0: +# precision+= [0] +# recall+= [0] +# continue + +# # Compute the intersection over union (IoU) between each pair of connected components +# iou_matrix = np.zeros((np.max(pred_labels), np.max(gt_labels))) +# intersection_matrix = np.zeros((np.max(pred_labels), np.max(gt_labels))) +# for i in range(np.max(pred_labels)): +# for j in range(np.max(gt_labels)): +# # Compute the intersection +# intersection = np.sum((pred_labels == i + 1) * (gt_labels == j + 1)) +# # Compute the union +# union = np.sum((pred_labels == i + 1)) + np.sum((gt_labels == j + 1)) - intersection +# # Compute the IoU +# iou_matrix[i, j] = intersection / union +# # if iou_matrix[i, j] > 0: +# # print("iou_matrix", iou_matrix[i, j]) +# # Compute the intersection +# intersection_matrix[i, j] = intersection + +# # # Compute lesion-wise precision and recall +# # true_positives = np.sum(np.max(iou_matrix, axis=1) > iou_threshold) +# # false_positives = np.sum(np.max(iou_matrix, axis=0) <= iou_threshold) +# # false_negatives = np.sum(np.max(iou_matrix, axis=1) <= iou_threshold) +# # precision += [true_positives / (true_positives + false_positives)] +# # recall+= [true_positives / (true_positives + false_negatives)] + +# # Compute lesion-wise precision and recall +# true_positives = np.sum(np.max(intersection_matrix, axis=1) > iou_threshold) +# false_positives = np.sum(np.max(intersection_matrix, axis=0) <= iou_threshold) +# false_negatives = np.sum(np.max(intersection_matrix, axis=1) <= iou_threshold) +# precision += [true_positives / (true_positives + false_positives)] +# recall+= [true_positives / (true_positives + false_negatives)] + + +# # Put it back in cuda +# precision = torch.tensor(precision).cuda() +# recall = torch.tensor(recall).cuda() + +# print("precision", precision) +# print("recall", recall) +# return precision, recall + + +# ############################################################################################################ +# # NNUNet's Model +# ############################################################################################################ +# nnunet_plans = { +# "UNet_class_name": "PlainConvUNet", +# "UNet_base_num_features": 32, +# "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2, 2], +# "n_conv_per_stage_decoder": [2, 2, 2, 2, 2, 2], +# "pool_op_kernel_sizes": [ +# [1, 1, 1], +# [1, 2, 2], +# [1, 2, 2], +# [2, 2, 2], +# [2, 2, 2], +# [1, 2, 2], +# [1, 2, 2] +# ], +# "conv_kernel_sizes": [ +# [1, 3, 3], +# [1, 3, 3], +# [3, 3, 3], +# [3, 3, 3], +# [3, 3, 3], +# [3, 3, 3], +# [3, 3, 3] +# ], +# "unet_max_num_features": 320, +# } + + +# # ====================================================================================================== +# # Utils for nnUNet's Model +# # ==================================================================================================== +# class InitWeights_He(object): +# def __init__(self, neg_slope=1e-2): +# self.neg_slope = neg_slope + +# def __call__(self, module): +# if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): +# module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) +# if module.bias is not None: +# module.bias = nn.init.constant_(module.bias, 0) + + +# # ====================================================================================================== +# # Define the network based on plans json +# # ==================================================================================================== +# def create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, deep_supervision: bool = False): +# """ +# Adapted from nnUNet's source code: +# https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/get_network_from_plans.py#L9 + +# """ +# num_stages = len(plans["conv_kernel_sizes"]) + +# dim = len(plans["conv_kernel_sizes"][0]) +# conv_op = convert_dim_to_conv_op(dim) + +# segmentation_network_class_name = plans["UNet_class_name"] +# mapping = { +# 'PlainConvUNet': PlainConvUNet, +# 'ResidualEncoderUNet': ResidualEncoderUNet +# } +# kwargs = { +# 'PlainConvUNet': { +# 'conv_bias': True, +# 'norm_op': get_matching_instancenorm(conv_op), +# 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, +# 'dropout_op': None, 'dropout_op_kwargs': None, +# 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, +# }, +# 'ResidualEncoderUNet': { +# 'conv_bias': True, +# 'norm_op': get_matching_instancenorm(conv_op), +# 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, +# 'dropout_op': None, 'dropout_op_kwargs': None, +# 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, +# } +# } +# assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ +# 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ +# 'into either this ' \ +# 'function (get_network_from_plans) or ' \ +# 'the init of your nnUNetModule to accomodate that.' +# network_class = mapping[segmentation_network_class_name] + +# conv_or_blocks_per_stage = { +# 'n_conv_per_stage' +# if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': plans["n_conv_per_stage_encoder"], +# 'n_conv_per_stage_decoder': plans["n_conv_per_stage_decoder"] +# } + +# # network class name!! +# model = network_class( +# input_channels=num_input_channels, +# n_stages=num_stages, +# features_per_stage=[min(plans["UNet_base_num_features"] * 2 ** i, +# plans["unet_max_num_features"]) for i in range(num_stages)], +# conv_op=conv_op, +# kernel_sizes=plans["conv_kernel_sizes"], +# strides=plans["pool_op_kernel_sizes"], +# num_classes=num_classes, +# deep_supervision=deep_supervision, +# **conv_or_blocks_per_stage, +# **kwargs[segmentation_network_class_name] +# ) +# model.apply(InitWeights_He(1e-2)) +# if network_class == ResidualEncoderUNet: +# model.apply(init_last_bn_before_add_to_0) + +# return model \ No newline at end of file