Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving performance of batchgenerators #113

Open
wants to merge 62 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
19a5f81
Improved performance: np.clip and one percentile call
ancestor-mithril May 15, 2023
57f21cc
Doing inplace operations
ancestor-mithril May 15, 2023
dbba889
Misc + small improvements, preferring tuples to lists
ancestor-mithril May 15, 2023
68efcfd
Improved speed of mean_std normalization + added tests
ancestor-mithril May 16, 2023
26ca086
Vectorized all normalizations per batch and per channel
ancestor-mithril May 16, 2023
7c6f377
Gaussian noise and mean_std refactoring
ancestor-mithril May 16, 2023
659e1fb
Vectorized augment_contrast and augment_brightness_additive
ancestor-mithril May 17, 2023
0a1dd79
Making convert_seg_image_to_one_hot_encoding batched and using it
ancestor-mithril May 17, 2023
b9c9161
Various small improvements
ancestor-mithril May 17, 2023
58e324b
Improving dataloader and numpy to tensor
ancestor-mithril May 17, 2023
f42847d
Setting cast to None in NumpyToTensor
ancestor-mithril May 20, 2023
0637882
Optimizing rotations
ancestor-mithril Aug 16, 2023
9b9c446
Doing batched augment_contrast
ancestor-mithril Aug 16, 2023
ff83845
Augment gamma changes
ancestor-mithril Aug 16, 2023
b29d74a
improving per channel augment gamma
ancestor-mithril Aug 16, 2023
26eb25c
Added batched brightness multiplicative transform
ancestor-mithril Aug 16, 2023
93c08d1
Doing batched augmentation only if batches are not empty
ancestor-mithril Aug 16, 2023
acbd7e9
Factored out the setup for multiplicative brightness
ancestor-mithril Aug 16, 2023
bf174e5
Added batched implementation for Gaussian Noise Transform
ancestor-mithril Aug 16, 2023
7369a8a
Makeup for Gaussian Blur Transform
ancestor-mithril Aug 16, 2023
3020ec5
Removed unittest2 dependency
ancestor-mithril Aug 17, 2023
7912a62
Improved crop and pad augmentation and spatial transforms
ancestor-mithril Aug 17, 2023
a821dee
Improving resample augmentation and resample transform
ancestor-mithril Aug 17, 2023
af70e89
Misc changes for lru_cache to take effect
ancestor-mithril Aug 17, 2023
8d7e7ac
Misc improvements to spatial transform
ancestor-mithril Aug 17, 2023
5c0794f
Improving vectorized computation by broadcasting lower dimensional op…
ancestor-mithril Aug 17, 2023
3b89bf5
Minor changes
ancestor-mithril Aug 22, 2023
f9b71d4
Solved bug with single threaded augmenter due to usage of unexisting …
ancestor-mithril Aug 22, 2023
b6db109
Further improving NumpyToTensor transform
ancestor-mithril Aug 22, 2023
d9b203a
Using pandas unique instead of np unique
ancestor-mithril Aug 22, 2023
e5b342e
Fixing batched operations (new random for each sample)
ancestor-mithril Aug 22, 2023
0821b66
Implementing batched mirror transform
ancestor-mithril Aug 22, 2023
c87865e
Misc change
ancestor-mithril Aug 22, 2023
17a0877
Misc
ancestor-mithril Aug 22, 2023
d8d9d5c
Misc
ancestor-mithril Aug 22, 2023
c4d7c4b
Solving bug with unpickle-able NumpyToTensor
ancestor-mithril Aug 22, 2023
707f3eb
Fixing mirroring
ancestor-mithril Aug 23, 2023
70adf88
Using keepdims instead of broadcasting again
ancestor-mithril Aug 28, 2023
251e74a
Fixed augment mirroring
ancestor-mithril Aug 28, 2023
5cf312b
Sorting pd.unique when needed
ancestor-mithril Aug 30, 2023
8f7da40
Minimizing array copy when data was already np.ndarray
ancestor-mithril Aug 30, 2023
03e6d2c
Adjusted test crop
ancestor-mithril Aug 30, 2023
d697376
Numpy To Tensor
ancestor-mithril Aug 31, 2023
4a1075d
Replaced 'get_range_val' with 'uniform'
ancestor-mithril Sep 6, 2023
5f890e8
Misc
ancestor-mithril Sep 6, 2023
ffb5824
Faster mirroring
ancestor-mithril Sep 6, 2023
6438729
Replaced len(ndarray.shape) to ndarray.ndim
ancestor-mithril Sep 6, 2023
8ffb23a
Rmoved callable feature from augment contrast and augment gamma.
ancestor-mithril Sep 11, 2023
ca82730
Using flynt for conversion to fstring
ancestor-mithril Sep 12, 2023
3733e75
Optimize imports
ancestor-mithril Sep 12, 2023
b81b436
Prefering tuple to list and dtype to astype
ancestor-mithril Sep 18, 2023
1f86cc0
Using rint instead of round
ancestor-mithril Sep 25, 2023
a7575e4
Merge branch 'upstream'
ancestor-mithril Dec 6, 2023
c98f477
Added new TODOS
ancestor-mithril Dec 6, 2023
57ec2fb
Disabling test
ancestor-mithril Dec 6, 2023
01e0d10
Removing redundant dependencies
ancestor-mithril Dec 6, 2023
6388e43
More inplace np.clip
ancestor-mithril Jan 10, 2024
f3b1117
Using and reducing memory allocations when casting to new type
ancestor-mithril Jan 11, 2024
c4eda10
Making resize segmentation faster without additional casting
ancestor-mithril Jan 31, 2024
d06ed91
Making resize segmentation faster without additional casting
ancestor-mithril Jan 31, 2024
6d8058c
Added callable retain_stats and contrast_range arguments back to colo…
ancestor-mithril Feb 1, 2024
2c0958e
Merge branch 'MIC-DKFZ:master' into master
ancestor-mithril Apr 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 105 additions & 84 deletions batchgenerators/augmentations/color_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,64 +13,66 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from builtins import range
from typing import Tuple, Union, Callable
from typing import Tuple, Callable, Union

