Skip to content

Commit

Permalink
Add ability to use a volumetric segmentation to create surfaces (#171)
Browse files Browse the repository at this point in the history
* add ability to use labels to segment

* fix deps

* add heuristic to open segmentations as such

* split out into cleaner functions

* use new function in main widget
  • Loading branch information
brisvag authored Jul 12, 2024
1 parent 61611e3 commit 591bd2b
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 48 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"cryohub>=0.6.4",
"cryotypes>=0.2.0",
"einops",
"morphosamplers>=0.0.10",
"morphosamplers[segment]>=0.0.10",
"pydantic<2", # migration will take a while for napari
"packaging",
]
Expand Down
102 changes: 78 additions & 24 deletions src/blik/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_reader(path):


def _construct_positions_layer(
coords, features, scale, exp_id, p_id, source, name_suffix, **pt_kwargs
coords, features, scale, exp_id, p_id, source, name_suffix, **points_kwargs
):
feat_defaults = (
pd.DataFrame(features.iloc[-1].to_dict(), index=[0])
Expand All @@ -44,7 +44,7 @@ def _construct_positions_layer(
"projection_mode": "all",
# "axis_labels": ('z', 'y', 'x'),
"units": 'angstrom',
**pt_kwargs,
**points_kwargs,
},
"points",
)
Expand Down Expand Up @@ -86,12 +86,12 @@ def construct_particle_layer_tuples(
p_id=None,
source="",
name_suffix="",
**pt_kwargs,
**points_kwargs,
):
"""
Constructs particle layer tuples from particle data.
Data should be still in xyz format (will be flipped to zyx).
Coords should be still in xyz order (will be flipped to zyx in output).
"""
# unique id so we can connect layers safely
p_id = p_id if p_id is not None else uuid1()
Expand All @@ -113,7 +113,7 @@ def construct_particle_layer_tuples(
p_id=p_id,
source=source,
name_suffix=name_suffix,
**pt_kwargs,
**points_kwargs,
)
ori = _construct_orientations_layer(
coords=coords,
Expand Down Expand Up @@ -141,7 +141,7 @@ def read_particles(particles, name_suffix="particle"):

px_size = particles.pixel_spacing
if not px_size:
warnings.warn("unknown pixel spacing, setting to 1 Angstrom", stacklevel=2)
warnings.warn(f"unknown pixel spacing for particles '{particles.experiment_id}'; setting to 1 Angstrom.", stacklevel=2)
px_size = 1

if particles.shift is not None:
Expand All @@ -161,31 +161,79 @@ def read_particles(particles, name_suffix="particle"):
)


def read_image(image):
px_size = image.pixel_spacing
if not px_size:
warnings.warn("unknown pixel spacing, setting to 1 Angstrom", stacklevel=2)
px_size = 1
def construct_image_layer_tuple(
data,
scale,
exp_id,
stack=False,
source="",
**image_kwargs,
):
return (
image.data,
data,
{
"name": f"{image.experiment_id} - image",
"scale": [px_size] * image.data.ndim,
"metadata": {"experiment_id": image.experiment_id, "stack": image.stack},
"name": f"{exp_id} - image",
"scale": [scale] * 3,
"metadata": {"experiment_id": exp_id, "stack": stack, "source": source},
"interpolation2d": "spline36",
"interpolation3d": "linear",
"rendering": "average",
"depiction": "plane",
"blending": "translucent",
"plane": {"thickness": 5},
"projection_mode": "mean",
"depiction": "plane",
"plane": {"thickness": 5, "position": np.array(data.shape) / 2},
"rendering": "average",
# "axis_labels": ('z', 'y', 'x'),
"units": 'angstrom',
**image_kwargs,
},
"image",
)


def read_image(image):
return construct_image_layer_tuple(
data=image.data,
scale=image.pixel_spacing,
exp_id=image.experiment_id,
stack=image.stack,
source=image.source,
)


def construct_segmentation_layer_tuple(
data,
scale,
exp_id,
stack=False,
source="",
**labels_kwargs,
):
return (
data,
{
"name": f"{exp_id} - segmentation",
"scale": [scale] * 3,
"metadata": {"experiment_id": exp_id, "stack": stack, "source": source},
"blending": "translucent",
# "axis_labels": ('z', 'y', 'x'),
"units": 'angstrom',
**labels_kwargs,
},
"labels",
)


def read_segmentation(image):
return construct_segmentation_layer_tuple(
data=image.data,
scale=image.pixel_spacing,
exp_id=image.experiment_id,
stack=image.stack,
source=image.source,
)



def read_surface_picks(path):
lines = []
with open(path, "rb") as f:
Expand Down Expand Up @@ -255,13 +303,19 @@ def read_layers(*paths, **kwargs):
else:
cryohub_paths.append(path)

data_list = cryohub.read(*cryohub_paths, **kwargs)
obj_list = cryohub.read(*cryohub_paths, **kwargs)
# sort so we get images first, better for some visualization circumstances
for data in sorted(data_list, key=lambda x: not isinstance(x, ImageProtocol)):
if isinstance(data, ImageProtocol):
layers.append(read_image(data))
elif isinstance(data, PoseSetProtocol):
layers.extend(read_particles(data))
for obj in sorted(obj_list, key=lambda x: not isinstance(x, ImageProtocol)):
if not obj.pixel_spacing:
warnings.warn(f"unknown pixel spacing for {obj.__class__.__name__} '{obj.experiment_id}'; setting to 1 Angstrom.", stacklevel=2)
obj.pixel_spacing = 1
if isinstance(obj, ImageProtocol):
if np.issubdtype(obj.data.dtype, np.integer) and np.iinfo(obj.data.dtype).bits == 8:
layers.append(read_segmentation(obj))
else:
layers.append(read_image(obj))
elif isinstance(obj, PoseSetProtocol):
layers.extend(read_particles(obj))

