Skip to content

Commit

Permalink
Merge branch 'MIC-DKFZ:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
fitzjalen authored Jun 19, 2024
2 parents 456a537 + ec229b5 commit 433afc4
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 4 deletions.
2 changes: 1 addition & 1 deletion documentation/how_to_use_nnunet.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ a plot of the training (blue) and validation (red) loss during training. Also sh
average over all cases but pretend that there is only one validation case from which we sample patches). The reason for
this is that the 'global Dice' is easy to compute during training and is still quite useful to evaluate whether a model
is training at all or not. A proper validation takes way too long to be done each epoch. It is run at the end of the training.
- validation_raw: in this folder are the predicted validation cases after the training has finished. The summary.json file in here
- validation: in this folder are the predicted validation cases after the training has finished. The summary.json file in here
contains the validation metrics (a mean over all cases is provided at the start of the file). If `--npz` was set then
the compressed softmax outputs (saved as .npz files) are in here as well.

Expand Down
1 change: 1 addition & 0 deletions nnunetv2/inference/data_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]],
yield item
[p.join() for p in processes]


class PreprocessAdapter(DataLoader):
def __init__(self, list_of_lists: List[List[str]],
list_of_segs_from_prev_stage_files: Union[None, List[str]],
Expand Down
1 change: 1 addition & 0 deletions nnunetv2/preprocessing/resampling/default_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
return new_shape



def determine_do_sep_z_and_axis(
force_separate_z: bool,
current_spacing,
Expand Down
5 changes: 4 additions & 1 deletion nnunetv2/training/dataloading/data_loader_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def generate_train_batch(self):
images.append(tmp['image'])
segs.append(tmp['segmentation'])
data_all = torch.stack(images)
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
if isinstance(segs[0], list):
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
else:
seg_all = torch.stack(segs)
del segs, images
if torch is not None:
torch.set_num_threads(torch_nthreads)
Expand Down
5 changes: 4 additions & 1 deletion nnunetv2/training/dataloading/data_loader_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def generate_train_batch(self):
images.append(tmp['image'])
segs.append(tmp['segmentation'])
data_all = torch.stack(images)
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
if isinstance(segs[0], list):
seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))]
else:
seg_all = torch.stack(segs)
del segs, images
if torch is not None:
torch.set_num_threads(torch_nthreads)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
from typing import Union, Tuple, List

import numpy as np
from batchgeneratorsv2.helpers.scalar_type import RandomScalar
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform
from batchgeneratorsv2.transforms.intensity.contrast import ContrastTransform, BGContrast
from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform
from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform
from batchgeneratorsv2.transforms.nnunet.random_binary_operator import ApplyRandomBinaryOperatorTransform
from batchgeneratorsv2.transforms.nnunet.remove_connected_components import \
RemoveRandomConnectedComponentFromOneHotEncodingTransform
from batchgeneratorsv2.transforms.nnunet.seg_to_onehot import MoveSegAsOneHotToDataTransform
from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform
from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform
from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform
from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform
from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms
from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform
from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform
from batchgeneratorsv2.transforms.utils.pseudo2d import Convert3DTo2DTransform, Convert2DTo3DTransform
from batchgeneratorsv2.transforms.utils.random import RandomTransform
from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform
from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.training.nnUNetTrainer.variants.training_length import nnUNetTrainer_Xepochs

Expand Down Expand Up @@ -25,3 +49,155 @@ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
self.inference_allowed_mirroring_axes = mirror_axes
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes


class nnUNetTrainer_onlyMirror01_DASegOrd0(nnUNetTrainer_onlyMirror01):
@staticmethod
def get_training_transforms(
patch_size: Union[np.ndarray, Tuple[int]],
rotation_for_DA: RandomScalar,
deep_supervision_scales: Union[List, Tuple, None],
mirror_axes: Tuple[int, ...],
do_dummy_2d_data_aug: bool,
use_mask_for_norm: List[bool] = None,
is_cascaded: bool = False,
foreground_labels: Union[Tuple[int, ...], List[int]] = None,
regions: List[Union[List[int], Tuple[int, ...], int]] = None,
ignore_label: int = None,
) -> BasicTransform:
transforms = []
if do_dummy_2d_data_aug:
ignore_axes = (0,)
transforms.append(Convert3DTo2DTransform())
patch_size_spatial = patch_size[1:]
else:
patch_size_spatial = patch_size
ignore_axes = None
transforms.append(
SpatialTransform(
patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0,
p_rotation=0.2,
rotation=rotation_for_DA, p_scaling=0.2, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1,
bg_style_seg_sampling=False, mode_seg='nearest'
)
)

if do_dummy_2d_data_aug:
transforms.append(Convert2DTo3DTransform())

