Skip to content

Commit

Permalink
update examples add imshow
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 committed Jun 8, 2024
1 parent 0869b62 commit 1d53ecb
Show file tree
Hide file tree
Showing 19 changed files with 360 additions and 210 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ qapp.exec()
- `tensorstore.TensorStore`
- `zarr`
- You can add support for your own storage class by subclassing `ndv.DataWrapper`
and implementing a couple methods
and implementing a couple methods. (This doesn't require modifying ndv,
but contributions of new wrappers are welcome!)

See examples for each of these array types in [examples](./examples/)

Expand Down
18 changes: 5 additions & 13 deletions examples/custom_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from ndv import DataWrapper
import ndv

if TYPE_CHECKING:
from ndv import Indices, Sizes
Expand All @@ -16,10 +16,10 @@ def __init__(self, shape: tuple[int, ...]) -> None:
self._data = np.random.randint(0, 256, shape)

def __getitem__(self, item: Any) -> np.ndarray:
return self._data[item]
return self._data[item] # type: ignore [no-any-return]


class MyWrapper(DataWrapper[MyArrayThing]):
class MyWrapper(ndv.DataWrapper[MyArrayThing]):
@classmethod
def supports(cls, data: Any) -> bool:
if isinstance(data, MyArrayThing):
Expand All @@ -36,13 +36,5 @@ def isel(self, indexers: Indices) -> Any:
return self.data[idx]


if __name__ == "__main__":
from qtpy import QtWidgets

from ndv import NDViewer

qapp = QtWidgets.QApplication([])
data = MyArrayThing((10, 3, 512, 512))
v = NDViewer(data, channel_axis=1)
v.show()
qapp.exec()
data = MyArrayThing((10, 3, 512, 512))
ndv.imshow(data)
11 changes: 2 additions & 9 deletions examples/dask_arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dask.array.core import map_blocks
except ImportError:
raise ImportError("Please `pip install dask[array]` to run this example.")
import ndv

frame_size = (1024, 1024)

Expand All @@ -21,12 +22,4 @@ def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None:
chunks += [(x,) for x in frame_size]
dask_arr = map_blocks(_dask_block, chunks=chunks, dtype=np.uint8)

if __name__ == "__main__":
from qtpy import QtWidgets

from ndv import NDViewer

qapp = QtWidgets.QApplication([])
v = NDViewer(dask_arr)
v.show()
qapp.exec()
v = ndv.imshow(dask_arr)
17 changes: 3 additions & 14 deletions examples/jax_arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,7 @@
import jax.numpy as jnp
except ImportError:
raise ImportError("Please install jax to run this example")
from numpy_arr import generate_5d_sine_wave
import ndv

# Example usage
array_shape = (10, 3, 5, 512, 512) # Specify the desired dimensions
sine_wave_5d = jnp.asarray(generate_5d_sine_wave(array_shape))

if __name__ == "__main__":
from qtpy import QtWidgets

from ndv import NDViewer

qapp = QtWidgets.QApplication([])
v = NDViewer(sine_wave_5d, channel_axis=1)
v.show()
qapp.exec()
jax_arr = jnp.asarray(ndv.data.nd_sine_wave())
v = ndv.imshow(jax_arr)
65 changes: 6 additions & 59 deletions examples/numpy_arr.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,11 @@
from __future__ import annotations

import numpy as np


def generate_5d_sine_wave(
shape: tuple[int, int, int, int, int],
amplitude: float = 240,
base_frequency: float = 5,
) -> np.ndarray:
"""5D dataset."""
# Unpack the dimensions
angle_dim, freq_dim, phase_dim, ny, nx = shape

# Create an empty array to hold the data
output = np.zeros(shape)

# Define spatial coordinates for the last two dimensions
half_per = base_frequency * np.pi
x = np.linspace(-half_per, half_per, nx)
y = np.linspace(-half_per, half_per, ny)
y, x = np.meshgrid(y, x)

# Iterate through each parameter in the higher dimensions
for phase_idx in range(phase_dim):
for freq_idx in range(freq_dim):
for angle_idx in range(angle_dim):
# Calculate phase and frequency
phase = np.pi / phase_dim * phase_idx
frequency = 1 + (freq_idx * 0.1) # Increasing frequency with each step

# Calculate angle
angle = np.pi / angle_dim * angle_idx
# Rotate x and y coordinates
xr = np.cos(angle) * x - np.sin(angle) * y
np.sin(angle) * x + np.cos(angle) * y

# Compute the sine wave
sine_wave = (amplitude * 0.5) * np.sin(frequency * xr + phase)
sine_wave += amplitude * 0.5

# Assign to the output array
output[angle_idx, freq_idx, phase_idx] = sine_wave

return output

import ndv

try:
from skimage import data

img = data.cells3d()
except Exception:
img = generate_5d_sine_wave((10, 3, 8, 512, 512))


if __name__ == "__main__":
from qtpy import QtWidgets

from ndv import NDViewer
img = ndv.data.cells3d()
except Exception as e:
print(e)
img = ndv.data.nd_sine_wave((10, 3, 8, 512, 512))

qapp = QtWidgets.QApplication([])
v = NDViewer(img)
v.show()
qapp.exec()
ndv.imshow(img)
20 changes: 4 additions & 16 deletions examples/pyopencl_arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,14 @@
import pyopencl as cl
import pyopencl.array as cl_array
except ImportError:
raise ImportError("Please install jax to run this example")
from numpy_arr import generate_5d_sine_wave
raise ImportError("Please install pyopencl to run this example")
import ndv

# Set up OpenCL context and queue
context = cl.create_some_context(interactive=False)
queue = cl.CommandQueue(context)


