Skip to content

Commit

Permalink
feat: add shape check for operators
Browse files Browse the repository at this point in the history
This augments and replaces the basic check_size function and adds a test suite. 
---------
Co-authored-by: Lena OUDJMAN <[email protected]>
Co-authored-by: Chaithya G R <[email protected]>
Co-authored-by: Pierre-Antoine Comby <[email protected]>
  • Loading branch information
Lenoush authored Aug 30, 2024
1 parent c358344 commit e4dc519
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 48 deletions.
39 changes: 39 additions & 0 deletions src/mrinufft/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,43 @@ def __init_subclass__(cls):
if backend := getattr(cls, "backend", None):
cls.interfaces[backend] = (available, cls)

def check_shape(self, *, image=None, ksp=None):
"""
Validate the shapes of the image or k-space data against operator shapes.
Parameters
----------
image : np.ndarray, optional
If passed, the shape of image data will be checked.
ksp : np.ndarray or object, optional
If passed, the shape of the k-space data will be checked.
Raises
------
ValueError
If the shape of the provided image does not match the expected operator
shape, or if the number of k-space samples does not match the expected
number of samples.
"""
if image is not None:
image_shape = image.shape[-len(self.shape) :]
if image_shape != self.shape:
raise ValueError(
f"Image shape {image_shape} is not compatible "
f"with the operator shape {self.shape}"
)

if ksp is not None:
kspace_shape = ksp.shape[-1]
if kspace_shape != self.n_samples:
raise ValueError(
f"Kspace samples {kspace_shape} is not compatible "
f"with the operator samples {self.n_samples}"
)
if image is None and ksp is None:
raise ValueError("Nothing to check, provides image or ksp arguments")

@abstractmethod
def op(self, data):
"""Compute operator transform.
Expand Down Expand Up @@ -654,6 +691,7 @@ def op(self, data, ksp=None):
this performs for every coil \ell:
..math:: \mathcal{F}\mathcal{S}_\ell x
"""
self.check_shape(image=data, ksp=ksp)
# sense
data = auto_cast(data, self.cpx_dtype)

Expand Down Expand Up @@ -711,6 +749,7 @@ def adj_op(self, coeffs, img=None):
-------
Array in the same memory space of coeffs. (ie on cpu or gpu Memory).
"""
self.check_shape(image=img, ksp=coeffs)
coeffs = auto_cast(coeffs, self.cpx_dtype)
if self.uses_sense:
ret = self._adj_op_sense(coeffs, img)
Expand Down
4 changes: 2 additions & 2 deletions src/mrinufft/operators/interfaces/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(
else:
kwargs["extra_adj_op_args"] = ["-i"]

self.raw_op = RawBartNUFFT(samples_, shape, **kwargs)
super().__init__(
samples_,
shape,
Expand All @@ -150,11 +151,10 @@ def __init__(
n_batchs=n_batchs,
n_trans=1,
smaps=smaps,
raw_op=self.raw_op,
squeeze_dims=squeeze_dims,
)

self.raw_op = RawBartNUFFT(samples_, shape, **kwargs)

@property
def norm_factor(self):
"""Normalization factor of the operator."""
Expand Down
18 changes: 3 additions & 15 deletions src/mrinufft/operators/interfaces/cufinufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from .utils import (
CUPY_AVAILABLE,
check_size,
is_cuda_array,
is_host_array,
nvtx_mark,
Expand Down Expand Up @@ -299,11 +298,7 @@ def op(self, data, ksp_d=None):
this performs for every coil \ell:
..math:: \mathcal{F}\mathcal{S}_\ell x
"""
# monocoil
if self.uses_sense:
check_size(data, (self.n_batchs, *self.shape))
else:
check_size(data, (self.n_batchs, self.n_coils, *self.shape))
self.check_shape(image=data, ksp=ksp_d)
data = auto_cast(data, self.cpx_dtype)
# Dispatch to special case.
if self.uses_sense and is_cuda_array(data):
Expand Down Expand Up @@ -415,8 +410,8 @@ def adj_op(self, coeffs, img_d=None):
-------
Array in the same memory space of coeffs. (ie on cpu or gpu Memory).
"""
self.check_shape(image=img_d, ksp=coeffs)
coeffs = auto_cast(coeffs, self.cpx_dtype)
check_size(coeffs, (self.n_batchs, self.n_coils, self.n_samples))
# Dispatch to special case.
if self.uses_sense and is_cuda_array(coeffs):
adj_op_func = self._adj_op_sense_device
Expand Down Expand Up @@ -576,14 +571,7 @@ def data_consistency(self, image_data, obs_data):
obs_data = auto_cast(obs_data, self.cpx_dtype)
image_data = auto_cast(image_data, self.cpx_dtype)

B, C = self.n_batchs, self.n_coils
K, XYZ = self.n_samples, self.shape

check_size(obs_data, (B, C, K))
if self.uses_sense:
check_size(image_data, (B, *XYZ))
else:
check_size(image_data, (B, C, *XYZ))
self.check_shape(image=image_data, ksp=obs_data)

if self.uses_sense and is_host_array(image_data):
grad_func = self._dc_sense_host
Expand Down
13 changes: 5 additions & 8 deletions src/mrinufft/operators/interfaces/gpunufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import numpy as np
import warnings

from ..base import FourierOperatorBase, with_numpy_cupy
from mrinufft._utils import proper_trajectory, get_array_module, auto_cast
from mrinufft.operators.interfaces.utils import is_cuda_array, is_host_array, check_size
from mrinufft.operators.interfaces.utils import is_cuda_array, is_host_array

GPUNUFFT_AVAILABLE = True
try:
Expand Down Expand Up @@ -446,6 +447,7 @@ def op(self, data, coeffs=None):
np.ndarray
Masked Fourier transform of the input image.
"""
self.check_shape(image=data, ksp=coeffs)
B, C, XYZ, K = self.n_batchs, self.n_coils, self.shape, self.n_samples

