Skip to content

Commit

Permalink
Merge branch 'master' of ssh://github.com/freesurfer/surfa
Browse files Browse the repository at this point in the history
  • Loading branch information
ahoopes committed Mar 19, 2024
2 parents b9a15b9 + 83cd12c commit 26346ae
Show file tree
Hide file tree
Showing 12 changed files with 658 additions and 179 deletions.
2 changes: 2 additions & 0 deletions surfa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .core import LabelRecoder

from .transform import Affine
from .transform import Warp
from .transform import Space
from .transform import ImageGeometry

Expand All @@ -29,6 +30,7 @@
from .io import load_affine
from .io import load_label_lookup
from .io import load_mesh
from .io import load_warp

from . import vis
from . import freesurfer
Expand Down
19 changes: 16 additions & 3 deletions surfa/core/framed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@
from surfa.core.labels import LabelLookup


# mgz now has its intent encoded in the version number
# version = (intent & 0xff ) << 8) | MGH_VERSION
# MGH_VERSION = 1
class FramedArrayIntents:
unknown = -1
mri = 0
label = 1
shape = 2
warpmap = 3
warpmap_inv = 4


class FramedArray:

def __init__(self, basedim, data, labels=None, metadata=None):
Expand Down Expand Up @@ -264,7 +276,8 @@ def _shape_changed(self):
"""
pass

def save(self, filename, fmt=None):
# optional parameter to specify FramedArray intent, default is MRI data
def save(self, filename, fmt=None, intent=FramedArrayIntents.mri):
"""
Write array to file.
Expand All @@ -276,7 +289,7 @@ def save(self, filename, fmt=None):
Optional file format to force.
"""
from surfa.io.framed import save_framed_array
save_framed_array(self, filename, fmt=fmt)
save_framed_array(self, filename, fmt=fmt, intent=intent)

def min(self, nonzero=False, frames=False):
"""
Expand Down Expand Up @@ -366,7 +379,7 @@ def percentile(self, percentiles, method='linear', nonzero=False):
data = self.data
if nonzero:
data = data[data.nonzero()]
return np.percentile(data, percentiles, interpolation=method)
return np.percentile(data, percentiles, method=method)

