Skip to content

Commit

Permalink
Merge branch 'MIC-DKFZ:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
fitzjalen authored Apr 29, 2024
2 parents e2c2e43 + 5db9604 commit fc36fd5
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 3 additions & 1 deletion nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')

if trainer_class is None:
raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
f'Please place it there (in any .py file)!')
network = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def __call__(self, **data_dict):
b, c, *shape = seg.shape
region_output = np.zeros((b, len(self.regions), *shape), dtype=bool)
for region_id, region_labels in enumerate(self.regions):
if not isinstance(region_labels, (list, tuple)):
region_labels = (region_labels, )
for label_value in region_labels:
region_output[:, region_id] |= (seg[:, self.seg_channel] == label_value)
region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels)
data_dict[self.output_key] = region_output.astype(np.uint8, copy=False)
return data_dict

2 changes: 1 addition & 1 deletion nnunetv2/utilities/find_class_by_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ def recursive_find_python_class(folder: str, class_name: str, current_module: st
tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module)
if tr is not None:
break
return tr
return tr

0 comments on commit fc36fd5

Please sign in to comment.