Skip to content

Commit

Permalink
Merge branch 'MIC-DKFZ:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
fitzjalen authored May 21, 2024
2 parents 3f5a1bd + 7372db8 commit f08c7bd
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 7 deletions.
2 changes: 1 addition & 1 deletion documentation/resenc_presets.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ The presets differ from `ResEncUNetPlanner` in two ways:
- They set new default values for `gpu_memory_target_in_gb` to target the respective VRAM consumptions
- They remove the batch size cap of 0.05 (= previously one batch could not cover mode pixels than 5% of the entire dataset, not it can be arbitrarily large)

The preset are merely there ot make life easier, and to provide standardized configurations people can benchmark with.
The presets are merely there to make life easier, and to provide standardized configurations people can benchmark with.
You can easily adapt the GPU memory target to match your GPU, and to scale beyond 40GB of GPU memory.

Here is an example for how to scale to 80GB VRAM on Dataset003_Liver:
Expand Down
110 changes: 110 additions & 0 deletions nnunetv2/dataset_conversion/Dataset042_BraTS18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import multiprocessing
import shutil

import SimpleITK as sitk
import numpy as np
from tqdm import tqdm
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw


def copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None:
# use this for segmentation only!!!
# nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3
img = sitk.ReadImage(in_file)
img_npy = sitk.GetArrayFromImage(img)

uniques = np.unique(img_npy)
for u in uniques:
if u not in [0, 1, 2, 4]:
raise RuntimeError('unexpected label')

seg_new = np.zeros_like(img_npy)
seg_new[img_npy == 4] = 3
seg_new[img_npy == 2] = 1
seg_new[img_npy == 1] = 2
img_corr = sitk.GetImageFromArray(seg_new)
img_corr.CopyInformation(img)
sitk.WriteImage(img_corr, out_file)


def convert_labels_back_to_BraTS(seg: np.ndarray):
new_seg = np.zeros_like(seg)
new_seg[seg == 1] = 2
new_seg[seg == 3] = 4
new_seg[seg == 2] = 1
return new_seg


def load_convert_labels_back_to_BraTS(filename, input_folder, output_folder):
a = sitk.ReadImage(join(input_folder, filename))
b = sitk.GetArrayFromImage(a)
c = convert_labels_back_to_BraTS(b)
d = sitk.GetImageFromArray(c)
d.CopyInformation(a)
sitk.WriteImage(d, join(output_folder, filename))


def convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str,
num_processes: int = 12):
"""
reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the
"""
maybe_mkdir_p(output_folder)
nii = subfiles(input_folder, suffix='.nii.gz', join=False)
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii)))


if __name__ == '__main__':
brats_data_dir = ...

task_id = 42
task_name = "BraTS2018"

foldername = "Dataset%03.0d_%s" % (task_id, task_name)

# setting up nnU-Net folders
out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
labelstr = join(out_base, "labelsTr")
maybe_mkdir_p(imagestr)
maybe_mkdir_p(labelstr)

case_ids_hgg = subdirs(join(brats_data_dir, "HGG"), prefix='Brats', join=False)
case_ids_lgg = subdirs(join(brats_data_dir, "LGG"), prefix="Brats", join=False)

print("copying hggs")
for c in tqdm(case_ids_hgg):
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii'))

copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "HGG", c, c + "_seg.nii"),
join(labelstr, c + '.nii'))
print("copying lggs")
for c in tqdm(case_ids_lgg):
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii'))

copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "LGG", c, c + "_seg.nii"),
join(labelstr, c + '.nii'))

generate_dataset_json(out_base,
channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'},
labels={
'background': 0,
'whole tumor': (1, 2, 3),
'tumor core': (2, 3),
'enhancing tumor': (3,)
},
num_training_cases=(len(case_ids_lgg) + len(case_ids_hgg)),
file_ending='.nii',
regions_class_order=(1, 2, 3),
license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
dataset_release='1.0')
110 changes: 110 additions & 0 deletions nnunetv2/dataset_conversion/Dataset043_BraTS19.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import multiprocessing
import shutil

import SimpleITK as sitk
import numpy as np
from tqdm import tqdm
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw


def copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None:
# use this for segmentation only!!!
# nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3
img = sitk.ReadImage(in_file)
img_npy = sitk.GetArrayFromImage(img)

uniques = np.unique(img_npy)
for u in uniques:
if u not in [0, 1, 2, 4]:
raise RuntimeError('unexpected label')

seg_new = np.zeros_like(img_npy)
seg_new[img_npy == 4] = 3
seg_new[img_npy == 2] = 1
seg_new[img_npy == 1] = 2
img_corr = sitk.GetImageFromArray(seg_new)
img_corr.CopyInformation(img)
sitk.WriteImage(img_corr, out_file)


def convert_labels_back_to_BraTS(seg: np.ndarray):
new_seg = np.zeros_like(seg)
new_seg[seg == 1] = 2
new_seg[seg == 3] = 4
new_seg[seg == 2] = 1
return new_seg