for lay in layers:
lay[1]["visible"] = False # speed up loading
Expand Down
23 changes: 9 additions & 14 deletions src/blik/widgets/main_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from packaging.version import parse as parse_version
from scipy.spatial.transform import Rotation

from ..reader import construct_particle_layer_tuples
from ..reader import construct_particle_layer_tuples, construct_segmentation_layer_tuple
from ..utils import generate_vectors, invert_xyz, layer_tuples_to_layers


Expand Down Expand Up @@ -171,18 +171,13 @@ def new(
if l_type == "segmentation":
for lay in layers:
if isinstance(lay, Image) and lay.metadata["experiment_id"] == exp_id:
labels = Labels(
np.zeros(lay.data.shape, dtype=np.int32),
name=f"{exp_id} - segmentation",
scale=lay.scale,
# axis_labels=('z', 'y', 'x'),
units='angstrom',
metadata={
"experiment_id": exp_id,
"stack": lay.metadata["stack"],
},
layer = construct_segmentation_layer_tuple(
data=np.zeros(lay.data.shape, dtype=np.int32),
scale=lay.scale[0],
exp_id=exp_id,
stack=lay.metadata["stack"],
)
return [labels]
return layer_tuples_to_layers([layer])
elif l_type == "particles":
for lay in layers:
if lay.metadata["experiment_id"] == exp_id:
Expand All @@ -208,7 +203,7 @@ def new(
edge_color="surface_id",
ndim=3,
# axis_labels=('z', 'y', 'x'),
units='angstrom',
units="angstrom",
)

return [pts]
Expand All @@ -223,7 +218,7 @@ def new(
face_color_cycle=np.random.rand(30, 3),
ndim=3,
# axis_labels=('z', 'y', 'x'),
units='angstrom',
units="angstrom",
)

return [pts]
Expand Down
65 changes: 56 additions & 9 deletions src/blik/widgets/picking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
sample_volume_around_surface,
)
from morphosamplers.surface_spline import GriddedSplineSurface
from morphosamplers.preprocess import get_label_paths_3d
from scipy.spatial.transform import Rotation

from ..reader import construct_particle_layer_tuples
Expand Down Expand Up @@ -67,6 +68,44 @@ def _generate_surface_grids_from_shapes_layer(
return surface_grids, colors


def _generate_surface_grids_from_labels_layer(
surface_label,
spacing_A=100,
inside_points=None,
closed=False,
):
"""create a new surface representation from a segmentation."""
spacing_A /= surface_label.scale[0]
surface_grids = []
if inside_points is None:
inside_point = None
else:
inside_point = (
invert_xyz(inside_points.data[0]) if len(inside_points.data) else None
)

# doing this "custom" because we need to flip xyz
surfaces_lines = get_label_paths_3d(compute(surface_label.data)[0], axis=0, slicing_step=10, sampling_step=10)

for lines in surfaces_lines:
lines = [invert_xyz(line.astype(float)) for line in lines]

try:
surface_grids.append(
GriddedSplineSurface(
points=lines,
separation=spacing_A,
order=3,
closed=closed,
inside_point=inside_point,
)
)
except ValueError:
continue

return surface_grids, np.random.rand(len(surface_grids), 3)


def _resample_surfaces(image_layer, surface_grids, spacing, thickness, masked):
volumes = []
for surf in surface_grids:
Expand Down Expand Up @@ -108,21 +147,29 @@ def _resample_filament(image_layer, filament, spacing, thickness):
inside_points={"nullable": True},
)
def surface(
surface_shapes: napari.layers.Shapes,
surface_input: napari.layers.Layer,
inside_points: napari.layers.Points,
spacing_A=50,
closed=False,
) -> napari.types.LayerDataTuple:
"""create a new surface representation from picked surface points."""
surface_grids, colors = _generate_surface_grids_from_shapes_layer(
surface_shapes,
spacing_A,
inside_points=inside_points,
closed=closed,
)
if isinstance(surface_input, napari.layers.Shapes):
surface_grids, colors = _generate_surface_grids_from_shapes_layer(
surface_input,
spacing_A,
inside_points=inside_points,
closed=closed,
)
else:
surface_grids, colors = _generate_surface_grids_from_labels_layer(
surface_input,
spacing_A,
inside_points=inside_points,
closed=closed,
)

meshes = []
exp_id = surface_shapes.metadata["experiment_id"]
exp_id = surface_input.metadata["experiment_id"]

for surf in surface_grids:
meshes.append(surf.mesh())
Expand Down Expand Up @@ -156,7 +203,7 @@ def surface(
"surface_grids": surface_grids,
"surface_colors": colors,
},
"scale": surface_shapes.scale,
"scale": surface_input.scale,
"shading": "smooth",
"colormap": colormap,
},
Expand Down

0 comments on commit 591bd2b

Please sign in to comment.