op_func = self.raw_op.op
Expand Down Expand Up @@ -485,6 +487,7 @@ def adj_op(self, coeffs, data=None):
np.ndarray
Inverse discrete Fourier transform of the input coefficients.
"""
self.check_shape(image=data, ksp=coeffs)
B, C, XYZ, K = self.n_batchs, self.n_coils, self.shape, self.n_samples

adj_op_func = self.raw_op.adj_op
Expand Down Expand Up @@ -658,14 +661,8 @@ def data_consistency(self, image_data, obs_data):
image_data = auto_cast(image_data, self.cpx_dtype)

B, C = self.n_batchs, self.n_coils
K, XYZ = self.n_samples, self.shape

check_size(obs_data, (B, C, K))
if self.uses_sense:
check_size(image_data, (B, *XYZ))
else:
check_size(image_data, (B, C, *XYZ))

self.check_shape(image=image_data, ksp=obs_data)
# dispatch
if is_host_array(image_data) and is_host_array(obs_data):
grad_func = self._dc_host
Expand Down
4 changes: 2 additions & 2 deletions src/mrinufft/operators/interfaces/sigpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
**kwargs,
):
samples_ = proper_trajectory(samples, normalize="unit")
raw_op = RawSigpyNUFFT(samples_, shape, n_trans=n_trans, **kwargs)

super().__init__(
samples_,
Expand All @@ -122,11 +123,10 @@ def __init__(
n_batchs=n_batchs,
n_trans=n_trans,
smaps=smaps,
raw_op=raw_op,
squeeze_dims=squeeze_dims,
)

self.raw_op = RawSigpyNUFFT(samples_, shape, n_trans=n_trans, **kwargs)

@property
def norm_factor(self):
"""Normalization factor of the operator."""
Expand Down
13 changes: 8 additions & 5 deletions src/mrinufft/operators/interfaces/tfnufft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tensorflow MRI Nufft Operators."""

import numpy as np

from ..base import FourierOperatorBase, with_tensorflow
from mrinufft._utils import proper_trajectory

