Skip to content

Commit

Permalink
Merge branch 'MIC-DKFZ:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
fitzjalen authored Apr 25, 2024
2 parents 6c995dc + f14188b commit e2c2e43
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 16 deletions.
4 changes: 2 additions & 2 deletions nnunetv2/inference/data_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(self, list_of_lists: List[List[str]],

def generate_train_batch(self):
idx = self.get_indices()[0]
files, seg_prev_stage, ofile = self._data[idx][0]
files, seg_prev_stage, ofile = self._data[idx]
# if we have a segmentation from the previous stage we have to process it together with the images so that we
# can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
# preprocessing and then there might be misalignments
Expand Down Expand Up @@ -190,7 +190,7 @@ def __init__(self, list_of_images: List[np.ndarray],

def generate_train_batch(self):
idx = self.get_indices()[0]
image, seg_prev_stage, props, ofname = self._data[idx][0]
image, seg_prev_stage, props, ofname = self._data[idx]
# if we have a segmentation from the previous stage we have to process it together with the images so that we
# can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
# preprocessing and then there might be misalignments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import multiprocessing
import shutil
from time import sleep
from typing import Tuple
from typing import Tuple, Union

import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,13 @@ def __init__(self, regions: Union[List, Tuple],

def __call__(self, **data_dict):
seg = data_dict.get(self.seg_key)
num_regions = len(self.regions)
if seg is not None:
seg_shp = seg.shape
output_shape = list(seg_shp)
output_shape[1] = num_regions
region_output = np.zeros(output_shape, dtype=seg.dtype)
for b in range(seg_shp[0]):
for region_id, region_source_labels in enumerate(self.regions):
if not isinstance(region_source_labels, (list, tuple)):
region_source_labels = (region_source_labels, )
for label_value in region_source_labels:
region_output[b, region_id][seg[b, self.seg_channel] == label_value] = 1
data_dict[self.output_key] = region_output
b, c, *shape = seg.shape
region_output = np.zeros((b, len(self.regions), *shape), dtype=bool)
for region_id, region_labels in enumerate(self.regions):
if not isinstance(region_labels, (list, tuple)):
region_labels = (region_labels, )
for label_value in region_labels:
region_output[:, region_id] |= (seg[:, self.seg_channel] == label_value)
data_dict[self.output_key] = region_output.astype(np.uint8, copy=False)
return data_dict
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nnunetv2"
version = "2.4.1"
version = "2.4.2"
requires-python = ">=3.9"
description = "nnU-Net is a framework for out-of-the box image segmentation."
readme = "README.md"
Expand Down

0 comments on commit e2c2e43

Please sign in to comment.