import numpy as np
from batchgenerators.augmentations.utils import general_cc_var_num_channels, illumination_jitter

from batchgenerators.augmentations.utils import general_cc_var_num_channels, illumination_jitter, get_broadcast_axes, \
reverse_broadcast


def get_augment_contrast_factor(contrast_range: Union[Tuple[float, float], Callable[[], float]],
per_channel: bool,
size: int,
broadcast_size: int):
if per_channel:
factor = []
for _ in range(size):
if callable(contrast_range):
factor.append(contrast_range())
elif contrast_range[0] < 1 and np.random.random() < 0.5:
factor.append(np.random.uniform(contrast_range[0], 1))
else:
factor.append(np.random.uniform(max(contrast_range[0], 1), contrast_range[1]))

def augment_contrast(data_sample: np.ndarray,
contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25),
preserve_range: bool = True,
per_channel: bool = True,
p_per_channel: float = 1) -> np.ndarray:
if not per_channel:
factor = reverse_broadcast(np.array(factor), get_broadcast_axes(broadcast_size))
else:
if callable(contrast_range):
factor = contrast_range()
elif contrast_range[0] < 1 and np.random.random() < 0.5:
factor = np.random.uniform(contrast_range[0], 1)
else:
if np.random.random() < 0.5 and contrast_range[0] < 1:
factor = np.random.uniform(contrast_range[0], 1)
else:
factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1])
factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1])

for c in range(data_sample.shape[0]):
if np.random.uniform() < p_per_channel:
mn = data_sample[c].mean()
if preserve_range:
minm = data_sample[c].min()
maxm = data_sample[c].max()
return factor

data_sample[c] = (data_sample[c] - mn) * factor + mn

if preserve_range:
data_sample[c][data_sample[c] < minm] = minm
data_sample[c][data_sample[c] > maxm] = maxm
else:
for c in range(data_sample.shape[0]):
if np.random.uniform() < p_per_channel:
if callable(contrast_range):
factor = contrast_range()
else:
if np.random.random() < 0.5 and contrast_range[0] < 1:
factor = np.random.uniform(contrast_range[0], 1)
else:
factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1])

mn = data_sample[c].mean()
if preserve_range:
minm = data_sample[c].min()
maxm = data_sample[c].max()

data_sample[c] = (data_sample[c] - mn) * factor + mn

