Skip to content

Commit

Permalink
slicing in viewer.Scene, tests pass locally
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishavlin committed Jul 6, 2023
1 parent e115005 commit 89216bc
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 88 deletions.
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,19 @@ left_edge = ds.domain_center - ds.arr([40, 40, 40], 'kpc')
right_edge = ds.domain_center + ds.arr([40, 40, 40], 'kpc')
res = (600, 600, 600)

yt_scene.add_to_viewer(viewer,
ds,
("enzo", "Temperature"),
left_edge = left_edge,
right_edge = right_edge,
resolution = res)

yt_scene.add_to_viewer(viewer,
ds,
("enzo", "Density"),
left_edge = left_edge,
right_edge = right_edge,
resolution = res)
yt_scene.add_region(viewer,
ds,
("enzo", "Temperature"),
left_edge=left_edge,
right_edge=right_edge,
resolution=res)

yt_scene.add_region(viewer,
ds,
("enzo", "Density"),
left_edge=left_edge,
right_edge=right_edge,
resolution=res)

nbscreenshot(viewer)
```
Expand Down
75 changes: 54 additions & 21 deletions src/yt_napari/_model_ingestor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional, Tuple, Union

import numpy as np
from unyt import unit_object, unit_registry, unyt_array
from unyt import unit_object, unit_registry, unyt_array, unyt_quantity

from yt_napari._data_model import DataContainer, InputModel
from yt_napari._ds_cache import dataset_cache
Expand Down Expand Up @@ -332,45 +332,78 @@ def _load_3D_regions(ds, m_data: DataContainer, layer_list: list) -> list:
return layer_list


def _load_2D_slices(ds, m_data: DataContainer, layer_list: list) -> list:
def _process_slice(
ds,
normal: Union[str, int],
center: Optional[unyt_array] = None,
width: Optional[unyt_quantity] = None,
height: Optional[unyt_quantity] = None,
resolution: Optional[Tuple[int, int]] = (400, 400),
periodic: Optional[bool] = False,
) -> tuple:
# returns a slice frb and a LayerDomain for a slice
axis_id = ds.coordinates.axis_id
normal_ax = axis_id[normal]
x_axis = axis_id[ds.coordinates.image_axis_name[normal][0]]
y_axis = axis_id[ds.coordinates.image_axis_name[normal][1]]

if center is None:
center = ds.domain_center
if width is None:
width = ds.domain_width[x_axis]
if height is None:
height = ds.domain_width[y_axis]

LE = ds.arr([0.0, 0.0], "code_length")
RE = ds.arr([0.0, 0.0], "code_length")
LE[0] = center[x_axis] - width / 2.0
RE[0] = center[x_axis] + width / 2.0
LE[1] = center[y_axis] - height / 2.0
RE[1] = center[y_axis] + height / 2.0

slc = ds.slice(normal_ax, center[normal_ax])
frb = slc.to_frb(
width=width,
height=height,
center=center,
resolution=resolution,
periodic=periodic,
)

layer_domain = LayerDomain(
left_edge=LE, right_edge=RE, resolution=resolution, n_d=2
)

return frb, layer_domain


def _load_2D_slices(ds, m_data: DataContainer, layer_list: list) -> list:

for slice in m_data.selections.slices:
if slice.normal == "":
continue
normal_ax = axis_id[slice.normal]

if slice.center is None:
c = ds.domain_center
c = None
else:
c = ds.arr(slice.center.value, slice.center.unit)

slc = ds.slice(slice.normal, c[normal_ax])

x_axis = axis_id[ds.coordinates.image_axis_name[slice.normal][0]]
y_axis = axis_id[ds.coordinates.image_axis_name[slice.normal][1]]
LE = ds.arr([0.0, 0.0], "code_length")
RE = ds.arr([0.0, 0.0], "code_length")
if slice.slice_width is None:
w = ds.domain_width[x_axis]
w = None
else:
w = ds.quan(slice.slice_width.value, slice.slice_width.unit)
LE[0] = c[x_axis] - w / 2.0
RE[0] = c[x_axis] + w / 2.0

if slice.slice_height is None:
h = ds.domain_width[y_axis]
h = None
else:
h = ds.quan(slice.slice_height.value, slice.slice_height.unit)
LE[1] = c[y_axis] - h / 2.0
RE[1] = c[y_axis] + h / 2.0

res = slice.resolution
layer_domain = LayerDomain(left_edge=LE, right_edge=RE, resolution=res, n_d=2)

frb = slc.to_frb(
frb, layer_domain = _process_slice(
ds,
slice.normal,
center=c,
width=w,
height=h,
center=c,
resolution=slice.resolution,
periodic=slice.periodic,
)
Expand Down
33 changes: 21 additions & 12 deletions src/yt_napari/_tests/test_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,26 @@ def test_viewer(make_napari_viewer, yt_ds, caplog):
# test add_to_viewer
sc = Scene()
res = (10, 10, 10)
sc.add_to_viewer(viewer, yt_ds, ("gas", "density"), resolution=res)
sc.add_region(viewer, yt_ds, ("gas", "density"), resolution=res)

expected_layers = 1
assert len(viewer.layers) == expected_layers

sc.add_to_viewer(viewer, yt_ds, ("gas", "density"), translate=10, resolution=res)
sc.add_region(viewer, yt_ds, ("gas", "density"), translate=10, resolution=res)
assert "translate is calculated internally" in caplog.text
sc.add_to_viewer(viewer, yt_ds, ("gas", "density"), scale=10, resolution=res)
sc.add_region(viewer, yt_ds, ("gas", "density"), scale=10, resolution=res)
assert "scale is calculated internally" in caplog.text

expected_layers += 2 # the above will add layers!
assert len(viewer.layers) == expected_layers

sc.add_to_viewer(viewer, yt_ds, ("gas", "density"), resolution=res)
sc.add_region(viewer, yt_ds, ("gas", "density"), resolution=res)
expected_layers += 1
assert len(viewer.layers) == expected_layers

# build a new scene so it builds from prior
sc = Scene()
sc.add_to_viewer(viewer, yt_ds, ("gas", "density"))
sc.add_region(viewer, yt_ds, ("gas", "density"))
expected_layers += 1
assert len(viewer.layers) == expected_layers

Expand All @@ -56,8 +56,8 @@ def test_sanitize_layers(make_napari_viewer, yt_ds):

sc = Scene()
res = (10, 10, 10)
sc.add_to_viewer(viewer, yt_ds, ("gas", "density"), name="layer0", resolution=res)
sc.add_to_viewer(viewer, yt_ds, ("gas", "mass"), name="layer1", resolution=res)
sc.add_region(viewer, yt_ds, ("gas", "density"), name="layer0", resolution=res)
sc.add_region(viewer, yt_ds, ("gas", "mass"), name="layer1", resolution=res)

clean_layers = sc._sanitize_layers(["layer0", "layer1"], viewer.layers)
assert len(clean_layers) == 2
Expand Down Expand Up @@ -90,8 +90,8 @@ def test_get_data_range(make_napari_viewer, yt_ds):

sc = Scene()
res = (10, 10, 10)
sc.add_to_viewer(viewer, yt_ds, ("gas", "density"), name="layer0", resolution=res)
sc.add_to_viewer(viewer, yt_ds, ("gas", "density"), name="layer1", resolution=res)
sc.add_region(viewer, yt_ds, ("gas", "density"), name="layer0", resolution=res)
sc.add_region(viewer, yt_ds, ("gas", "density"), name="layer1", resolution=res)
expected = (viewer.layers[0].data.min(), viewer.layers[0].data.max())
actual = sc.get_data_range(viewer.layers)
assert np.allclose(actual, expected)
Expand All @@ -109,8 +109,8 @@ def test_cross_layer_features(make_napari_viewer, yt_ds):

sc = Scene()
res = (10, 10, 10)
sc.add_to_viewer(viewer, yt_ds, ("gas", "density"), name="layer0", resolution=res)
sc.add_to_viewer(viewer, yt_ds, ("gas", "density"), name="layer1", resolution=res)
sc.add_region(viewer, yt_ds, ("gas", "density"), name="layer0", resolution=res)
sc.add_region(viewer, yt_ds, ("gas", "density"), name="layer1", resolution=res)

sc.set_across_layers(viewer.layers, "colormap", "viridis")
assert all([layer.colormap == "viridis"] for layer in viewer.layers)
Expand All @@ -120,7 +120,7 @@ def test_cross_layer_features(make_napari_viewer, yt_ds):
for layer in viewer.layers:
assert np.allclose(layer.contrast_limits, expected)

sc.add_to_viewer(
sc.add_region(
viewer,
yt_ds,
("gas", "density"),
Expand All @@ -130,3 +130,12 @@ def test_cross_layer_features(make_napari_viewer, yt_ds):
)
linked = get_linked_layers(viewer.layers["layer2"])
assert viewer.layers["layer1"] in linked


def test_viewer_slices(make_napari_viewer, yt_ds):
viewer = make_napari_viewer()
sc = Scene()
res = (50, 50)
sc.add_slice(viewer, yt_ds, "x", ("gas", "density"), resolution=res)

assert len(viewer.layers) == 1
Loading

0 comments on commit 89216bc

Please sign in to comment.