diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index c28b41e1b..8d4096482 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -96,7 +96,9 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str, num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, 'nnunetv2.training.nnUNetTrainer') - + if trainer_class is None: + raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. ' + f'Please place it there (in any .py file)!') network = trainer_class.build_network_architecture( configuration_manager.network_arch_class_name, configuration_manager.network_arch_init_kwargs, diff --git a/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py b/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py index 93e8f5edb..eb45d26f2 100644 --- a/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py +++ b/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py @@ -26,9 +26,7 @@ def __call__(self, **data_dict): b, c, *shape = seg.shape region_output = np.zeros((b, len(self.regions), *shape), dtype=bool) for region_id, region_labels in enumerate(self.regions): - if not isinstance(region_labels, (list, tuple)): - region_labels = (region_labels, ) - for label_value in region_labels: - region_output[:, region_id] |= (seg[:, self.seg_channel] == label_value) + region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels) data_dict[self.output_key] = region_output.astype(np.uint8, copy=False) return data_dict + diff --git a/nnunetv2/utilities/find_class_by_name.py b/nnunetv2/utilities/find_class_by_name.py index a345d99a7..223b3acc3 100644 --- a/nnunetv2/utilities/find_class_by_name.py +++ b/nnunetv2/utilities/find_class_by_name.py @@ -21,4 +21,4 @@ def recursive_find_python_class(folder: str, class_name: str, current_module: st tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module) if tr is not None: break - return tr \ No newline at end of file + return tr