if preserve_range:
data_sample[c][data_sample[c] < minm] = minm
data_sample[c][data_sample[c] > maxm] = maxm
def augment_contrast(data_sample: np.ndarray,
ancestor-mithril marked this conversation as resolved.
Show resolved Hide resolved
contrast_range: Union[Tuple[float, float], Callable[[], float]] = (0.75, 1.25),
preserve_range: bool = True,
per_channel: bool = True,
p_per_channel: float = 1,
batched=False) -> np.ndarray:
mask = np.random.uniform(size=data_sample.shape[:2] if batched else data_sample.shape[0]) < p_per_channel
if np.any(mask):
workon = data_sample[mask]
factor = get_augment_contrast_factor(contrast_range, per_channel, len(workon), workon.ndim)
axes = tuple(range(1, workon.ndim))
mean = workon.mean(axis=axes, keepdims=True)
if preserve_range:
minm = workon.min(axis=axes, keepdims=True)
maxm = workon.max(axis=axes, keepdims=True)

data_sample[mask] = workon * factor + mean * (1 - factor) # writing directly in data_sample

if preserve_range:
np.clip(data_sample[mask], minm, maxm, out=data_sample[mask])

return data_sample


def augment_brightness_additive(data_sample, mu:float, sigma:float , per_channel:bool=True, p_per_channel:float=1.):
def augment_brightness_additive(data_sample, mu: float, sigma: float, per_channel: bool = True,
p_per_channel: float = 1.):
"""
data_sample must have shape (c, x, y(, z)))
:param data_sample:
Expand All @@ -80,27 +82,29 @@ def augment_brightness_additive(data_sample, mu:float, sigma:float , per_channel
:param p_per_channel:
:return:
"""
if not per_channel:
rnd_nb = np.random.normal(mu, sigma)
for c in range(data_sample.shape[0]):
if np.random.uniform() <= p_per_channel:
data_sample[c] += rnd_nb
size = data_sample.shape[0]
if per_channel:
rnd_nb = np.random.normal(mu, sigma, size=size)
else:
for c in range(data_sample.shape[0]):
if np.random.uniform() <= p_per_channel:
rnd_nb = np.random.normal(mu, sigma)
data_sample[c] += rnd_nb
rnd_nb = np.repeat(np.random.normal(mu, sigma), size)
rnd_nb[np.random.uniform(size=size) > p_per_channel] = 0.0
data_sample += reverse_broadcast(rnd_nb, get_broadcast_axes(data_sample.ndim))
return data_sample


def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True):
multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1])
if not per_channel:
data_sample *= multiplier
else:
for c in range(data_sample.shape[0]):
multiplier = np.random.uniform(multiplier_range[0], multiplier_range[1])
data_sample[c] *= multiplier
def setup_augment_brightness_multiplicative(per_channel: bool, batched: bool, shape: Tuple[int, ...]):
if per_channel:
if batched:
return shape[:2] + (1,) * (len(shape) - 2)
return (shape[0],) + (1,) * (len(shape) - 1)
if batched:
return (shape[0],) + (1,) * (len(shape) - 1)
return (1,) * len(shape)


def augment_brightness_multiplicative(data_sample, multiplier_range=(0.5, 2), per_channel=True, batched=False):
size = setup_augment_brightness_multiplicative(per_channel, batched, data_sample.shape)
data_sample *= np.random.uniform(multiplier_range[0], multiplier_range[1], size=size)
return data_sample


Expand All @@ -110,38 +114,55 @@ def augment_gamma(data_sample, gamma_range=(0.5, 2), invert_image=False, epsilon
data_sample = - data_sample

if not per_channel:
retain_stats_here = retain_stats() if callable(retain_stats) else retain_stats
if retain_stats_here:
retain_stats = retain_stats() if callable(retain_stats) else retain_stats
if retain_stats:
mn = data_sample.mean()
sd = data_sample.std()
if np.random.random() < 0.5 and gamma_range[0] < 1:
if gamma_range[0] < 1 and np.random.random() < 0.5:
gamma = np.random.uniform(gamma_range[0], 1)
else:
gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1])
minm = data_sample.min()
rnge = data_sample.max() - minm
data_sample = np.power(((data_sample - minm) / float(rnge + epsilon)), gamma) * rnge + minm
if retain_stats_here:
data_sample = data_sample - data_sample.mean()
data_sample = data_sample / (data_sample.std() + 1e-8) * sd
data_sample = data_sample + mn
if retain_stats:
data_sample -= data_sample.mean()
data_sample *= sd / (data_sample.std() + 1e-8)
data_sample += mn
else:
for c in range(data_sample.shape[0]):
retain_stats_here = retain_stats() if callable(retain_stats) else retain_stats
if retain_stats_here:
mn = data_sample[c].mean()
sd = data_sample[c].std()
if np.random.random() < 0.5 and gamma_range[0] < 1:
gamma = np.random.uniform(gamma_range[0], 1)
shape_0 = data_sample.shape[0]
gamma = []
gamma_l = max(gamma_range[0], 1)
for i in range(shape_0):
if gamma_range[0] < 1 and np.random.random() < 0.5:
gamma.append(np.random.uniform(gamma_range[0], 1))
else:
gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1])
minm = data_sample[c].min()
rnge = data_sample[c].max() - minm
data_sample[c] = np.power(((data_sample[c] - minm) / float(rnge + epsilon)), gamma) * float(rnge + epsilon) + minm
if retain_stats_here:
data_sample[c] = data_sample[c] - data_sample[c].mean()
data_sample[c] = data_sample[c] / (data_sample[c].std() + 1e-8) * sd
data_sample[c] = data_sample[c] + mn
gamma.append(np.random.uniform(gamma_l, gamma_range[1]))
gamma = np.array(gamma)

