Skip to content

Commit

Permalink
increase the batch_size in blend_3d
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanguage committed Sep 27, 2023
1 parent 684b827 commit 4862c33
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
12 changes: 9 additions & 3 deletions ufish/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,13 @@ def _enhance_img2d(self, img: np.ndarray) -> np.ndarray:
def _enhance_img3d(
self, img: np.ndarray, batch_size: int = 4) -> np.ndarray:
"""Enhance a 3D image."""
logger.info(
f'Enhancing 3D image in shape {img.shape}, '
f'batch size: {batch_size}')
output = np.zeros_like(img, dtype=np.float32)
for i in range(0, output.shape[0], batch_size):
logger.info(
f'Enhancing slice {i}-{i+batch_size}/{output.shape[0]}')
_slice = img[i:i+batch_size][:, np.newaxis]
output[i:i+batch_size] = self.infer(_slice)[:, 0]
return output
Expand All @@ -319,10 +324,11 @@ def _enhance_2d_or_3d(
'Image does not have a z axis, ' +
'cannot blend along z axis.')
from .utils.img import enhance_blend_3d
enh_func = partial(
self._enhance_img3d, batch_size=batch_size)
logger.info(
"Blending 3D image from 3 directions: z, y, x.")
output = enhance_blend_3d(
img, enh_func, axes=axes)
img, self._enhance_img3d, axes=axes,
batch_size=batch_size)
else:
output = self._enhance_img3d(img, batch_size=batch_size)
else:
Expand Down
21 changes: 17 additions & 4 deletions ufish/utils/img.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,24 +307,37 @@ def chunks_iterator(

def enhance_blend_3d(
img: np.ndarray,
enh_func: T.Callable[[np.ndarray], np.ndarray],
enh_func: T.Callable[[np.ndarray, int], np.ndarray],
axes: str,
batch_size: int = 4,
) -> np.ndarray:
"""Run enhancement along 3 directions and blend the results.
Args:
enh_func: Enhancement function.
img: Image to enhance.
axes: Axes of the image.
batch_size: Batch size for enhancement.
"""
if axes != 'zyx':
# move z to the first axis
z_idx = axes.index('z')
img = np.moveaxis(img, z_idx, 0)
enh_z = enh_func(img)
enh_y = enh_func(np.moveaxis(img, 1, 0))
enh_z = enh_func(img, batch_size)
zimg_size = np.array(img.shape[1:]).prod()

img_y = np.moveaxis(img, 1, 0)
yimg_size = np.array(img_y.shape[1:]).prod()
factor_y = int(zimg_size / yimg_size)
bz_y = max(batch_size * factor_y, 1)
enh_y = enh_func(img_y, bz_y)
enh_y = np.moveaxis(enh_y, 0, 1)
enh_x = enh_func(np.moveaxis(img, 2, 0))

img_x = np.moveaxis(img, 2, 0)
ximg_size = np.array(img_x.shape[1:]).prod()
factor_x = int(zimg_size / ximg_size)
bz_x = max(batch_size * factor_x, 1)
enh_x = enh_func(img_x, bz_x)
enh_x = np.moveaxis(enh_x, 0, 2)
enh_img = enh_z * enh_y * enh_x
return enh_img
Expand Down

0 comments on commit 4862c33

Please sign in to comment.