diff --git a/documentation/dataset_format.md b/documentation/dataset_format.md index de6c9936b..cd8433aed 100644 --- a/documentation/dataset_format.md +++ b/documentation/dataset_format.md @@ -26,7 +26,8 @@ T2 MRI, …) and FILE_ENDING is the file extension used by your image format (.p The dataset.json file connects channel names with the channel identifiers in the 'channel_names' key (see below for details). Side note: Typically, each channel/modality needs to be stored in a separate file and is accessed with the XXXX channel identifier. -Exception are natural images (RGB; .png) where the three color channels can all be stored in one file (see the [road segmentation](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py) dataset as an example). +Exception are natural images (RGB; .png) where the three color channels can all be stored in one file (see the +[road segmentation](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py) dataset as an example). **Segmentations** must share the same geometry with their corresponding images (same shape etc.). Segmentations are integer maps with each value representing a semantic class. The background must be 0. If there is no background, then @@ -57,14 +58,14 @@ of what the raw data was provided in! This is for performance reasons. By default, the following file formats are supported: + - NaturalImage2DIO: .png, .bmp, .tif - NibabelIO: .nii.gz, .nrrd, .mha - NibabelIOWithReorient: .nii.gz, .nrrd, .mha. This reader will reorient images to RAS! - SimpleITKIO: .nii.gz, .nrrd, .mha - Tiff3DIO: .tif, .tiff. 3D tif images! Since TIF does not have a standardized way of storing spacing information, -nnU-Net expects each TIF file to be accompanied by an identically named .json file that contains three numbers -(no units, no comma. Just separated by whitespace), one for each dimension. - +nnU-Net expects each TIF file to be accompanied by an identically named .json file that contains this information (see +[here](#datasetjson)). The file extension lists are not exhaustive and depend on what the backend supports. For example, nibabel and SimpleITK support more than the three given here. The file endings given here are just the ones we tested! @@ -200,6 +201,27 @@ There is a utility with which you can generate the dataset.json automatically. Y [here](../nnunetv2/dataset_conversion/generate_dataset_json.py). See our examples in [dataset_conversion](../nnunetv2/dataset_conversion) for how to use it. And read its documentation! +As described above, a json file that contains spacing information is required for TIFF files. +An example for a 3D TIFF stack with units corresponding to 7.6 in x and y, 80 in z is: + +``` +{ + "spacing": [7.6, 7.6, 80.0] +} +``` + +Within the dataset folder, this file (named `cell6.json` in this example) would be placed in the following folders: + + nnUNet_raw/Dataset123_Foo/ + ├── dataset.json + ├── imagesTr + │   ├── cell6.json + │   └── cell6_0000.tif + └── labelsTr + ├── cell6.json + └── cell6.tif + + ## How to use nnU-Net v1 Tasks If you are migrating from the old nnU-Net, convert your existing datasets with `nnUNetv2_convert_old_nnUNet_dataset`! diff --git a/documentation/installation_instructions.md b/documentation/installation_instructions.md index 409e5ebea..edb03b06d 100644 --- a/documentation/installation_instructions.md +++ b/documentation/installation_instructions.md @@ -82,6 +82,6 @@ easy identification. Note that these commands simply execute python scripts. If you installed nnU-Net in a virtual environment, this environment must be activated when executing the commands. You can see what scripts/functions are executed by -checking the entry_points in the setup.py file. +checking the project.scripts in the [pyproject.toml](../pyproject.toml) file. All nnU-Net commands have a `-h` option which gives information on how to use them. diff --git a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py index c779dcff1..229fed1f1 100644 --- a/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py +++ b/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -1,4 +1,3 @@ -import os.path import shutil from copy import deepcopy from functools import lru_cache @@ -81,6 +80,10 @@ def __init__(self, dataset_name_or_id: Union[str, int], self.plans = None + if isfile(join(self.raw_dataset_folder, 'splits_final.json')): + _maybe_copy_splits_file(join(self.raw_dataset_folder, 'splits_final.json'), + join(preprocessed_folder, 'splits_final.json')) + def determine_reader_writer(self): example_image = self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0] return determine_reader_writer_from_dataset_json(self.dataset_json, example_image) @@ -642,5 +645,23 @@ def load_plans(self, fname: str): self.plans = load_json(fname) +def _maybe_copy_splits_file(splits_file: str, target_fname: str): + if not isfile(target_fname): + shutil.copy(splits_file, target_fname) + else: + # split already exists, do not copy, but check that the splits match. + # This code allows target_fname to contain more splits than splits_file. This is OK. + splits_source = load_json(splits_file) + splits_target = load_json(target_fname) + # all folds in the source file must match the target file + for i in range(len(splits_source)): + train_source = set(splits_source[i]['train']) + train_target = set(splits_target[i]['train']) + assert train_target == train_source + val_source = set(splits_source[i]['val']) + val_target = set(splits_target[i]['val']) + assert val_source == val_target + + if __name__ == '__main__': ExperimentPlanner(2, 8).plan_experiment() diff --git a/nnunetv2/experiment_planning/plan_and_preprocess_api.py b/nnunetv2/experiment_planning/plan_and_preprocess_api.py index 13490a427..74a070385 100644 --- a/nnunetv2/experiment_planning/plan_and_preprocess_api.py +++ b/nnunetv2/experiment_planning/plan_and_preprocess_api.py @@ -1,17 +1,16 @@ -import shutil from typing import List, Type, Optional, Tuple, Union -import nnunetv2 -from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, subfiles, load_json +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, load_json +import nnunetv2 +from nnunetv2.configuration import default_num_processes from nnunetv2.experiment_planning.dataset_fingerprint.fingerprint_extractor import DatasetFingerprintExtractor from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner from nnunetv2.experiment_planning.verify_dataset_integrity import verify_dataset_integrity from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed -from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name, maybe_convert_to_dataset_name +from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.plans_handling.plans_handler import PlansManager -from nnunetv2.configuration import default_num_processes from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets diff --git a/nnunetv2/experiment_planning/verify_dataset_integrity.py b/nnunetv2/experiment_planning/verify_dataset_integrity.py index 61175d069..71f84bff4 100644 --- a/nnunetv2/experiment_planning/verify_dataset_integrity.py +++ b/nnunetv2/experiment_planning/verify_dataset_integrity.py @@ -76,7 +76,7 @@ def check_cases(image_files: List[str], label_file: str, expected_num_channels: if not np.allclose(spacing_seg, spacing_images): print('Error: Spacing mismatch between segmentation and corresponding images. \nSpacing images: %s. ' '\nSpacing seg: %s. \nImage files: %s. \nSeg file: %s\n' % - (shape_image, shape_seg, image_files, label_file)) + (spacing_images, spacing_seg, image_files, label_file)) ret = False # check modalities diff --git a/nnunetv2/imageio/natural_image_reager_writer.py b/nnunetv2/imageio/natural_image_reader_writer.py similarity index 100% rename from nnunetv2/imageio/natural_image_reager_writer.py rename to nnunetv2/imageio/natural_image_reader_writer.py diff --git a/nnunetv2/imageio/reader_writer_registry.py b/nnunetv2/imageio/reader_writer_registry.py index e2921e688..606334ce0 100644 --- a/nnunetv2/imageio/reader_writer_registry.py +++ b/nnunetv2/imageio/reader_writer_registry.py @@ -4,7 +4,7 @@ from batchgenerators.utilities.file_and_folder_operations import join import nnunetv2 -from nnunetv2.imageio.natural_image_reager_writer import NaturalImage2DIO +from nnunetv2.imageio.natural_image_reader_writer import NaturalImage2DIO from nnunetv2.imageio.nibabel_reader_writer import NibabelIO, NibabelIOWithReorient from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO from nnunetv2.imageio.tif_reader_writer import Tiff3DIO diff --git a/nnunetv2/inference/examples.py b/nnunetv2/inference/examples.py index b57a39831..a66d98f8b 100644 --- a/nnunetv2/inference/examples.py +++ b/nnunetv2/inference/examples.py @@ -12,7 +12,7 @@ tile_step_size=0.5, use_gaussian=True, use_mirroring=True, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, verbose_preprocessing=False, diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 3e3e3dd56..cfc9e9c85 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -1,4 +1,5 @@ import inspect +import itertools import multiprocessing import os import traceback @@ -39,7 +40,7 @@ def __init__(self, tile_step_size: float = 0.5, use_gaussian: bool = True, use_mirroring: bool = True, - perform_everything_on_gpu: bool = True, + perform_everything_on_device: bool = True, device: torch.device = torch.device('cuda'), verbose: bool = False, verbose_preprocessing: bool = False, @@ -59,10 +60,10 @@ def __init__(self, # why would I ever want to do that. Stupid dobby. This kills DDP inference... pass if device.type != 'cuda': - print(f'perform_everything_on_gpu=True is only supported for cuda devices! Setting this to False') - perform_everything_on_gpu = False + print(f'perform_everything_on_device=True is only supported for cuda devices! Setting this to False') + perform_everything_on_device = False self.device = device - self.perform_everything_on_gpu = perform_everything_on_gpu + self.perform_everything_on_device = perform_everything_on_device def initialize_from_trained_model_folder(self, model_training_output_dir: str, use_folds: Union[Tuple[Union[int, str]], None], @@ -110,7 +111,7 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str, self.label_manager = plans_manager.get_label_manager(dataset_json) if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ and not isinstance(self.network, OptimizedModule): - print('compiling network') + print('Using torch.compile') self.network = torch.compile(self.network) def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, @@ -129,12 +130,13 @@ def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, self.allowed_mirroring_axes = inference_allowed_mirroring_axes self.label_manager = plans_manager.get_label_manager(dataset_json) allow_compile = True - allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) + allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and ( + os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) allow_compile = allow_compile and not isinstance(self.network, OptimizedModule) if isinstance(self.network, DistributedDataParallel): allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule) if allow_compile: - print('compiling network') + print('Using torch.compile') self.network = torch.compile(self.network) @staticmethod @@ -352,7 +354,7 @@ def predict_from_data_iterator(self, else: print(f'\nPredicting image of shape {data.shape}:') - print(f'perform_everything_on_gpu: {self.perform_everything_on_gpu}') + print(f'perform_everything_on_device: {self.perform_everything_on_device}') properties = preprocessed['data_properties'] @@ -360,7 +362,6 @@ def predict_from_data_iterator(self, # npy files proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) while not proceed: - # print('sleeping') sleep(0.1) proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) @@ -453,56 +454,33 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE. SEE convert_predicted_logits_to_segmentation_with_correct_shape """ - # we have some code duplication here but this allows us to run with perform_everything_on_gpu=True as - # default and not have the entire program crash in case of GPU out of memory. Neat. That should make - # things a lot faster for some datasets. - original_perform_everything_on_gpu = self.perform_everything_on_gpu + n_threads = torch.get_num_threads() + torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads) with torch.no_grad(): prediction = None - if self.perform_everything_on_gpu: - try: - for params in self.list_of_parameters: - - # messing with state dict names... - if not isinstance(self.network, OptimizedModule): - self.network.load_state_dict(params) - else: - self.network._orig_mod.load_state_dict(params) - - if prediction is None: - prediction = self.predict_sliding_window_return_logits(data) - else: - prediction += self.predict_sliding_window_return_logits(data) - - if len(self.list_of_parameters) > 1: - prediction /= len(self.list_of_parameters) - - except RuntimeError: - print('Prediction with perform_everything_on_gpu=True failed due to insufficient GPU memory. ' - 'Falling back to perform_everything_on_gpu=False. Not a big deal, just slower...') - print('Error:') - traceback.print_exc() - prediction = None - self.perform_everything_on_gpu = False - - if prediction is None: - for params in self.list_of_parameters: - # messing with state dict names... - if not isinstance(self.network, OptimizedModule): - self.network.load_state_dict(params) - else: - self.network._orig_mod.load_state_dict(params) - - if prediction is None: - prediction = self.predict_sliding_window_return_logits(data) - else: - prediction += self.predict_sliding_window_return_logits(data) - if len(self.list_of_parameters) > 1: - prediction /= len(self.list_of_parameters) - - print('Prediction done, transferring to CPU if needed') + + for params in self.list_of_parameters: + + # messing with state dict names... + if not isinstance(self.network, OptimizedModule): + self.network.load_state_dict(params) + else: + self.network._orig_mod.load_state_dict(params) + + # why not leave prediction on device if perform_everything_on_device? Because this may cause the + # second iteration to crash due to OOM. Grabbing tha twith try except cause way more bloated code than + # this actually saves computation time + if prediction is None: + prediction = self.predict_sliding_window_return_logits(data).to('cpu') + else: + prediction += self.predict_sliding_window_return_logits(data).to('cpu') + + if len(self.list_of_parameters) > 1: + prediction /= len(self.list_of_parameters) + + if self.verbose: print('Prediction done') prediction = prediction.to('cpu') - self.perform_everything_on_gpu = original_perform_everything_on_gpu + torch.set_num_threads(n_threads) return prediction def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]): @@ -548,24 +526,66 @@ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3 assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!' - num_predictons = 2 ** len(mirror_axes) - if 0 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (2,))), (2,)) - if 1 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (3,))), (3,)) - if 2 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (4,))), (4,)) - if 0 in mirror_axes and 1 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (2, 3))), (2, 3)) - if 0 in mirror_axes and 2 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (2, 4))), (2, 4)) - if 1 in mirror_axes and 2 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (3, 4))), (3, 4)) - if 0 in mirror_axes and 1 in mirror_axes and 2 in mirror_axes: - prediction += torch.flip(self.network(torch.flip(x, (2, 3, 4))), (2, 3, 4)) - prediction /= num_predictons + axes_combinations = [ + c for i in range(len(mirror_axes)) for c in itertools.combinations([m + 2 for m in mirror_axes], i + 1) + ] + for axes in axes_combinations: + prediction += torch.flip(self.network(torch.flip(x, (*axes,))), (*axes,)) + prediction /= (len(axes_combinations) + 1) return prediction + def _internal_predict_sliding_window_return_logits(self, + data: torch.Tensor, + slicers, + do_on_device: bool = True, + ): + predicted_logits = n_predictions = prediction = gaussian = workon = None + results_device = self.device if do_on_device else torch.device('cpu') + + try: + empty_cache(self.device) + + # move data to device + if self.verbose: + print(f'move image to device {results_device}') + data = data.to(results_device) + + # preallocate arrays + if self.verbose: + print(f'preallocating results arrays on device {results_device}') + predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), + dtype=torch.half, + device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) + if self.use_gaussian: + gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, + value_scaling_factor=10, + device=results_device) + + if self.verbose: print('running prediction') + if not self.allow_tqdm and self.verbose: print(f'{len(slicers)} steps') + for sl in tqdm(slicers, disable=not self.allow_tqdm): + workon = data[sl][None] + workon = workon.to(self.device, non_blocking=False) + + prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) + + predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction) + n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1) + + predicted_logits /= n_predictions + # check for infs + if torch.any(torch.isinf(predicted_logits)): + raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, ' + 'reduce value_scaling_factor in compute_gaussian or increase the dtype of ' + 'predicted_logits to fp32') + except Exception as e: + del predicted_logits, n_predictions, prediction, gaussian, workon + empty_cache(self.device) + empty_cache(results_device) + raise e + return predicted_logits + def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ -> Union[np.ndarray, torch.Tensor]: assert isinstance(input_image, torch.Tensor) @@ -595,49 +615,24 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) - # preallocate results and num_predictions - results_device = self.device if self.perform_everything_on_gpu else torch.device('cpu') - if self.verbose: print('preallocating arrays') - try: - data = data.to(self.device) - predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), - dtype=torch.half, - device=results_device) - n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, - device=results_device) - if self.use_gaussian: - gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, - value_scaling_factor=1000, - device=results_device) - except RuntimeError: - # sometimes the stuff is too large for GPUs. In that case fall back to CPU - results_device = torch.device('cpu') - data = data.to(results_device) - predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), - dtype=torch.half, - device=results_device) - n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, - device=results_device) - if self.use_gaussian: - gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, - value_scaling_factor=1000, - device=results_device) - finally: - empty_cache(self.device) - - if self.verbose: print('running prediction') - for sl in tqdm(slicers, disable=not self.allow_tqdm): - workon = data[sl][None] - workon = workon.to(self.device, non_blocking=False) - - prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) - - predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction) - n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1) - - predicted_logits /= n_predictions - empty_cache(self.device) - return predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] + if self.perform_everything_on_device and self.device != 'cpu': + # we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device + try: + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, + self.perform_everything_on_device) + except RuntimeError: + print( + 'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU') + empty_cache(self.device) + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, False) + else: + predicted_logits = self._internal_predict_sliding_window_return_logits(data, slicers, + self.perform_everything_on_device) + + empty_cache(self.device) + # revert padding + predicted_logits = predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] + return predicted_logits def predict_entry_point_modelfolder(): @@ -685,6 +680,9 @@ def predict_entry_point_modelfolder(): help="Use this to set the device the inference should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, + help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' + 'jobs)') print( "\n#######################################################################\nPlease cite the following paper " @@ -717,9 +715,11 @@ def predict_entry_point_modelfolder(): predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=device, - verbose=args.verbose) + verbose=args.verbose, + allow_tqdm=not args.disable_progress_bar, + verbose_preprocessing=args.verbose) predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk) predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, overwrite=not args.continue_prediction, @@ -789,6 +789,9 @@ def predict_entry_point(): help="Use this to set the device the inference should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, + help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' + 'jobs)') print( "\n#######################################################################\nPlease cite the following paper " @@ -826,10 +829,11 @@ def predict_entry_point(): predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=device, verbose=args.verbose, - verbose_preprocessing=False) + verbose_preprocessing=args.verbose, + allow_tqdm=not args.disable_progress_bar) predictor.initialize_from_trained_model_folder( model_folder, args.f, @@ -849,7 +853,7 @@ def predict_entry_point(): # args.step_size, # use_gaussian=True, # use_mirroring=not args.disable_tta, - # perform_everything_on_gpu=True, + # perform_everything_on_device=True, # verbose=args.verbose, # save_probabilities=args.save_probabilities, # overwrite=not args.continue_prediction, @@ -865,19 +869,20 @@ def predict_entry_point(): if __name__ == '__main__': # predict a bunch of files from nnunetv2.paths import nnUNet_results, nnUNet_raw + predictor = nnUNetPredictor( tile_step_size=0.5, use_gaussian=True, use_mirroring=True, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, verbose_preprocessing=False, allow_tqdm=True - ) + ) predictor.initialize_from_trained_model_folder( join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'), - use_folds=(0, ), + use_folds=(0,), checkpoint_name='checkpoint_final.pth', ) predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), @@ -888,18 +893,18 @@ def predict_entry_point(): # predict a numpy array from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')]) ret = predictor.predict_single_npy_array(img, props, None, None, False) iterator = predictor.get_data_iterator_from_raw_npy_data([img], None, [props], None, 1) ret = predictor.predict_from_data_iterator(iterator, False, 1) - # predictor = nnUNetPredictor( # tile_step_size=0.5, # use_gaussian=True, # use_mirroring=True, - # perform_everything_on_gpu=True, + # perform_everything_on_device=True, # device=torch.device('cuda', 0), # verbose=False, # allow_tqdm=True @@ -915,4 +920,3 @@ def predict_entry_point(): # num_processes_preprocessing=2, num_processes_segmentation_export=2, # folder_with_segs_from_prev_stage='/media/isensee/data/nnUNet_raw/Dataset003_Liver/imagesTs_predlowres', # num_parts=1, part_id=0) - diff --git a/nnunetv2/inference/readme.md b/nnunetv2/inference/readme.md index 721952888..4f832a158 100644 --- a/nnunetv2/inference/readme.md +++ b/nnunetv2/inference/readme.md @@ -57,7 +57,7 @@ Example: tile_step_size=0.5, use_gaussian=True, use_mirroring=True, - perform_everything_on_gpu=True, + perform_everything_on_device=True, device=torch.device('cuda', 0), verbose=False, verbose_preprocessing=False, diff --git a/nnunetv2/inference/sliding_window_prediction.py b/nnunetv2/inference/sliding_window_prediction.py index 07316cfa7..a6f8ebbae 100644 --- a/nnunetv2/inference/sliding_window_prediction.py +++ b/nnunetv2/inference/sliding_window_prediction.py @@ -17,10 +17,10 @@ def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: tmp[tuple(center_coords)] = 1 gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) - gaussian_importance_map = torch.from_numpy(gaussian_importance_map).type(dtype).to(device) + gaussian_importance_map = torch.from_numpy(gaussian_importance_map) gaussian_importance_map = gaussian_importance_map / torch.max(gaussian_importance_map) * value_scaling_factor - gaussian_importance_map = gaussian_importance_map.type(dtype) + gaussian_importance_map = gaussian_importance_map.type(dtype).to(device) # gaussian_importance_map cannot be 0, otherwise we may end up with nans! gaussian_importance_map[gaussian_importance_map == 0] = torch.min( diff --git a/nnunetv2/preprocessing/normalization/default_normalization_schemes.py b/nnunetv2/preprocessing/normalization/default_normalization_schemes.py index 3c90a919f..705d477c8 100644 --- a/nnunetv2/preprocessing/normalization/default_normalization_schemes.py +++ b/nnunetv2/preprocessing/normalization/default_normalization_schemes.py @@ -32,7 +32,7 @@ def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: here seg is used to store the zero valued region. The value for that region in the segmentation is -1 by default. """ - image = image.astype(self.target_dtype) + image = image.astype(self.target_dtype, copy=False) if self.use_mask_for_norm is not None and self.use_mask_for_norm: # negative values in the segmentation encode the 'outside' region (think zero values around the brain as # in BraTS). We want to run the normalization only in the brain region, so we need to mask the image. @@ -45,7 +45,8 @@ def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: else: mean = image.mean() std = image.std() - image = (image - mean) / (max(std, 1e-8)) + image -= mean + image /= (max(std, 1e-8)) return image @@ -54,13 +55,15 @@ class CTNormalization(ImageNormalization): def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: assert self.intensityproperties is not None, "CTNormalization requires intensity properties" - image = image.astype(self.target_dtype) mean_intensity = self.intensityproperties['mean'] std_intensity = self.intensityproperties['std'] lower_bound = self.intensityproperties['percentile_00_5'] upper_bound = self.intensityproperties['percentile_99_5'] - image = np.clip(image, lower_bound, upper_bound) - image = (image - mean_intensity) / max(std_intensity, 1e-8) + + image = image.astype(self.target_dtype, copy=False) + np.clip(image, lower_bound, upper_bound, out=image) + image -= mean_intensity + image /= max(std_intensity, 1e-8) return image @@ -68,16 +71,16 @@ class NoNormalization(ImageNormalization): leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: - return image.astype(self.target_dtype) + return image.astype(self.target_dtype, copy=False) class RescaleTo01Normalization(ImageNormalization): leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: - image = image.astype(self.target_dtype) - image = image - image.min() - image = image / np.clip(image.max(), a_min=1e-8, a_max=None) + image = image.astype(self.target_dtype, copy=False) + image -= image.min() + image /= np.clip(image.max(), a_min=1e-8, a_max=None) return image @@ -89,7 +92,7 @@ def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: "Your images do not seem to be RGB images" assert image.max() <= 255, "RGB images are uint 8, for whatever reason I found pixel values greater than 255" \ ". Your images do not seem to be RGB images" - image = image.astype(self.target_dtype) - image = image / 255. + image = image.astype(self.target_dtype, copy=False) + image /= 255. return image diff --git a/nnunetv2/training/dataloading/utils.py b/nnunetv2/training/dataloading/utils.py index bd145b4a9..352d18285 100644 --- a/nnunetv2/training/dataloading/utils.py +++ b/nnunetv2/training/dataloading/utils.py @@ -1,13 +1,93 @@ +from __future__ import annotations import multiprocessing import os -from multiprocessing import Pool from typing import List +from pathlib import Path +from warnings import warn import numpy as np from batchgenerators.utilities.file_and_folder_operations import isfile, subfiles from nnunetv2.configuration import default_num_processes +def find_broken_image_and_labels( + path_to_data_dir: str | Path, +) -> tuple[set[str], set[str]]: + """ + Iterates through all numpys and tries to read them once to see if a ValueError is raised. + If so, the case id is added to the respective set and returned for potential fixing. + + :path_to_data_dir: Path/str to the preprocessed directory containing the npys and npzs. + :returns: Tuple of a set containing the case ids of the broken npy images and a set of the case ids of broken npy segmentations. + """ + content = os.listdir(path_to_data_dir) + unique_ids = [c[:-4] for c in content if c.endswith(".npz")] + failed_data_ids = set() + failed_seg_ids = set() + for unique_id in unique_ids: + # Try reading data + try: + np.load(path_to_data_dir / (unique_id + ".npy"), "r") + except ValueError: + failed_data_ids.add(unique_id) + # Try reading seg + try: + np.load(path_to_data_dir / (unique_id + "_seg.npy"), "r") + except ValueError: + failed_seg_ids.add(unique_id) + + return failed_data_ids, failed_seg_ids + + +def try_fix_broken_npy(path_do_data_dir: Path, case_ids: set[str], fix_image: bool): + """ + Receives broken case ids and tries to fix them by re-extracting the npz file (up to 5 times). + + :param case_ids: Set of case ids that are broken. + :param path_do_data_dir: Path to the preprocessed directory containing the npys and npzs. + :raises ValueError: If the npy file could not be unpacked after 5 tries. -- + """ + for case_id in case_ids: + for i in range(5): + try: + key = "data" if fix_image else "seg" + suffix = ".npy" if fix_image else "_seg.npy" + read_npz = np.load(path_do_data_dir / (case_id + ".npz"), "r")[key] + np.save(path_do_data_dir / (case_id + suffix), read_npz) + # Try loading the just saved image. + np.load(path_do_data_dir / (case_id + suffix), "r") + break + except ValueError: + if i == 4: + raise ValueError( + f"Could not unpack {case_id + suffix} after 5 tries!" + ) + continue + + +def verify_or_stratify_npys(path_to_data_dir: str | Path) -> None: + """ + This re-reads the npy files after unpacking. Should there be a loading issue with any, it will try to unpack this file again and overwrites the existing. + If the new file does not get saved correctly 5 times, it will raise an error with the file name to the user. Does the same for images and segmentations. + :param path_to_data_dir: Path to the preprocessed directory containing the npys and npzs. + :raises ValueError: If the npy file could not be unpacked after 5 tries. -- + Otherwise an obscured error will be raised later during training (depending when the broken file is sampled) + """ + path_to_data_dir = Path(path_to_data_dir) + # Check for broken image and segmentation npys + failed_data_ids, failed_seg_ids = find_broken_image_and_labels(path_to_data_dir) + + if len(failed_data_ids) != 0 or len(failed_seg_ids) != 0: + warn( + f"Found {len(failed_data_ids)} faulty data npys and {len(failed_seg_ids)}!\n" + + f"Faulty images: {failed_data_ids}; Faulty segmentations: {failed_seg_ids})\n" + + "Trying to fix them now." + ) + # Try to fix the broken npys by reextracting the npz. If that fails, raise error + try_fix_broken_npy(path_to_data_dir, failed_data_ids, fix_image=True) + try_fix_broken_npy(path_to_data_dir, failed_seg_ids, fix_image=False) + + def _convert_to_npy(npz_file: str, unpack_segmentation: bool = True, overwrite_existing: bool = False) -> None: try: a = np.load(npz_file) # inexpensive, no compression is done here. This just reads metadata diff --git a/nnunetv2/training/loss/compound_losses.py b/nnunetv2/training/loss/compound_losses.py index 9db0a4227..eaeb5d8e0 100644 --- a/nnunetv2/training/loss/compound_losses.py +++ b/nnunetv2/training/loss/compound_losses.py @@ -38,11 +38,10 @@ def forward(self, net_output: torch.Tensor, target: torch.Tensor): if self.ignore_label is not None: assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \ '(DC_and_CE_loss)' - mask = (target != self.ignore_label).bool() + mask = target != self.ignore_label # remove ignore label from target, replace with one of the known labels. It doesn't matter because we # ignore gradients in those areas anyway - target_dice = torch.clone(target) - target_dice[target == self.ignore_label] = 0 + target_dice = torch.where(mask, target, 0) num_fg = mask.sum() else: target_dice = target @@ -50,7 +49,7 @@ def forward(self, net_output: torch.Tensor, target: torch.Tensor): dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \ if self.weight_dice != 0 else 0 - ce_loss = self.ce(net_output, target[:, 0].long()) \ + ce_loss = self.ce(net_output, target[:, 0]) \ if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0 result = self.weight_ce * ce_loss + self.weight_dice * dc_loss diff --git a/nnunetv2/training/loss/deep_supervision.py b/nnunetv2/training/loss/deep_supervision.py index 03141e809..952e3f715 100644 --- a/nnunetv2/training/loss/deep_supervision.py +++ b/nnunetv2/training/loss/deep_supervision.py @@ -1,3 +1,4 @@ +import torch from torch import nn @@ -11,25 +12,19 @@ def __init__(self, loss, weight_factors=None): If weights are None, all w will be 1. """ super(DeepSupervisionWrapper, self).__init__() - self.weight_factors = weight_factors + assert any([x != 0 for x in weight_factors]), "At least one weight factor should be != 0.0" + self.weight_factors = tuple(weight_factors) self.loss = loss def forward(self, *args): - for i in args: - assert isinstance(i, (tuple, list)), f"all args must be either tuple or list, got {type(i)}" - # we could check for equal lengths here as well but we really shouldn't overdo it with checks because - # this code is executed a lot of times! + assert all([isinstance(i, (tuple, list)) for i in args]), \ + f"all args must be either tuple or list, got {[type(i) for i in args]}" + # we could check for equal lengths here as well, but we really shouldn't overdo it with checks because + # this code is executed a lot of times! if self.weight_factors is None: - weights = [1] * len(args[0]) + weights = (1, ) * len(args[0]) else: weights = self.weight_factors - # we initialize the loss like this instead of 0 to ensure it sits on the correct device, not sure if that's - # really necessary - l = weights[0] * self.loss(*[j[0] for j in args]) - for i, inputs in enumerate(zip(*args)): - if i == 0: - continue - l += weights[i] * self.loss(*inputs) - return l \ No newline at end of file + return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0]) diff --git a/nnunetv2/training/loss/dice.py b/nnunetv2/training/loss/dice.py index af554908b..574435754 100644 --- a/nnunetv2/training/loss/dice.py +++ b/nnunetv2/training/loss/dice.py @@ -74,18 +74,18 @@ def forward(self, x, y, loss_mask=None): x = self.apply_nonlin(x) # make everything shape (b, c) - axes = list(range(2, len(x.shape))) + axes = tuple(range(2, x.ndim)) + with torch.no_grad(): - if len(x.shape) != len(y.shape): + if x.ndim != y.ndim: y = y.view((y.shape[0], 1, *y.shape[1:])) if x.shape == y.shape: # if this is the case then gt is probably already a one hot encoding y_onehot = y else: - gt = y.long() y_onehot = torch.zeros(x.shape, device=x.device, dtype=torch.bool) - y_onehot.scatter_(1, gt, 1) + y_onehot.scatter_(1, y.long(), 1) if not self.do_bg: y_onehot = y_onehot[:, 1:] @@ -96,15 +96,19 @@ def forward(self, x, y, loss_mask=None): if not self.do_bg: x = x[:, 1:] - intersect = (x * y_onehot).sum(axes) if loss_mask is None else (x * y_onehot * loss_mask).sum(axes) - sum_pred = x.sum(axes) if loss_mask is None else (x * loss_mask).sum(axes) - - if self.ddp and self.batch_dice: - intersect = AllGatherGrad.apply(intersect).sum(0) - sum_pred = AllGatherGrad.apply(sum_pred).sum(0) - sum_gt = AllGatherGrad.apply(sum_gt).sum(0) + if loss_mask is None: + intersect = (x * y_onehot).sum(axes) + sum_pred = x.sum(axes) + else: + intersect = (x * y_onehot * loss_mask).sum(axes) + sum_pred = (x * loss_mask).sum(axes) if self.batch_dice: + if self.ddp: + intersect = AllGatherGrad.apply(intersect).sum(0) + sum_pred = AllGatherGrad.apply(sum_pred).sum(0) + sum_gt = AllGatherGrad.apply(sum_gt).sum(0) + intersect = intersect.sum(0) sum_pred = sum_pred.sum(0) sum_gt = sum_gt.sum(0) @@ -128,22 +132,18 @@ def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): :return: """ if axes is None: - axes = tuple(range(2, len(net_output.size()))) - - shp_x = net_output.shape - shp_y = gt.shape + axes = tuple(range(2, net_output.ndim)) with torch.no_grad(): - if len(shp_x) != len(shp_y): - gt = gt.view((shp_y[0], 1, *shp_y[1:])) + if net_output.ndim != gt.ndim: + gt = gt.view((gt.shape[0], 1, *gt.shape[1:])) if net_output.shape == gt.shape: # if this is the case then gt is probably already a one hot encoding y_onehot = gt else: - gt = gt.long() - y_onehot = torch.zeros(shp_x, device=net_output.device) - y_onehot.scatter_(1, gt, 1) + y_onehot = torch.zeros(net_output.shape, device=net_output.device) + y_onehot.scatter_(1, gt.long(), 1) tp = net_output * y_onehot fp = net_output * (1 - y_onehot) @@ -152,7 +152,7 @@ def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): if mask is not None: with torch.no_grad(): - mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for i in range(2, len(tp.shape))])) + mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for _ in range(2, tp.ndim)])) tp *= mask_here fp *= mask_here fn *= mask_here diff --git a/nnunetv2/training/loss/robust_ce_loss.py b/nnunetv2/training/loss/robust_ce_loss.py index ad4665919..3399e3ae9 100644 --- a/nnunetv2/training/loss/robust_ce_loss.py +++ b/nnunetv2/training/loss/robust_ce_loss.py @@ -10,7 +10,7 @@ class RobustCrossEntropyLoss(nn.CrossEntropyLoss): input must be logits, not probabilities! """ def forward(self, input: Tensor, target: Tensor) -> Tensor: - if len(target.shape) == len(input.shape): + if target.ndim == input.ndim: assert target.shape[1] == 1 target = target[:, 0] return super().forward(input, target.long()) @@ -30,4 +30,3 @@ def forward(self, inp, target): num_voxels = np.prod(res.shape, dtype=np.int64) res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False) return res.mean() - diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index 074ed8572..7b694b8ab 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -11,6 +11,8 @@ import numpy as np import torch +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ @@ -50,13 +52,13 @@ from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler from nnunetv2.utilities.collate_outputs import collate_outputs +from nnunetv2.utilities.crossval_split import generate_crossval_split from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy from nnunetv2.utilities.get_network_from_plans import get_network_from_plans from nnunetv2.utilities.helpers import empty_cache, dummy_context from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager -from sklearn.model_selection import KFold from torch import autocast, nn from torch import distributed as dist from torch.cuda import device_count @@ -146,6 +148,7 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic self.num_val_iterations_per_epoch = 50 self.num_epochs = 1000 self.current_epoch = 0 + self.enable_deep_supervision = True ### Dealing with labels/regions self.label_manager = self.plans_manager.get_label_manager(dataset_json) @@ -153,7 +156,7 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic # needed for predictions. We do sigmoid in case of (overlapping) regions self.num_input_channels = None # -> self.initialize() - self.network = None # -> self._get_network() + self.network = None # -> self.build_network_architecture() self.optimizer = self.lr_scheduler = None # -> self.initialize self.grad_scaler = GradScaler() if self.device.type == 'cuda' else None self.loss = None # -> self.initialize @@ -201,13 +204,16 @@ def initialize(self): self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, self.dataset_json) - self.network = self.build_network_architecture(self.plans_manager, self.dataset_json, - self.configuration_manager, - self.num_input_channels, - enable_deep_supervision=True).to(self.device) + self.network = self.build_network_architecture( + self.plans_manager, + self.dataset_json, + self.configuration_manager, + self.num_input_channels, + self.enable_deep_supervision, + ).to(self.device) # compile network for free speedup if self._do_i_compile(): - self.print_to_log_file('Compiling network...') + self.print_to_log_file('Using torch.compile...') self.network = torch.compile(self.network) self.optimizer, self.lr_scheduler = self.configure_optimizers() @@ -267,7 +273,7 @@ def build_network_architecture(plans_manager: PlansManager, num_input_channels, enable_deep_supervision: bool = True) -> nn.Module: """ - his is where you build the architecture according to the plans. There is no obligation to use + This is where you build the architecture according to the plans. There is no obligation to use get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what you want. Even ignore the plans and just return something static (as long as it can process the requested patch size) @@ -289,8 +295,11 @@ def build_network_architecture(plans_manager: PlansManager, num_input_channels, deep_supervision=enable_deep_supervision) def _get_deep_supervision_scales(self): - deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack( - self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1] + if self.enable_deep_supervision: + deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack( + self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1] + else: + deep_supervision_scales = None # for train and val_transforms return deep_supervision_scales def _set_batch_size_and_oversample(self): @@ -299,8 +308,6 @@ def _set_batch_size_and_oversample(self): self.batch_size = self.configuration_manager.batch_size else: # batch size is distributed over DDP workers and we need to change oversample_percent for each worker - batch_sizes = [] - oversample_percents = [] world_size = dist.get_world_size() my_rank = dist.get_rank() @@ -309,36 +316,38 @@ def _set_batch_size_and_oversample(self): assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \ 'GPUs... Duh.' - batch_size_per_GPU = np.ceil(global_batch_size / world_size).astype(int) - - for rank in range(world_size): - if (rank + 1) * batch_size_per_GPU > global_batch_size: - batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - global_batch_size) - else: - batch_size = batch_size_per_GPU - - batch_sizes.append(batch_size) - - sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1]) - sample_id_high = np.sum(batch_sizes) - - if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent): - oversample_percents.append(0.0) - elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent): - oversample_percents.append(1.0) - else: - percent_covered_by_this_rank = sample_id_high / global_batch_size - sample_id_low / global_batch_size - oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) - - sample_id_low / global_batch_size) / percent_covered_by_this_rank) - oversample_percents.append(oversample_percent_here) + batch_size_per_GPU = [global_batch_size // world_size] * world_size + batch_size_per_GPU = [batch_size_per_GPU[i] + 1 + if (batch_size_per_GPU[i] * world_size + i) < global_batch_size + else batch_size_per_GPU[i] + for i in range(len(batch_size_per_GPU))] + assert sum(batch_size_per_GPU) == global_batch_size + + sample_id_low = 0 if my_rank == 0 else np.sum(batch_size_per_GPU[:my_rank]) + sample_id_high = np.sum(batch_size_per_GPU[:my_rank + 1]) + + # This is how oversampling is determined in DataLoader + # round(self.batch_size * (1 - self.oversample_foreground_percent)) + # We need to use the same scheme here because an oversample of 0.33 with a batch size of 2 will be rounded + # to an oversample of 0.5 (1 sample random, one oversampled). This may get lost if we just numerically + # compute oversample + oversample = [True if not i < round(global_batch_size * (1 - self.oversample_foreground_percent)) else False + for i in range(global_batch_size)] + + if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent): + oversample_percent = 0.0 + elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent): + oversample_percent = 1.0 + else: + oversample_percent = sum(oversample[sample_id_low:sample_id_high]) / batch_size_per_GPU[my_rank] - print("worker", my_rank, "oversample", oversample_percents[my_rank]) - print("worker", my_rank, "batch_size", batch_sizes[my_rank]) + print("worker", my_rank, "oversample", oversample_percent) + print("worker", my_rank, "batch_size", batch_size_per_GPU[my_rank]) # self.print_to_log_file("worker", my_rank, "oversample", oversample_percents[my_rank]) # self.print_to_log_file("worker", my_rank, "batch_size", batch_sizes[my_rank]) - self.batch_size = batch_sizes[my_rank] - self.oversample_foreground_percent = oversample_percents[my_rank] + self.batch_size = batch_size_per_GPU[my_rank] + self.oversample_foreground_percent = oversample_percent def _build_loss(self): if self.label_manager.has_regions: @@ -352,17 +361,24 @@ def _build_loss(self): 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss) - deep_supervision_scales = self._get_deep_supervision_scales() - # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + if self.is_ddp and not self._do_i_compile(): + # very strange and stupid interaction. DDP crashes and complains about unused parameters due to + # weights[-1] = 0. Interestingly this crash doesn't happen with torch.compile enabled. Strange stuff. + # Anywho, the simple fix is to set a very low weight to this. + weights[-1] = 1e-6 + else: + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): @@ -506,9 +522,9 @@ def plot_network_architecture(self): def do_split(self): """ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, - so always the same) and save it as splits_final.pkl file in the preprocessed data directory. + so always the same) and save it as splits_final.json file in the preprocessed data directory. Sometimes you may want to create your own split for various reasons. For this you will need to create your own - splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in + splits_final.json file. If this file is present, nnU-Net is going to use it and whatever splits are defined in it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to use a random 80:20 data split. @@ -527,15 +543,8 @@ def do_split(self): # if the split file does not exist we need to create it if not isfile(splits_file): self.print_to_log_file("Creating new 5-fold cross-validation split...") - splits = [] - all_keys_sorted = np.sort(list(dataset.keys())) - kfold = KFold(n_splits=5, shuffle=True, random_state=12345) - for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): - train_keys = np.array(all_keys_sorted)[train_idx] - test_keys = np.array(all_keys_sorted)[test_idx] - splits.append({}) - splits[-1]['train'] = list(train_keys) - splits[-1]['val'] = list(test_keys) + all_keys_sorted = list(np.sort(list(dataset.keys()))) + splits = generate_crossval_split(all_keys_sorted, seed=12345, n_splits=5) save_json(splits, splits_file) else: @@ -599,10 +608,15 @@ def get_dataloaders(self): # needed for deep supervision: how much do we need to downscale the segmentation targets for the different # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() - rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ - self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + ( + rotation_for_DA, + do_dummy_2d_data_aug, + initial_patch_size, + mirror_axes, + ) = self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() # training pipeline tr_transforms = self.get_training_transforms( @@ -669,19 +683,21 @@ def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): return dl_tr, dl_val @staticmethod - def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], - rotation_for_DA: dict, - deep_supervision_scales: Union[List, Tuple], - mirror_axes: Tuple[int, ...], - do_dummy_2d_data_aug: bool, - order_resampling_data: int = 3, - order_resampling_seg: int = 1, - border_val_seg: int = -1, - use_mask_for_norm: List[bool] = None, - is_cascaded: bool = False, - foreground_labels: Union[Tuple[int, ...], List[int]] = None, - regions: List[Union[List[int], Tuple[int, ...], int]] = None, - ignore_label: int = None) -> AbstractTransform: + def get_training_transforms( + patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: dict, + deep_supervision_scales: Union[List, Tuple, None], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + order_resampling_data: int = 3, + order_resampling_seg: int = 1, + border_val_seg: int = -1, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None, + ) -> AbstractTransform: tr_transforms = [] if do_dummy_2d_data_aug: ignore_axes = (0,) @@ -761,11 +777,13 @@ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], return tr_transforms @staticmethod - def get_validation_transforms(deep_supervision_scales: Union[List, Tuple], - is_cascaded: bool = False, - foreground_labels: Union[Tuple[int, ...], List[int]] = None, - regions: List[Union[List[int], Tuple[int, ...], int]] = None, - ignore_label: int = None) -> AbstractTransform: + def get_validation_transforms( + deep_supervision_scales: Union[List, Tuple, None], + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None, + ) -> AbstractTransform: val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) @@ -794,9 +812,13 @@ def set_deep_supervision_enabled(self, enabled: bool): chances you need to change this as well! """ if self.is_ddp: - self.network.module.decoder.deep_supervision = enabled + mod = self.network.module else: - self.network.decoder.deep_supervision = enabled + mod = self.network + if isinstance(mod, OptimizedModule): + mod = mod._orig_mod + + mod.decoder.deep_supervision = enabled def on_train_start(self): if not self.was_initialized: @@ -805,7 +827,7 @@ def on_train_start(self): maybe_mkdir_p(self.output_folder) # make sure deep supervision is on in the network - self.set_deep_supervision_enabled(True) + self.set_deep_supervision_enabled(self.enable_deep_supervision) self.print_plans() empty_cache(self.device) @@ -855,9 +877,11 @@ def on_train_end(self): old_stdout = sys.stdout with open(os.devnull, 'w') as f: sys.stdout = f - if self.dataloader_train is not None: + if self.dataloader_train is not None and \ + isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)): self.dataloader_train._finish() - if self.dataloader_val is not None: + if self.dataloader_val is not None and \ + isinstance(self.dataloader_train, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)): self.dataloader_val._finish() sys.stdout = old_stdout @@ -940,9 +964,10 @@ def validation_step(self, batch: dict) -> dict: del data l = self.loss(output, target) - # we only need the output with the highest output resolution - output = output[0] - target = target[0] + # we only need the output with the highest output resolution (if DS enabled) + if self.enable_deep_supervision: + output = output[0] + target = target[0] # the following is needed for online evaluation. Fake dice (green line) axes = [0] + list(range(2, output.ndim)) @@ -1011,8 +1036,7 @@ def on_validation_epoch_end(self, val_outputs: List[dict]): else: loss_here = np.mean(outputs_collated['loss']) - global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in - zip(tp, fp, fn)]] + global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in zip(tp, fp, fn)]] mean_fg_dice = np.nanmean(global_dc_per_class) self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch) self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch) @@ -1024,7 +1048,6 @@ def on_epoch_start(self): def on_epoch_end(self): self.logger.log('epoch_end_timestamps', time(), self.current_epoch) - # todo find a solution for this stupid shit self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4)) self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4)) self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in @@ -1115,8 +1138,18 @@ def perform_actual_validation(self, save_probabilities: bool = False): self.set_deep_supervision_enabled(False) self.network.eval() + if self.is_ddp and self.batch_size == 1 and self.enable_deep_supervision and self._do_i_compile(): + self.print_to_log_file("WARNING! batch size is 1 during training and torch.compile is enabled. If you " + "encounter crashes in validation then this is because torch.compile forgets " + "to trigger a recompilation of the model with deep supervision disabled. " + "This causes torch.flip to complain about getting a tuple as input. Just rerun the " + "validation with --val (exactly the same as before) and then it will work. " + "Why? Because --val triggers nnU-Net to ONLY run validation meaning that the first " + "forward pass (where compile is triggered) already has deep supervision disabled. " + "This is exactly what we need in perform_actual_validation") + predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, - perform_everything_on_gpu=True, device=self.device, verbose=False, + perform_everything_on_device=True, device=self.device, verbose=False, verbose_preprocessing=False, allow_tqdm=False) predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None, self.dataset_json, self.__class__.__name__, @@ -1131,7 +1164,11 @@ def perform_actual_validation(self, save_probabilities: bool = False): # the validation keys across the workers. _, val_keys = self.do_split() if self.is_ddp: + last_barrier_at_idx = len(val_keys) // dist.get_world_size() - 1 + val_keys = val_keys[self.local_rank:: dist.get_world_size()] + # we cannot just have barriers all over the place because the number of keys each GPU receives can be + # different dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, @@ -1144,7 +1181,7 @@ def perform_actual_validation(self, save_probabilities: bool = False): results = [] - for k in dataset_val.keys(): + for i, k in enumerate(dataset_val.keys()): proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, allowed_num_queued=2) while not proceed: @@ -1163,15 +1200,10 @@ def perform_actual_validation(self, save_probabilities: bool = False): warnings.simplefilter("ignore") data = torch.from_numpy(data) + self.print_to_log_file(f'{k}, shape {data.shape}, rank {self.local_rank}') output_filename_truncated = join(validation_output_folder, k) - try: - prediction = predictor.predict_sliding_window_return_logits(data) - except RuntimeError: - predictor.perform_everything_on_gpu = False - prediction = predictor.predict_sliding_window_return_logits(data) - predictor.perform_everything_on_gpu = True - + prediction = predictor.predict_sliding_window_return_logits(data) prediction = prediction.cpu() # this needs to go into background processes @@ -1219,6 +1251,9 @@ def perform_actual_validation(self, save_probabilities: bool = False): self.dataset_json), ) )) + # if we don't barrier from time to time we will get nccl timeouts for large datsets. Yuck. + if self.is_ddp and i < last_barrier_at_idx and (i + 1) % 20 == 0: + dist.barrier() _ = [r.get() for r in results] @@ -1233,7 +1268,9 @@ def perform_actual_validation(self, save_probabilities: bool = False): self.dataset_json["file_ending"], self.label_manager.foreground_regions if self.label_manager.has_regions else self.label_manager.foreground_labels, - self.label_manager.ignore_label, chill=True) + self.label_manager.ignore_label, chill=True, + num_processes=default_num_processes * dist.get_world_size() if + self.is_ddp else default_num_processes) self.print_to_log_file("Validation complete", also_print_to_console=True) self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True) diff --git a/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py b/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py index 6c12ecc84..e7de92cf0 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py +++ b/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py @@ -1,25 +1,39 @@ import torch -from nnunetv2.training.nnUNetTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import \ - nnUNetTrainerBenchmark_5epochs +from nnunetv2.training.nnUNetTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import ( + nnUNetTrainerBenchmark_5epochs, +) from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels class nnUNetTrainerBenchmark_5epochs_noDataLoading(nnUNetTrainerBenchmark_5epochs): - def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device("cuda"), + ): super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) self._set_batch_size_and_oversample() - num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, - self.dataset_json) + num_input_channels = determine_num_input_channels( + self.plans_manager, self.configuration_manager, self.dataset_json + ) patch_size = self.configuration_manager.patch_size dummy_data = torch.rand((self.batch_size, num_input_channels, *patch_size), device=self.device) - dummy_target = [ - torch.round( - torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device) * - max(self.label_manager.all_labels) - ) for k in self._get_deep_supervision_scales()] - self.dummy_batch = {'data': dummy_data, 'target': dummy_target} + if self.enable_deep_supervision: + dummy_target = [ + torch.round( + torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device) + * max(self.label_manager.all_labels) + ) + for k in self._get_deep_supervision_scales() + ] + else: + raise NotImplementedError("This trainer does not support deep supervision") + self.dummy_batch = {"data": dummy_data, "target": dummy_target} def get_dataloaders(self): return None, None diff --git a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py index bd9c31c0b..7250fb845 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py +++ b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py @@ -93,7 +93,7 @@ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): @staticmethod def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], rotation_for_DA: dict, - deep_supervision_scales: Union[List, Tuple], + deep_supervision_scales: Union[List, Tuple, None], mirror_axes: Tuple[int, ...], do_dummy_2d_data_aug: bool, order_resampling_data: int = 3, @@ -233,9 +233,9 @@ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], tr_transforms.append( BrightnessGradientAdditiveTransform( - lambda x, y: np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))), + _brightnessadditive_localgamma_transform_scale, (-0.5, 1.5), - max_strength=lambda x, y: np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5), + max_strength=_brightness_gradient_additive_max_strength, mean_centered=False, same_for_all_channels=False, p_per_sample=0.3, @@ -245,9 +245,9 @@ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], tr_transforms.append( LocalGammaTransform( - lambda x, y: np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))), + _brightnessadditive_localgamma_transform_scale, (-0.5, 1.5), - lambda: np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4), + _local_gamma_gamma, same_for_all_channels=False, p_per_sample=0.3, p_per_channel=0.5 @@ -354,6 +354,18 @@ def get_dataloaders(self): return mt_gen_train, mt_gen_val +def _brightnessadditive_localgamma_transform_scale(x, y): + return np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))) + + +def _brightness_gradient_additive_max_strength(_x, _y): + return np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5) + + +def _local_gamma_gamma(): + return np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4) + + class nnUNetTrainerDA5Segord0(nnUNetTrainerDA5): def get_dataloaders(self): """ diff --git a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py index 527e26292..17f3586db 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py +++ b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py @@ -10,7 +10,7 @@ class nnUNetTrainerNoDA(nnUNetTrainer): @staticmethod def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], rotation_for_DA: dict, - deep_supervision_scales: Union[List, Tuple], + deep_supervision_scales: Union[List, Tuple, None], mirror_axes: Tuple[int, ...], do_dummy_2d_data_aug: bool, order_resampling_data: int = 1, diff --git a/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py b/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py index c8432dfdc..fdc0fea64 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py +++ b/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py @@ -7,27 +7,35 @@ class nnUNetTrainerCELoss(nnUNetTrainer): def _build_loss(self): - assert not self.label_manager.has_regions, 'regions not supported by this trainer' - loss = RobustCrossEntropyLoss(weight=None, - ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100) - - deep_supervision_scales = self._get_deep_supervision_scales() + assert not self.label_manager.has_regions, "regions not supported by this trainer" + loss = RobustCrossEntropyLoss( + weight=None, ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100 + ) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss class nnUNetTrainerCELoss_5epochs(nnUNetTrainerCELoss): - def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device("cuda"), + ): """used for debugging plans etc""" super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) self.num_epochs = 5 diff --git a/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py b/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py index d34c87f66..58993c6fc 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py +++ b/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py @@ -25,17 +25,18 @@ def _build_loss(self): 'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp}, apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1) - deep_supervision_scales = self._get_deep_supervision_scales() - - # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases - # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 - - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss @@ -54,17 +55,18 @@ def _build_loss(self): ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss) - deep_supervision_scales = self._get_deep_supervision_scales() + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() - # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases - # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss class nnUNetTrainerDiceCELoss_noSmooth_50epochs(nnUNetTrainer_50epochs): diff --git a/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py b/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py index afb3fe18e..5eff10e8c 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py +++ b/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py @@ -7,63 +7,70 @@ class nnUNetTrainerTopk10Loss(nnUNetTrainer): def _build_loss(self): - assert not self.label_manager.has_regions, 'regions not supported by this trainer' - loss = TopKLoss(ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, - k=10) + assert not self.label_manager.has_regions, "regions not supported by this trainer" + loss = TopKLoss( + ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, k=10 + ) - deep_supervision_scales = self._get_deep_supervision_scales() + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() - # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases - # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss class nnUNetTrainerTopk10LossLS01(nnUNetTrainer): def _build_loss(self): - assert not self.label_manager.has_regions, 'regions not supported by this trainer' - loss = TopKLoss(ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, - k=10, label_smoothing=0.1) + assert not self.label_manager.has_regions, "regions not supported by this trainer" + loss = TopKLoss( + ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, + k=10, + label_smoothing=0.1, + ) - deep_supervision_scales = self._get_deep_supervision_scales() + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() - # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases - # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss class nnUNetTrainerDiceTopK10Loss(nnUNetTrainer): def _build_loss(self): - assert not self.label_manager.has_regions, 'regions not supported by this trainer' - loss = DC_and_topk_loss({'batch_dice': self.configuration_manager.batch_dice, - 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, - {'k': 10, - 'label_smoothing': 0.0}, - weight_ce=1, weight_dice=1, - ignore_label=self.label_manager.ignore_label) + assert not self.label_manager.has_regions, "regions not supported by this trainer" + loss = DC_and_topk_loss( + {"batch_dice": self.configuration_manager.batch_dice, "smooth": 1e-5, "do_bg": False, "ddp": self.is_ddp}, + {"k": 10, "label_smoothing": 0.0}, + weight_ce=1, + weight_dice=1, + ignore_label=self.label_manager.ignore_label, + ) + if self.enable_deep_supervision: + deep_supervision_scales = self._get_deep_supervision_scales() - deep_supervision_scales = self._get_deep_supervision_scales() + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2**i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 - # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases - # this gives higher resolution outputs more weight in the loss - weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) - weights[-1] = 0 - - # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 - weights = weights / weights.sum() - # now wrap the loss - loss = DeepSupervisionWrapper(loss, weights) + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) return loss diff --git a/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py b/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py index 34f9b554f..1152fbeb4 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py +++ b/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py @@ -1,114 +1,16 @@ -import torch -from torch import autocast - -from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss -from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer -from nnunetv2.utilities.helpers import dummy_context -from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels -from torch.nn.parallel import DistributedDataParallel as DDP +import torch class nnUNetTrainerNoDeepSupervision(nnUNetTrainer): - def _build_loss(self): - if self.label_manager.has_regions: - loss = DC_and_BCE_loss({}, - {'batch_dice': self.configuration_manager.batch_dice, - 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp}, - use_ignore_label=self.label_manager.ignore_label is not None, - dice_class=MemoryEfficientSoftDiceLoss) - else: - loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, - 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, - ignore_label=self.label_manager.ignore_label, - dice_class=MemoryEfficientSoftDiceLoss) - return loss - - def _get_deep_supervision_scales(self): - return None - - def initialize(self): - if not self.was_initialized: - self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, - self.dataset_json) - - self.network = self.build_network_architecture(self.plans_manager, self.dataset_json, - self.configuration_manager, - self.num_input_channels, - enable_deep_supervision=False).to(self.device) - - self.optimizer, self.lr_scheduler = self.configure_optimizers() - # if ddp, wrap in DDP wrapper - if self.is_ddp: - self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network) - self.network = DDP(self.network, device_ids=[self.local_rank]) - - self.loss = self._build_loss() - self.was_initialized = True - else: - raise RuntimeError("You have called self.initialize even though the trainer was already initialized. " - "That should not happen.") - - def set_deep_supervision_enabled(self, enabled: bool): - pass - - def validation_step(self, batch: dict) -> dict: - data = batch['data'] - target = batch['target'] - - data = data.to(self.device, non_blocking=True) - if isinstance(target, list): - target = [i.to(self.device, non_blocking=True) for i in target] - else: - target = target.to(self.device, non_blocking=True) - - self.optimizer.zero_grad(set_to_none=True) - - # Autocast is a little bitch. - # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. - # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) - # So autocast will only be active if we have a cuda device. - with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): - output = self.network(data) - del data - l = self.loss(output, target) - - # the following is needed for online evaluation. Fake dice (green line) - axes = [0] + list(range(2, output.ndim)) - - if self.label_manager.has_regions: - predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() - else: - # no need for softmax - output_seg = output.argmax(1)[:, None] - predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) - predicted_segmentation_onehot.scatter_(1, output_seg, 1) - del output_seg - - if self.label_manager.has_ignore_label: - if not self.label_manager.has_regions: - mask = (target != self.label_manager.ignore_label).float() - # CAREFUL that you don't rely on target after this line! - target[target == self.label_manager.ignore_label] = 0 - else: - mask = 1 - target[:, -1:] - # CAREFUL that you don't rely on target after this line! - target = target[:, :-1] - else: - mask = None - - tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) - - tp_hard = tp.detach().cpu().numpy() - fp_hard = fp.detach().cpu().numpy() - fn_hard = fn.detach().cpu().numpy() - if not self.label_manager.has_regions: - # if we train with regions all segmentation heads predict some kind of foreground. In conventional - # (softmax training) there needs tobe one output for the background. We are not interested in the - # background Dice - # [1:] in order to remove background - tp_hard = tp_hard[1:] - fp_hard = fp_hard[1:] - fn_hard = fn_hard[1:] - - return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} \ No newline at end of file + def __init__( + self, + plans: dict, + configuration: str, + fold: int, + dataset_json: dict, + unpack_dataset: bool = True, + device: torch.device = torch.device("cuda"), + ): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.enable_deep_supervision = False diff --git a/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py b/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py index 89fef482c..467a6fd04 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py +++ b/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Tuple import torch @@ -59,6 +60,13 @@ def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True) return dl_tr, dl_val + def _set_batch_size_and_oversample(self): + old_oversample = deepcopy(self.oversample_foreground_percent) + super()._set_batch_size_and_oversample() + self.oversample_foreground_percent = old_oversample + self.print_to_log_file(f"Ignore previous message about oversample_foreground_percent. " + f"oversample_foreground_percent overwritten to {self.oversample_foreground_percent}") + class nnUNetTrainer_probabilisticOversampling_033(nnUNetTrainer_probabilisticOversampling): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, diff --git a/nnunetv2/utilities/crossval_split.py b/nnunetv2/utilities/crossval_split.py new file mode 100644 index 000000000..472603b00 --- /dev/null +++ b/nnunetv2/utilities/crossval_split.py @@ -0,0 +1,16 @@ +from typing import List + +import numpy as np +from sklearn.model_selection import KFold + + +def generate_crossval_split(train_identifiers: List[str], seed=12345, n_splits=5) -> List[dict[str, List[str]]]: + splits = [] + kfold = KFold(n_splits=n_splits, shuffle=True, random_state=seed) + for i, (train_idx, test_idx) in enumerate(kfold.split(train_identifiers)): + train_keys = np.array(train_identifiers)[train_idx] + test_keys = np.array(train_identifiers)[test_idx] + splits.append({}) + splits[-1]['train'] = list(train_keys) + splits[-1]['val'] = list(test_keys) + return splits diff --git a/pyproject.toml b/pyproject.toml index 8d1369414..91bc31563 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nnunetv2" -version = "2.2" +version = "2.2.1" requires-python = ">=3.9" description = "nnU-Net is a framework for out-of-the box image segmentation." readme = "readme.md"