# Example usage
array_shape = (10, 3, 5, 512, 512) # Specify the desired dimensions
sine_wave_5d = generate_5d_sine_wave(array_shape)
cl_sine_wave = cl_array.to_device(queue, sine_wave_5d)
gpu_data = cl_array.to_device(queue, ndv.data.nd_sine_wave())


if __name__ == "__main__":
from qtpy import QtWidgets

from ndv import NDViewer

qapp = QtWidgets.QApplication([])
v = NDViewer(cl_sine_wave, channel_axis=1)
v.show()
qapp.exec()
ndv.imshow(gpu_data)
13 changes: 3 additions & 10 deletions examples/sparse_arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import numpy as np

import ndv

shape = (256, 4, 512, 512)
N = int(np.prod(shape) * 0.001)
coords = np.random.randint(low=0, high=shape, size=(N, len(shape))).T
Expand All @@ -16,13 +18,4 @@
# Create the sparse array from the coordinates and data
sparse_array = sparse.COO(coords, data, shape=shape)


if __name__ == "__main__":
from qtpy import QtWidgets

from ndv import NDViewer

qapp = QtWidgets.QApplication([])
v = NDViewer(sparse_array, channel_axis=1)
v.show()
qapp.exec()
ndv.imshow(sparse_array)
35 changes: 20 additions & 15 deletions examples/tensorstore_arr.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
from __future__ import annotations

import numpy as np
import tensorstore as ts
from qtpy import QtWidgets
try:
import tensorstore as ts
except ImportError:
raise ImportError("Please install tensorstore to run this example")

from ndv import NDViewer

shape = (10, 4, 3, 512, 512)
import ndv

data = ndv.data.cells3d()

ts_array = ts.open(
{"driver": "zarr", "kvstore": {"driver": "memory"}},
{
"driver": "zarr",
"kvstore": {"driver": "memory"},
"transform": {
# tensorstore supports labeled dimensions
"input_labels": ["z", "c", "y", "x"],
},
},
create=True,
shape=shape,
dtype=ts.uint8,
shape=data.shape,
dtype=data.dtype,
).result()
ts_array[:] = np.random.randint(0, 255, size=shape, dtype=np.uint8)
ts_array = ts_array[ts.d[:].label["t", "c", "z", "y", "x"]]
ts_array[:] = ndv.data.cells3d()

if __name__ == "__main__":
qapp = QtWidgets.QApplication([])
v = NDViewer(ts_array)
v.show()
qapp.exec()
ndv.imshow(ts_array)
21 changes: 11 additions & 10 deletions examples/torch_arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
import torch
except ImportError:
raise ImportError("Please install torch to run this example")
from numpy_arr import generate_5d_sine_wave
from qtpy import QtWidgets

from ndv import NDViewer
import warnings

import ndv

warnings.filterwarnings("ignore", "Named tensors") # Named tensors are experimental

# Example usage
array_shape = (10, 3, 5, 512, 512) # Specify the desired dimensions
sine_wave_5d = torch.asarray(generate_5d_sine_wave(array_shape))
try:
torch_data = torch.tensor(ndv.data.nd_sine_wave(), names=("t", "c", "z", "y", "x"))
except TypeError:
print("Named tensors are not supported in your version of PyTorch")
torch_data = torch.tensor(ndv.data.nd_sine_wave())

if __name__ == "__main__":
qapp = QtWidgets.QApplication([])
v = NDViewer(sine_wave_5d, channel_axis=1)
v.show()
qapp.exec()
ndv.imshow(torch_data)
16 changes: 6 additions & 10 deletions examples/xarray_arr.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from __future__ import annotations

import xarray as xr
from qtpy import QtWidgets

from ndv import NDViewer
try:
import xarray as xr
except ImportError:
raise ImportError("Please install xarray to run this example")
import ndv

da = xr.tutorial.open_dataset("air_temperature").air

if __name__ == "__main__":
qapp = QtWidgets.QApplication([])
v = NDViewer(da, colormaps=["thermal"], channel_mode="composite")
v.show()
qapp.exec()
ndv.imshow(da, cmap="thermal")
17 changes: 8 additions & 9 deletions examples/zarr_arr.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from __future__ import annotations

import zarr
import zarr.storage
from qtpy import QtWidgets
import ndv

try:
import zarr
import zarr.storage
except ImportError:
raise ImportError("Please `pip install zarr aiohttp` to run this example")

from ndv import NDViewer

URL = "https://s3.embl.de/i2k-2020/ngff-example-data/v0.4/tczyx.ome.zarr"
zarr_arr = zarr.open(URL, mode="r")

if __name__ == "__main__":
qapp = QtWidgets.QApplication([])
v = NDViewer(zarr_arr["s0"])
v.show()
qapp.exec()
ndv.imshow(zarr_arr["s0"])
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ pretty = true
minversion = "7.0"
testpaths = ["tests"]
filterwarnings = ["error"]
markers = ["allow_leaks: mark test to allow widget leaks"]

# https://coverage.readthedocs.io/
[tool.coverage.report]
Expand Down
6 changes: 4 additions & 2 deletions src/ndv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@

from typing import TYPE_CHECKING

from . import data
from .util import imshow
from .viewer._data_wrapper import DataWrapper
from .viewer._stack_viewer import NDViewer
from .viewer._viewer import NDViewer

__all__ = ["NDViewer", "DataWrapper"]
__all__ = ["NDViewer", "DataWrapper", "imshow", "data"]


if TYPE_CHECKING:
Expand Down
Loading

0 comments on commit 1d53ecb

Please sign in to comment.