Skip to content

Commit

Permalink
fix bug in predict multi channel img
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanguage committed Sep 27, 2023
1 parent 8678392 commit 684b827
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 27 deletions.
4 changes: 2 additions & 2 deletions ufish/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,9 @@ def predict_chunks(
"""
from .utils.img import (
check_img_axes, chunks_iterator,
process_chunk_size)
process_chunk_size, infer_img_axes)
if axes is None:
axes = self.infer_axes(img)
axes = infer_img_axes(img.shape)
check_img_axes(img, axes)
if chunk_size is None:
from .utils.img import get_default_chunk_size
Expand Down
55 changes: 30 additions & 25 deletions ufish/utils/img.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,48 +126,53 @@ def expand_df_axes(
return df


def transform_axes(img, axes: str, orig_axes: T.Optional[str] = None):
"""Re-order the axes of an image,
axes in order of 'tczyx'."""
if orig_axes is None:
new_axes = ''.join(sorted(axes, key=lambda x: 'tczyx'.index(x)))
img = np.moveaxis(
img, [axes.index(c) for c in new_axes], range(len(axes)))
return img, new_axes
else:
# recover the original axes
img = np.moveaxis(
img, [axes.index(c) for c in orig_axes], range(len(axes)))
return img, orig_axes


def map_predfunc_to_img(
predfunc: T.Callable[
[np.ndarray],
T.Tuple[pd.DataFrame, np.ndarray]
],
img: np.ndarray,
axes: str,
is_transform_axes: bool = True,
):
"""Map a prediction function to an multi-dimensional image."""
from .log import logger
yx_idx = [axes.index(c) for c in 'yx']
# move yx to the last two axes
img = np.moveaxis(img, yx_idx, [-2, -1])
new_axes = axes.replace('y', '').replace('x', '') + 'yx'
if is_transform_axes:
img, new_axes = transform_axes(img, axes)
else:
new_axes = axes
dfs = []
if len(img.shape) in (2, 3):
if (len(img.shape) == 2) or (new_axes == 'zyx'):
df, e_img = predfunc(img, axes=axes)
df = expand_df_axes(df, new_axes, [])
dfs.append(df)
elif len(img.shape) == 4:
else:
e_img = np.zeros_like(img, dtype=np.float32)
for i, img_3d in enumerate(img):
for i, sub_img in enumerate(img):
logger.info(
'Processing multi-dimensional image'
f' {i+1}/{len(img)}')
df, e_img[i] = predfunc(img_3d, axes=axes[1:])
f'Processing multi-dimensional image on axis {new_axes[0]}'
f': {i+1}/{len(img)}')
df, e_img[i] = map_predfunc_to_img(
predfunc, sub_img, new_axes[1:], False)
df = expand_df_axes(df, new_axes, [i])
dfs.append(df)
else:
assert len(img.shape) == 5
e_img = np.zeros_like(img, dtype=np.float32)
num_imgs = img.shape[0] * img.shape[1]
for i, img_4d in enumerate(img):
for j, img_3d in enumerate(img_4d):
logger.info(
'Processing multi-dimensional image'
f' {i*img.shape[1]+j+1}/{num_imgs}'
)
df, e_img[i, j] = predfunc(img_3d, axes=axes[2:])
df = expand_df_axes(df, new_axes, [i, j])
dfs.append(df)
# move yx back to the original position
e_img = np.moveaxis(e_img, [-2, -1], yx_idx)
if is_transform_axes:
e_img, _ = transform_axes(e_img, new_axes, axes)
res_df = pd.concat(dfs, ignore_index=True)
# re-order columns
res_df = res_df[list(axes)]
Expand Down

0 comments on commit 684b827

Please sign in to comment.