def load_convert_labels_back_to_BraTS(filename, input_folder, output_folder):
a = sitk.ReadImage(join(input_folder, filename))
b = sitk.GetArrayFromImage(a)
c = convert_labels_back_to_BraTS(b)
d = sitk.GetImageFromArray(c)
d.CopyInformation(a)
sitk.WriteImage(d, join(output_folder, filename))


def convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str,
num_processes: int = 12):
"""
reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the
"""
maybe_mkdir_p(output_folder)
nii = subfiles(input_folder, suffix='.nii.gz', join=False)
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii)))


if __name__ == '__main__':
brats_data_dir = ...

task_id = 43
task_name = "BraTS2019"

foldername = "Dataset%03.0d_%s" % (task_id, task_name)

# setting up nnU-Net folders
out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
labelstr = join(out_base, "labelsTr")
maybe_mkdir_p(imagestr)
maybe_mkdir_p(labelstr)

case_ids_hgg = subdirs(join(brats_data_dir, "HGG"), prefix='BraTS', join=False)
case_ids_lgg = subdirs(join(brats_data_dir, "LGG"), prefix="BraTS", join=False)

print("copying hggs")
for c in tqdm(case_ids_hgg):
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii'))

copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "HGG", c, c + "_seg.nii"),
join(labelstr, c + '.nii'))
print("copying lggs")
for c in tqdm(case_ids_lgg):
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii'))

copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "LGG", c, c + "_seg.nii"),
join(labelstr, c + '.nii'))

generate_dataset_json(out_base,
channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'},
labels={
'background': 0,
'whole tumor': (1, 2, 3),
'tumor core': (2, 3),
'enhancing tumor': (3,)
},
num_training_cases=(len(case_ids_hgg) + len(case_ids_lgg)),
file_ending='.nii',
regions_class_order=(1, 2, 3),
license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
dataset_release='1.0')
3 changes: 1 addition & 2 deletions nnunetv2/experiment_planning/verify_dataset_integrity.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ def verify_dataset_integrity(folder: str, num_processes: int = 8) -> None:
with multiprocessing.get_context("spawn").Pool(num_processes) as p:
result = p.starmap(
verify_labels,
zip([join(folder, 'labelsTr', i) for i in labelfiles], [reader_writer_class] * len(labelfiles),
[expected_labels] * len(labelfiles))
zip(labelfiles, [reader_writer_class] * len(labelfiles), [expected_labels] * len(labelfiles))
)
if not all(result):
raise RuntimeError(
Expand Down
2 changes: 2 additions & 0 deletions nnunetv2/imageio/nibabel_reader_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class NibabelIO(BaseReaderWriter):
IMPORTANT: Run nnUNetv2_plot_overlay_pngs to verify that this did not destroy the alignment of data and seg!
"""
supported_file_endings = [
'.nii',
'.nii.gz',
'.nrrd',
'.mha'
Expand Down Expand Up @@ -107,6 +108,7 @@ class NibabelIOWithReorient(BaseReaderWriter):
IMPORTANT: Run nnUNetv2_plot_overlay_pngs to verify that this did not destroy the alignment of data and seg!
"""
supported_file_endings = [
'.nii',
'.nii.gz',
'.nrrd',
'.mha'
Expand Down
3 changes: 2 additions & 1 deletion nnunetv2/imageio/simpleitk_reader_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class SimpleITKIO(BaseReaderWriter):
supported_file_endings = [
'.nii.gz',
'.nrrd',
'.mha'
'.mha',
'.gipl'
]

def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
Expand Down
8 changes: 5 additions & 3 deletions nnunetv2/inference/JHU_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def predict_from_data_iterator(self,


if __name__ == '__main__':
# /home/isensee/JHU_trained_model
# python nnunetv2/inference/JHU_inference.py /home/isensee/Downloads/AbdomenAtlasTest /home/isensee/Downloads/AbdomenAtlasTest_pred -model /home/isensee/temp/JHU/trained_model_ep3850
# /home/isensee/temp/JHU/trained_model_ep3850
# /home/isensee/Downloads/AbdomenAtlasTest
# /home/isensee/Downloads/AbdomenAtlasTest_pred

Expand All @@ -158,6 +159,7 @@ def predict_from_data_iterator(self,
parser.add_argument('input_dir', type=str)
parser.add_argument('output_dir', type=str)
parser.add_argument('-model', required=True, type=str)
parser.add_argument('--disable_tqdm', required=False, action='store_true', default=False)
args = parser.parse_args()

predictor = JHUPredictor(
Expand All @@ -166,9 +168,9 @@ def predict_from_data_iterator(self,
use_mirroring=True,
perform_everything_on_device=True,
device=torch.device('cuda', 0),
verbose=True,
verbose=False,
verbose_preprocessing=False,
allow_tqdm=True
allow_tqdm=not args.disable_tqdm
)

predictor.initialize_from_trained_model_folder(
Expand Down

0 comments on commit f08c7bd

Please sign in to comment.