def clip(self, a_min, a_max):
"""
Expand Down
134 changes: 27 additions & 107 deletions surfa/image/framed.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,18 +383,21 @@ def resample_like(self, target, method='linear', copy=True, fill=0):
method=method, affine=affine.matrix, fill=fill)
return self.new(interped, target_geom)

def transform(self, trf=None, method='linear', rotation='corner', resample=True, fill=0, affine=None):
def transform(self, trf=None, method='linear', rotation='corner', resample=True, fill=0):
"""
Apply an affine or non-linear transform.
**Note on deformation fields:** Until we come up with a reasonable way to represent
deformation fields, they can be implemented as multi-frame images. It is assumed that
they represent a *displacement* vector field in voxel space. So under the hood, images
will be moved into the space of the deformation field if the image geometries differ.
**The original implementation has been moved to Affine.transform and Warp.transform.
**The method is now impemeted to transform image using Affine.transform and Warp.transform.
**Note on trf argument:** It accepts Affine/Warp object, or deformation fields (4D numpy array).
Pass trf argument as a numpy array is deprecated and will be removed in the future.
It is assumed that the deformation fields represent a *displacement* vector field in voxel space.
So under the hood, images will be moved into the space of the deformation field if the image geometries differ.
Parameters
----------
trf : Affine or !class
trf : Affine/Warp or !class
Affine transform or nonlinear deformation (displacement) to apply to the image.
method : {'linear', 'nearest'}
Image interpolation method if resample is enabled.
Expand All @@ -406,8 +409,6 @@ def transform(self, trf=None, method='linear', rotation='corner', resample=True,
be updated (this is not possible if a displacement field is provided).
fill : scalar
Fill value for out-of-bounds voxels.
affine : Affine
Deprecated. Use the `trf` argument instead.
Returns
-------
Expand All @@ -418,107 +419,27 @@ def transform(self, trf=None, method='linear', rotation='corner', resample=True,
raise NotImplementedError('transform() is not yet implemented for 2D data, '
'contact andrew if you need this')

if affine is not None:
trf = affine
warnings.warn('The \'affine\' argument to transform() is deprecated. Just use '
'the first positional argument to specify a transform.',
DeprecationWarning, stacklevel=2)

# one of these two will be set by the end of the function
disp_data = None
matrix_data = None

# first try to convert it to an affine matrix. if that fails
# we assume it has to be a deformation field
try:
trf = cast_affine(trf, allow_none=False)
except ValueError:
pass

image = self.copy()
if isinstance(trf, Affine):
return trf.transform(image, method, rotation, resample, fill)

# for clarity
affine = trf

# if not resampling, just change the image vox2world matrix and return
if not resample:

# TODO: if affine is missing geometry info, do we assume that the affine
# is in world space or voxel space? let's do world for now
if affine.source is not None and affine.target is not None:
affine = affine.convert(space='world', source=self)
# TODO: must try this again once I changed everything around!!
elif affine.space is None:
warnings.warn('Affine transform is missing metadata defining its coordinate '
'space or source and target geometry. Assuming matrix is a '
'world-space transform since resample=False, but this might '
'not always be the case. Best practice is to provide the '
'correct metadata in the affine')
elif affine.space != 'world':
raise ValueError('affine must contain source and target info '
'if not in world space')

# apply forward transform to the header
transformed = self.copy()
transformed.geom.update(vox2world=affine @ affine.source.vox2world)
return transformed

# sanity check and preprocess the affine if resampling
target_geom = self.geom

if affine.source is not None and affine.target is not None:
# it should be assumed that the default affine space is voxel
# when both source and target are set
if affine.space is None:
affine = affine.copy()
affine.space = 'voxel'
#
affine = affine.convert(space='voxel', source=self)
target_geom = affine.target
elif affine.space is not None and affine.space != 'voxel':
raise ValueError('affine must contain source and target info if '
'coordinate space is not \'voxel\'')

# ensure the rotation is around the image corner before interpolating
if rotation not in ('center', 'corner'):
raise ValueError("rotation must be 'center' or 'corner'")
elif rotation == 'center':
affine = center_to_corner_rotation(affine, self.baseshape)

# make sure the matrix is actually inverted since we want a target to
# source voxel mapping for resampling
matrix_data = affine.inv().matrix
source_data = self.framed_data

else:
if not resample:
raise ValueError('transform resampling must be enabled when deformation is used')
from surfa.transform.warp import Warp
if isinstance(trf, np.ndarray):
warnings.warn('The option to pass \'trf\' argument as a numpy array is deprecated. '
'Pass \'trf\' as either an Affine or Warp object',
DeprecationWarning, stacklevel=2)

# cast deformation as a framed image data. important that the fallback geometry
# here is the current image space
deformation = cast_image(trf, fallback_geom=self.geom)
if deformation.nframes != self.basedim:
raise ValueError(f'deformation ({deformation.nframes}D) does not match '
f'dimensionality of image ({self.basedim}D)')

# since we only support deformations in the form of voxel displacement
# currently, must get the image in the space of the deformation
source_data = self.resample_like(deformation).framed_data

# make sure to use the deformation as the target geometry
target_geom = deformation.geom
image = image.resample_like(deformation)
trf = Warp(data=trf,
source=image.geom,
target=deformation.geom,
format=Warp.Format.disp_crs)

# get displacement data
disp_data = deformation.data
if isinstance(trf, Warp):
return trf.transform(image, method, fill)

# do the interpolation
interpolated = interpolate(source=source_data,
target_shape=target_geom.shape,
method=method,
affine=matrix_data,
disp=disp_data,
fill=fill)
return self.new(interpolated, target_geom)
raise ValueError("Pass \'trf\' as either an Affine or Warp object")

def reorient(self, orientation, copy=True):
"""
Expand Down Expand Up @@ -784,20 +705,19 @@ def barycenters(self, labels=None, space='image'):
one, the barycenter array will be of shape $(F, L, D)$.
"""
if labels is not None:
#

if not np.issubdtype(self.dtype, np.integer):
raise ValueError('expected int dtype for computing barycenters on 1D, '
f'but got dtype {self.dtype}')
weights = np.ones(self.baseshape, dtype=np.float32)
centers = [center_of_mass(weights, self.framed_data[..., i], labels) for i in range(self.nframes)]
else:
#

centers = [center_of_mass(self.framed_data[..., i]) for i in range(self.nframes)]

#

centers = np.squeeze(centers)

#
space = cast_space(space)
if space != 'image':
centers = self.geom.affine('image', space)(centers)
Expand Down
36 changes: 18 additions & 18 deletions surfa/image/interp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def interpolate(source, target_shape, method, affine=None, disp=None, fill=0):
raise ValueError('interpolation requires an affine transform and/or displacement field')

if method not in ('linear', 'nearest'):
raise ValueError(f'interp method must be linear or nearest, but got {method}')
raise ValueError(f'interp method must be linear or nearest, got {method}')

if not isinstance(source, np.ndarray):
raise ValueError(f'source data must a numpy array, but got input of type {source.__class__.__name__}')
raise ValueError(f'source data must be a numpy array, got {source.__class__.__name__}')

if source.ndim != 4:
raise ValueError(f'source data must be 4D, but got input of shape {target_shape}')
raise ValueError(f'source data must be 4D, but got input of shape {source.shape}')

target_shape = tuple(target_shape)
if len(target_shape) != 3:
Expand All @@ -53,7 +53,7 @@ def interpolate(source, target_shape, method, affine=None, disp=None, fill=0):
use_affine = affine is not None
if use_affine:
if not isinstance(affine, np.ndarray):
raise ValueError(f'affine must a numpy array, but got input of type {source.__class__.__name__}')
raise ValueError(f'affine must be a numpy array, got {affine.__class__.__name__}')
if not np.array_equal(affine.shape, (4, 4)):
raise ValueError(f'affine must be 4x4, but got input of shape {affine.shape}')
# only supports float32 affines for now
Expand All @@ -63,9 +63,9 @@ def interpolate(source, target_shape, method, affine=None, disp=None, fill=0):
use_disp = disp is not None
if use_disp:
if not isinstance(disp, np.ndarray):
raise ValueError(f'source data must a numpy array, but got input of type {source.__class__.__name__}')
raise ValueError(f'source data must be a numpy array, got {disp.__class__.__name__}')
if not np.array_equal(disp.shape[:-1], target_shape):
raise ValueError(f'displacement field shape {disp.shape[:-1]} must match target shape {target_shape}')
raise ValueError(f'warp shape {disp.shape[:-1]} must match target shape {target_shape}')

# TODO: figure out what would cause this
if not disp.flags.c_contiguous and not disp.flags.f_contiguous:
Expand Down Expand Up @@ -130,10 +130,10 @@ ctypedef fused datatype:

@cython.boundscheck(False)
@cython.wraparound(False)
def interp_3d_fortran_nearest(datatype[::1, :, :, :] source,
def interp_3d_fortran_nearest(const datatype[::1, :, :, :] source,
np.ndarray[np.int_t, ndim=1] target_shape,
float[:, ::1] mat,
float[::1, :, :, :] disp,
const float[:, ::1] mat,
const float[::1, :, :, :] disp,
datatype fill_value,
bint use_affine,
bint use_disp):
Expand Down Expand Up @@ -242,10 +242,10 @@ def interp_3d_fortran_nearest(datatype[::1, :, :, :] source,

@cython.boundscheck(False)
@cython.wraparound(False)
def interp_3d_fortran_linear(datatype[::1, :, :, :] source,
def interp_3d_fortran_linear(const datatype[::1, :, :, :] source,
np.ndarray[np.int_t, ndim=1] target_shape,
float[:, ::1] mat,
float[::1, :, :, :] disp,
const float[:, ::1] mat,
const float[::1, :, :, :] disp,
datatype fill_value,
bint use_affine,
bint use_disp):
Expand Down Expand Up @@ -380,10 +380,10 @@ def interp_3d_fortran_linear(datatype[::1, :, :, :] source,

@cython.boundscheck(False)
@cython.wraparound(False)
def interp_3d_contiguous_nearest(datatype[:, :, :, ::1] source,
def interp_3d_contiguous_nearest(const datatype[:, :, :, ::1] source,
np.ndarray[np.int_t, ndim=1] target_shape,
float[:, ::1] mat,
float[:, :, :, ::1] disp,
const float[:, ::1] mat,
const float[:, :, :, ::1] disp,
datatype fill_value,
bint use_affine,
bint use_disp):
Expand Down Expand Up @@ -492,10 +492,10 @@ def interp_3d_contiguous_nearest(datatype[:, :, :, ::1] source,

@cython.boundscheck(False)
@cython.wraparound(False)
def interp_3d_contiguous_linear(datatype[:, :, :, ::1] source,
def interp_3d_contiguous_linear(const datatype[:, :, :, ::1] source,
np.ndarray[np.int_t, ndim=1] target_shape,
float[:, ::1] mat,
float[:, :, :, ::1] disp,
const float[:, ::1] mat,
const float[:, :, :, ::1] disp,
datatype fill_value,
bint use_affine,
bint use_disp):
Expand Down
1 change: 1 addition & 0 deletions surfa/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from .framed import load_volume
from .framed import load_slice
from .framed import load_overlay
from .framed import load_warp
from .labels import load_label_lookup
from .mesh import load_mesh
Loading

0 comments on commit 26346ae

Please sign in to comment.