axes = tuple(range(1, data_sample.ndim))

if callable(retain_stats):
retain_stats = [retain_stats() for _ in range(shape_0)]
else:
retain_stats = [retain_stats] * shape_0
retain_stats_here = any(retain_stats)
if retain_stats_here:
mn = data_sample[retain_stats].mean(axis=axes, keepdims=True)
sd = data_sample[retain_stats].mean(axis=axes, keepdims=True)

minm = data_sample.min(axis=axes, keepdims=True)
rnge = data_sample.max(axis=axes, keepdims=True) - minm + epsilon

broadcast_axes = get_broadcast_axes(data_sample.ndim)
gamma = reverse_broadcast(gamma, broadcast_axes)
data_sample = np.power((data_sample - minm) / rnge, gamma) * rnge + minm

if retain_stats_here:
data_sample[retain_stats] -= data_sample[retain_stats].mean(axis=axes, keepdims=True)
data_sample[retain_stats] *= sd / (data_sample[retain_stats].std(axis=axes, keepdims=True) + 1e-8)
data_sample[retain_stats] += mn

if invert_image:
data_sample = - data_sample
return data_sample
Expand Down
82 changes: 38 additions & 44 deletions batchgenerators/augmentations/crop_and_pad_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from builtins import range
import numpy as np
from batchgenerators.augmentations.utils import pad_nd_image
from typing import Union, Sequence


