diff --git a/documentation/how_to_use_nnunet.md b/documentation/how_to_use_nnunet.md index 290ae8873..8b34b4045 100644 --- a/documentation/how_to_use_nnunet.md +++ b/documentation/how_to_use_nnunet.md @@ -163,7 +163,7 @@ a plot of the training (blue) and validation (red) loss during training. Also sh average over all cases but pretend that there is only one validation case from which we sample patches). The reason for this is that the 'global Dice' is easy to compute during training and is still quite useful to evaluate whether a model is training at all or not. A proper validation takes way too long to be done each epoch. It is run at the end of the training. -- validation_raw: in this folder are the predicted validation cases after the training has finished. The summary.json file in here +- validation: in this folder are the predicted validation cases after the training has finished. The summary.json file in here contains the validation metrics (a mean over all cases is provided at the start of the file). If `--npz` was set then the compressed softmax outputs (saved as .npz files) are in here as well. diff --git a/nnunetv2/inference/data_iterators.py b/nnunetv2/inference/data_iterators.py index f71c73a2f..2486bf6df 100644 --- a/nnunetv2/inference/data_iterators.py +++ b/nnunetv2/inference/data_iterators.py @@ -118,6 +118,7 @@ def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]], yield item [p.join() for p in processes] + class PreprocessAdapter(DataLoader): def __init__(self, list_of_lists: List[List[str]], list_of_segs_from_prev_stage_files: Union[None, List[str]], diff --git a/nnunetv2/preprocessing/resampling/default_resampling.py b/nnunetv2/preprocessing/resampling/default_resampling.py index b710c7dac..1b05a9bcc 100644 --- a/nnunetv2/preprocessing/resampling/default_resampling.py +++ b/nnunetv2/preprocessing/resampling/default_resampling.py @@ -31,6 +31,7 @@ def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray], return new_shape + def determine_do_sep_z_and_axis( force_separate_z: bool, current_spacing, diff --git a/nnunetv2/training/dataloading/data_loader_2d.py b/nnunetv2/training/dataloading/data_loader_2d.py index 655a7aae5..08bfad87a 100644 --- a/nnunetv2/training/dataloading/data_loader_2d.py +++ b/nnunetv2/training/dataloading/data_loader_2d.py @@ -101,7 +101,10 @@ def generate_train_batch(self): images.append(tmp['image']) segs.append(tmp['segmentation']) data_all = torch.stack(images) - seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))] + if isinstance(segs[0], list): + seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))] + else: + seg_all = torch.stack(segs) del segs, images if torch is not None: torch.set_num_threads(torch_nthreads) diff --git a/nnunetv2/training/dataloading/data_loader_3d.py b/nnunetv2/training/dataloading/data_loader_3d.py index 3131e1f09..d17928475 100644 --- a/nnunetv2/training/dataloading/data_loader_3d.py +++ b/nnunetv2/training/dataloading/data_loader_3d.py @@ -64,7 +64,10 @@ def generate_train_batch(self): images.append(tmp['image']) segs.append(tmp['segmentation']) data_all = torch.stack(images) - seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))] + if isinstance(segs[0], list): + seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))] + else: + seg_all = torch.stack(segs) del segs, images if torch is not None: torch.set_num_threads(torch_nthreads) diff --git a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py index 7bf946013..f7fa420c5 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py +++ b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py @@ -1,3 +1,27 @@ +from typing import Union, Tuple, List + +import numpy as np +from batchgeneratorsv2.helpers.scalar_type import RandomScalar +from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform +from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform +from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast +from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform +from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform +from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform +from batchgeneratorsv2.transforms.nnunet.remove_connected_components import \ + RemoveRandomConnectedComponentFromOneHotEncodingTransform +from batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform +from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform +from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform +from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform +from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform +from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms +from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform +from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform +from batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform +from batchgeneratorsv2.transforms.utils.random import RandomTransform +from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform +from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.training.nnUNetTrainer.variants.training_length import nnUNetTrainer_Xepochs @@ -25,3 +49,155 @@ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): self.inference_allowed_mirroring_axes = mirror_axes return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + +class nnUNetTrainer_onlyMirror01_DASegOrd0(nnUNetTrainer_onlyMirror01): + @staticmethod + def get_training_transforms( + patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: RandomScalar, + deep_supervision_scales: Union[List, Tuple, None], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None, + ) -> BasicTransform: + transforms = [] + if do_dummy_2d_data_aug: + ignore_axes = (0,) + transforms.append(Convert3DTo2DTransform()) + patch_size_spatial = patch_size[1:] + else: + patch_size_spatial = patch_size + ignore_axes = None + transforms.append( + SpatialTransform( + patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0, + p_rotation=0.2, + rotation=rotation_for_DA, p_scaling=0.2, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1, + bg_style_seg_sampling=False, mode_seg='nearest' + ) + ) + + if do_dummy_2d_data_aug: + transforms.append(Convert2DTo3DTransform()) + + transforms.append(RandomTransform( + GaussianNoiseTransform( + noise_variance=(0, 0.1), + p_per_channel=1, + synchronize_channels=True + ), apply_probability=0.1 + )) + transforms.append(RandomTransform( + GaussianBlurTransform( + blur_sigma=(0.5, 1.), + synchronize_channels=False, + synchronize_axes=False, + p_per_channel=0.5, benchmark=True + ), apply_probability=0.2 + )) + transforms.append(RandomTransform( + MultiplicativeBrightnessTransform( + multiplier_range=BGContrast((0.75, 1.25)), + synchronize_channels=False, + p_per_channel=1 + ), apply_probability=0.15 + )) + transforms.append(RandomTransform( + ContrastTransform( + contrast_range=BGContrast((0.75, 1.25)), + preserve_range=True, + synchronize_channels=False, + p_per_channel=1 + ), apply_probability=0.15 + )) + transforms.append(RandomTransform( + SimulateLowResolutionTransform( + scale=(0.5, 1), + synchronize_channels=False, + synchronize_axes=True, + ignore_axes=ignore_axes, + allowed_channels=None, + p_per_channel=0.5 + ), apply_probability=0.25 + )) + transforms.append(RandomTransform( + GammaTransform( + gamma=BGContrast((0.7, 1.5)), + p_invert_image=1, + synchronize_channels=False, + p_per_channel=1, + p_retain_stats=1 + ), apply_probability=0.1 + )) + transforms.append(RandomTransform( + GammaTransform( + gamma=BGContrast((0.7, 1.5)), + p_invert_image=0, + synchronize_channels=False, + p_per_channel=1, + p_retain_stats=1 + ), apply_probability=0.3 + )) + if mirror_axes is not None and len(mirror_axes) > 0: + transforms.append( + MirrorTransform( + allowed_axes=mirror_axes + ) + ) + + if use_mask_for_norm is not None and any(use_mask_for_norm): + transforms.append(MaskImageTransform( + apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], + channel_idx_in_seg=0, + set_outside_to=0, + )) + + transforms.append( + RemoveLabelTansform(-1, 0) + ) + if is_cascaded: + assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations' + transforms.append( + MoveSegAsOneHotToDataTransform( + source_channel_idx=1, + all_labels=foreground_labels, + remove_channel_from_source=True + ) + ) + transforms.append( + RandomTransform( + ApplyRandomBinaryOperatorTransform( + channel_idx=list(range(-len(foreground_labels), 0)), + strel_size=(1, 8), + p_per_label=1 + ), apply_probability=0.4 + ) + ) + transforms.append( + RandomTransform( + RemoveRandomConnectedComponentFromOneHotEncodingTransform( + channel_idx=list(range(-len(foreground_labels), 0)), + fill_with_other_class_p=0, + dont_do_if_covers_more_than_x_percent=0.15, + p_per_label=1 + ), apply_probability=0.2 + ) + ) + + if regions is not None: + # the ignore label must also be converted + transforms.append( + ConvertSegmentationToRegionsTransform( + regions=list(regions) + [ignore_label] if ignore_label is not None else regions, + channel_in_seg=0 + ) + ) + + if deep_supervision_scales is not None: + transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales)) + + return ComposeTransforms(transforms) \ No newline at end of file diff --git a/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py b/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py index 5d05ac570..2804711aa 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py +++ b/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py @@ -77,6 +77,20 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) self.num_epochs = 1000 +class nnUNetTrainer_500epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 500 + + +class nnUNetTrainer_750epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 750 + + class nnUNetTrainer_2000epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, device: torch.device = torch.device('cuda')): diff --git a/nnunetv2/utilities/plans_handling/plans_handler.py b/nnunetv2/utilities/plans_handling/plans_handler.py index 11b76dfcb..eab65885a 100644 --- a/nnunetv2/utilities/plans_handling/plans_handler.py +++ b/nnunetv2/utilities/plans_handling/plans_handler.py @@ -54,6 +54,8 @@ def __init__(self, configuration_dict: dict): conv_op = convert_dim_to_conv_op(dim) instnorm = get_matching_instancenorm(dimension=dim) + convs_or_blocks = "n_conv_per_stage" if unet_class_name == "PlainConvUNet" else "n_blocks_per_stage" + arch_dict = { 'network_class_name': network_class_name, 'arch_kwargs': { @@ -64,7 +66,7 @@ def __init__(self, configuration_dict: dict): "conv_op": conv_op.__module__ + '.' + conv_op.__name__, "kernel_sizes": deepcopy(self.configuration["conv_kernel_sizes"]), "strides": deepcopy(self.configuration["pool_op_kernel_sizes"]), - "n_conv_per_stage": deepcopy(self.configuration["n_conv_per_stage_encoder"]), + convs_or_blocks: deepcopy(self.configuration["n_conv_per_stage_encoder"]), "n_conv_per_stage_decoder": deepcopy(self.configuration["n_conv_per_stage_decoder"]), "conv_bias": True, "norm_op": instnorm.__module__ + '.' + instnorm.__name__,