transforms.append(RandomTransform(
GaussianNoiseTransform(
noise_variance=(0, 0.1),
p_per_channel=1,
synchronize_channels=True
), apply_probability=0.1
))
transforms.append(RandomTransform(
GaussianBlurTransform(
blur_sigma=(0.5, 1.),
synchronize_channels=False,
synchronize_axes=False,
p_per_channel=0.5, benchmark=True
), apply_probability=0.2
))
transforms.append(RandomTransform(
MultiplicativeBrightnessTransform(
multiplier_range=BGContrast((0.75, 1.25)),
synchronize_channels=False,
p_per_channel=1
), apply_probability=0.15
))
transforms.append(RandomTransform(
ContrastTransform(
contrast_range=BGContrast((0.75, 1.25)),
preserve_range=True,
synchronize_channels=False,
p_per_channel=1
), apply_probability=0.15
))
transforms.append(RandomTransform(
SimulateLowResolutionTransform(
scale=(0.5, 1),
synchronize_channels=False,
synchronize_axes=True,
ignore_axes=ignore_axes,
allowed_channels=None,
p_per_channel=0.5
), apply_probability=0.25
))
transforms.append(RandomTransform(
GammaTransform(
gamma=BGContrast((0.7, 1.5)),
p_invert_image=1,
synchronize_channels=False,
p_per_channel=1,
p_retain_stats=1
), apply_probability=0.1
))
transforms.append(RandomTransform(
GammaTransform(
gamma=BGContrast((0.7, 1.5)),
p_invert_image=0,
synchronize_channels=False,
p_per_channel=1,
p_retain_stats=1
), apply_probability=0.3
))
if mirror_axes is not None and len(mirror_axes) > 0:
transforms.append(
MirrorTransform(
allowed_axes=mirror_axes
)
)

if use_mask_for_norm is not None and any(use_mask_for_norm):
transforms.append(MaskImageTransform(
apply_to_channels=[i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],
channel_idx_in_seg=0,
set_outside_to=0,
))

transforms.append(
RemoveLabelTansform(-1, 0)
)
if is_cascaded:
assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations'
transforms.append(
MoveSegAsOneHotToDataTransform(
source_channel_idx=1,
all_labels=foreground_labels,
remove_channel_from_source=True
)
)
transforms.append(
RandomTransform(
ApplyRandomBinaryOperatorTransform(
channel_idx=list(range(-len(foreground_labels), 0)),
strel_size=(1, 8),
p_per_label=1
), apply_probability=0.4
)
)
transforms.append(
RandomTransform(
RemoveRandomConnectedComponentFromOneHotEncodingTransform(
channel_idx=list(range(-len(foreground_labels), 0)),
fill_with_other_class_p=0,
dont_do_if_covers_more_than_x_percent=0.15,
p_per_label=1
), apply_probability=0.2
)
)

if regions is not None:
# the ignore label must also be converted
transforms.append(
ConvertSegmentationToRegionsTransform(
regions=list(regions) + [ignore_label] if ignore_label is not None else regions,
channel_in_seg=0
)
)

if deep_supervision_scales is not None:
transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))

return ComposeTransforms(transforms)
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,20 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.num_epochs = 1000

class nnUNetTrainer_500epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.num_epochs = 500


class nnUNetTrainer_750epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.num_epochs = 750


class nnUNetTrainer_2000epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
Expand Down
4 changes: 3 additions & 1 deletion nnunetv2/utilities/plans_handling/plans_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(self, configuration_dict: dict):
conv_op = convert_dim_to_conv_op(dim)
instnorm = get_matching_instancenorm(dimension=dim)

convs_or_blocks = "n_conv_per_stage" if unet_class_name == "PlainConvUNet" else "n_blocks_per_stage"

arch_dict = {
'network_class_name': network_class_name,
'arch_kwargs': {
Expand All @@ -64,7 +66,7 @@ def __init__(self, configuration_dict: dict):
"conv_op": conv_op.__module__ + '.' + conv_op.__name__,
"kernel_sizes": deepcopy(self.configuration["conv_kernel_sizes"]),
"strides": deepcopy(self.configuration["pool_op_kernel_sizes"]),
"n_conv_per_stage": deepcopy(self.configuration["n_conv_per_stage_encoder"]),
convs_or_blocks: deepcopy(self.configuration["n_conv_per_stage_encoder"]),
"n_conv_per_stage_decoder": deepcopy(self.configuration["n_conv_per_stage_decoder"]),
"conv_bias": True,
"norm_op": instnorm.__module__ + '.' + instnorm.__name__,
Expand Down

0 comments on commit 433afc4

Please sign in to comment.