def center_crop(data, crop_size, seg=None):
Expand All @@ -30,28 +30,24 @@ def get_lbs_for_random_crop(crop_size, data_shape, margins):
:param margins:
:return:
"""
lbs = []
for i in range(len(data_shape) - 2):
if data_shape[i+2] - crop_size[i] - margins[i] > margins[i]:
lbs.append(np.random.randint(margins[i], data_shape[i+2] - crop_size[i] - margins[i]))
else:
lbs.append((data_shape[i+2] - crop_size[i]) // 2)
return lbs
new_shape = data_shape - crop_size
mask = new_shape > 2 * margins
new_shape[mask] = np.random.randint(margins[mask], new_shape[mask] - margins[mask])
new_shape[~mask] //= 2
return new_shape


def get_lbs_for_center_crop(crop_size, data_shape):
"""
:param crop_size:
:param data_shape: (b,c,x,y(,z)) must be the whole thing!
:param data_shape: (b,c,x,y(,z)) must be the only x,y(,z)!
:return:
"""
lbs = []
for i in range(len(data_shape) - 2):
lbs.append((data_shape[i + 2] - crop_size[i]) // 2)
return lbs
return (data_shape - crop_size) // 2


def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center",
def crop(data: Union[Sequence[np.ndarray], np.ndarray], seg: Union[Sequence[np.ndarray], np.ndarray] = None,
crop_size=128, margins=(0, 0, 0), crop_type="center",
pad_mode='constant', pad_kwargs={'constant_values': 0},
pad_mode_seg='constant', pad_kwargs_seg={'constant_values': 0}):
"""
Expand All @@ -69,44 +65,39 @@ def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center",
:param crop_type: random or center
:return:
"""
if not isinstance(data, (list, tuple, np.ndarray)):
raise TypeError("data has to be either a numpy array or a list")

data_shape = tuple([len(data)] + list(data[0].shape))
data_shape = (len(data),) + data[0].shape
data_dtype = data[0].dtype
dim = len(data_shape) - 2

if seg is not None:
seg_shape = tuple([len(seg)] + list(seg[0].shape))
seg_shape = (len(seg),) + seg[0].shape
seg_dtype = seg[0].dtype

if not isinstance(seg, (list, tuple, np.ndarray)):
raise TypeError("data has to be either a numpy array or a list")

assert all([i == j for i, j in zip(seg_shape[2:], data_shape[2:])]), "data and seg must have the same spatial " \
"dimensions. Data: %s, seg: %s" % \
(str(data_shape), str(seg_shape))
assert np.array_equal(seg_shape[2:], data_shape[2:]), "data and seg must have the same spatial dimensions. " \
f"Data: {data_shape}, seg: {seg_shape}"

if type(crop_size) not in (tuple, list, np.ndarray):
crop_size = [crop_size] * dim
crop_size = (crop_size,) * dim
else:
assert len(crop_size) == len(
data_shape) - 2, "If you provide a list/tuple as center crop make sure it has the same dimension as your " \
"data (2d/3d)"
assert len(crop_size) == dim, ("If you provide a list/tuple as center crop make sure it has the same dimension "
"as your data (2d/3d)")
crop_size = np.asarray(crop_size)

if not isinstance(margins, (np.ndarray, tuple, list)):
margins = [margins] * dim
margins = (margins,) * dim
margins = np.asarray(margins)

data_return = np.zeros([data_shape[0], data_shape[1]] + list(crop_size), dtype=data_dtype)
data_return = np.zeros((data_shape[0], data_shape[1], *crop_size), dtype=data_dtype)
if seg is not None:
seg_return = np.zeros([seg_shape[0], seg_shape[1]] + list(crop_size), dtype=seg_dtype)
seg_return = np.zeros((seg_shape[0], seg_shape[1], *crop_size), dtype=seg_dtype)
else:
seg_return = None

for b in range(data_shape[0]):
data_shape_here = [data_shape[0]] + list(data[b].shape)
data_first_dim = data[b].shape[0]
data_shape_here = np.array(data[b].shape[1:])
if seg is not None:
seg_shape_here = [seg_shape[0]] + list(seg[b].shape)
seg_first_dim = seg[b].shape[0]

if crop_type == "center":
lbs = get_lbs_for_center_crop(crop_size, data_shape_here)
Expand All @@ -115,22 +106,25 @@ def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center",
else:
raise NotImplementedError("crop_type must be either center or random")

need_to_pad = [[0, 0]] + [[abs(min(0, lbs[d])),
abs(min(0, data_shape_here[d + 2] - (lbs[d] + crop_size[d])))]
for d in range(dim)]
zero = np.zeros(dim, dtype=int)
temp1 = np.abs(np.minimum(lbs, zero))
lbs_plus_crop_size = lbs + crop_size
temp2 = np.abs(np.minimum(zero, data_shape_here - lbs_plus_crop_size))
need_to_pad = ((0, 0),) + tuple(zip(temp1, temp2))
need_to_pad = np.array(need_to_pad)

# we should crop first, then pad -> reduces i/o for memmaps, reduces RAM usage and improves speed
ubs = [min(lbs[d] + crop_size[d], data_shape_here[d+2]) for d in range(dim)]
lbs = [max(0, lbs[d]) for d in range(dim)]
ubs = np.minimum(data_shape_here, lbs_plus_crop_size)
lbs = np.maximum(zero, lbs)

slicer_data = [slice(0, data_shape_here[1])] + [slice(lbs[d], ubs[d]) for d in range(dim)]
data_cropped = data[b][tuple(slicer_data)]
slicer_data = (slice(0, data_first_dim), *[slice(lbs[d], ubs[d]) for d in range(dim)])
data_cropped = data[b][slicer_data]

if seg_return is not None:
slicer_seg = [slice(0, seg_shape_here[1])] + [slice(lbs[d], ubs[d]) for d in range(dim)]
seg_cropped = seg[b][tuple(slicer_seg)]
slicer_data = (slice(0, seg_first_dim),) + slicer_data[1:]
seg_cropped = seg[b][slicer_data]

if any([i > 0 for j in need_to_pad for i in j]):
if np.any(need_to_pad):
data_return[b] = np.pad(data_cropped, need_to_pad, pad_mode, **pad_kwargs)
if seg_return is not None:
seg_return[b] = np.pad(seg_cropped, need_to_pad, pad_mode_seg, **pad_kwargs_seg)
Expand Down
Loading