Expand Down Expand Up @@ -79,6 +80,7 @@ def op(self, data):
-------
Tensor
"""
self.check_shape(image=data)
if self.uses_sense:
data_d = data * self.smaps
else:
Expand All @@ -95,24 +97,25 @@ def op(self, data):
return coeff

@with_tensorflow
def adj_op(self, data):
def adj_op(self, coeffs):
"""
Backward Operation.
Parameters
----------
data: Tensor
coeffs: Tensor
Returns
-------
Tensor
"""
self.check_shape(ksp=coeffs)
if self.uses_density:
data_d = data * self.density
coeffs_d = coeffs * self.density
else:
data_d = data
coeffs_d = coeffs
img = tfnufft.nufft(
data_d,
coeffs_d,
self.samples,
self.shape,
transform_type="type_1",
Expand Down
21 changes: 13 additions & 8 deletions src/mrinufft/operators/interfaces/torchkbnufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from mrinufft.operators.base import FourierOperatorBase, with_torch
from mrinufft._utils import proper_trajectory
from mrinufft.operators.interfaces.utils import is_cuda_tensor
from mrinufft.operators.interfaces.utils import (
is_cuda_tensor,
)
import numpy as np


TORCH_AVAILABLE = True
try:
import torchkbnufft as tkbn
Expand Down Expand Up @@ -123,6 +126,7 @@ def op(self, data, out=None):
-------
Tensor: Non-uniform Fourier transform of the input image.
"""
self.check_shape(image=data, ksp=out)
B, C, XYZ = self.n_batchs, self.n_coils, self.shape
data = data.reshape((B, 1 if self.uses_sense else C, *XYZ))
data = data.to(self.device, copy=False)
Expand All @@ -137,28 +141,29 @@ def op(self, data, out=None):
return self._safe_squeeze(kdata)

@with_torch
def adj_op(self, data, out=None):
def adj_op(self, coeffs, out=None):
"""Backward Operation.
Parameters
----------
data: Tensor
coeffs: Tensor
Returns
-------
Tensor
"""
self.check_shape(image=out, ksp=coeffs)
B, C, K, XYZ = self.n_batchs, self.n_coils, self.n_samples, self.shape
data = data.reshape((B, C, K))
data = data.to(self.device, copy=False)
coeffs = coeffs.reshape((B, C, K))
coeffs = coeffs.to(self.device, copy=False)

if self.smaps is not None:
self.smaps = self.smaps.to(data.dtype, copy=False)
self.smaps = self.smaps.to(coeffs.dtype, copy=False)
if self.density:
data = data * self.density
coeffs = coeffs * self.density

img = self._tkb_adj_op.forward(
data=data, omega=self.samples.t(), smaps=self.smaps
data=coeffs, omega=self.samples.t(), smaps=self.smaps
)
img = img.reshape((B, 1 if self.uses_sense else C, *XYZ))
img /= self.norm_factor
Expand Down
2 changes: 0 additions & 2 deletions src/mrinufft/operators/interfaces/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from .utils import (
check_error,
check_size,
sizeof_fmt,
)

Expand All @@ -18,7 +17,6 @@

__all__ = [
"check_error",
"check_size",
"sizeof_fmt",
"get_maxThreadBlock",
"CUPY_AVAILABLE",
Expand Down
6 changes: 0 additions & 6 deletions src/mrinufft/operators/interfaces/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,3 @@ def sizeof_fmt(num, suffix="B"):
return f"{num:3.1f}{unit}{suffix}"
num /= 1024.0
return f"{num:.1f}Yi{suffix}"


def check_size(array_like, shape):
"""Check if array_like has a matching shape."""
if np.prod(array_like.shape) != np.prod(shape):
raise ValueError(f"Expected array with {shape}, got {array_like.shape}.")
3 changes: 3 additions & 0 deletions src/mrinufft/operators/stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def _ifftz(data):
@with_numpy_cupy
def op(self, data, ksp=None):
"""Forward operator."""
self.check_shape(image=data, ksp=ksp)
# Dispatch to special case.
data = auto_cast(data, self.cpx_dtype)

Expand Down Expand Up @@ -646,6 +647,8 @@ def _op_calibless_device(self, data, ksp=None):
@with_numpy_cupy
def adj_op(self, coeffs, img=None):
"""Adjoint operator."""
if img is not None:
self.check_shape(image=img, ksp=coeffs)
# Dispatch to special case.
coeffs = auto_cast(coeffs, self.cpx_dtype)

Expand Down
1 change: 1 addition & 0 deletions tests/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def test_batch_adj_op(
kspace_data = to_interface(kspace_data, array_interface)

kspace_flat = kspace_data.reshape(-1, operator.n_coils, operator.n_samples)

image_flat = [None] * operator.n_batchs
for i in range(len(image_flat)):
image_flat[i] = from_interface(
Expand Down
Loading

0 comments on commit e4dc519

Please sign in to comment.