Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support more array types, add imshow #3

Merged
merged 16 commits into from
Jun 8, 2024
18 changes: 16 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,25 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11"]

test-array-libs:
uses: pyapp-kit/workflows/.github/workflows/test-pyrepo.yml@v2
with:
os: ${{ matrix.os }}
python-version: ${{ matrix.python-version }}
extras: "test,third_party_arrays"
coverage-upload: artifact
qt: pyqt6
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ["3.9", "3.12"]

upload_coverage:
if: always()
needs: [test]
needs: [test, test-array-libs]
uses: pyapp-kit/workflows/.github/workflows/upload-coverage.yml@v2
secrets:
codecov_token: ${{ secrets.CODECOV_TOKEN }}
Expand Down
34 changes: 21 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@
Simple, fast-loading, asynchronous, n-dimensional viewer for Qt, with minimal dependencies.

```python
from qtpy import QtWidgets
from ndv import NDViewer
from skimage import data # just for example data here

qapp = QtWidgets.QApplication([])
v = NDViewer(data.cells3d())
v.show()
qapp.exec()
import ndv

data = ndv.data.cells3d()
# or ndv.data.nd_sine_wave()
# or *any* arraylike object (see support below)
ndv.imshow(data)
```

![Montage](https://github.com/pyapp-kit/ndv/assets/1609449/712861f7-ddcb-4ecd-9a4c-ba5f0cc1ee2c)
Expand All @@ -27,12 +25,22 @@ qapp.exec()
- sliders support integer as well as slice (range)-based slicing
- colormaps provided by [cmap](https://github.com/tlambert03/cmap)
- supports [vispy](https://github.com/vispy/vispy) and [pygfx](https://github.com/pygfx/pygfx) backends
- supports any numpy-like duck arrays, with special support for features in:
- `xarray.DataArray`
- supports any numpy-like duck arrays, including (but not limited to):
- `numpy.ndarray`
- `cupy.ndarray`
- `dask.array.Array`
- `tensorstore.TensorStore`
- `zarr`
- `dask`
- `jax.Array`
- `pyopencl.array.Array`
- `sparse.COO`
- `tensorstore.TensorStore` (supports named dimensions)
- `torch.Tensor` (supports named dimensions)
- `xarray.DataArray` (supports named dimensions)
- `zarr` (supports named dimensions)
- You can add support for your own storage class by subclassing `ndv.DataWrapper`
and implementing a couple methods. (This doesn't require modifying ndv,
but contributions of new wrappers are welcome!)

See examples for each of these array types in [examples](./examples/)

## Installation

Expand Down
40 changes: 40 additions & 0 deletions examples/custom_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np

import ndv

if TYPE_CHECKING:
from ndv import Indices, Sizes


class MyArrayThing:
def __init__(self, shape: tuple[int, ...]) -> None:
self.shape = shape
self._data = np.random.randint(0, 256, shape)

def __getitem__(self, item: Any) -> np.ndarray:
return self._data[item] # type: ignore [no-any-return]


class MyWrapper(ndv.DataWrapper[MyArrayThing]):
@classmethod
def supports(cls, data: Any) -> bool:
if isinstance(data, MyArrayThing):
return True
return False

def sizes(self) -> Sizes:
"""Return a mapping of {dim: size} for the data"""
return {f"dim_{k}": v for k, v in enumerate(self.data.shape)}

def isel(self, indexers: Indices) -> Any:
"""Convert mapping of {dim: index} to conventional indexing"""
idx = tuple(indexers.get(k, slice(None)) for k in range(len(self.data.shape)))
return self.data[idx]


data = MyArrayThing((10, 3, 512, 512))
ndv.imshow(data)
11 changes: 2 additions & 9 deletions examples/dask_arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dask.array.core import map_blocks
except ImportError:
raise ImportError("Please `pip install dask[array]` to run this example.")
import ndv

frame_size = (1024, 1024)

Expand All @@ -21,12 +22,4 @@ def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None:
chunks += [(x,) for x in frame_size]
dask_arr = map_blocks(_dask_block, chunks=chunks, dtype=np.uint8)

if __name__ == "__main__":
from qtpy import QtWidgets

from ndv import NDViewer

qapp = QtWidgets.QApplication([])
v = NDViewer(dask_arr)
v.show()
qapp.exec()
v = ndv.imshow(dask_arr)
16 changes: 3 additions & 13 deletions examples/jax_arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
import jax.numpy as jnp
except ImportError:
raise ImportError("Please install jax to run this example")
from numpy_arr import generate_5d_sine_wave
from qtpy import QtWidgets
import ndv

from ndv import NDViewer

# Example usage
array_shape = (10, 3, 5, 512, 512) # Specify the desired dimensions
sine_wave_5d = jnp.asarray(generate_5d_sine_wave(array_shape))

if __name__ == "__main__":
qapp = QtWidgets.QApplication([])
v = NDViewer(sine_wave_5d, channel_axis=1)
v.show()
qapp.exec()
jax_arr = jnp.asarray(ndv.data.nd_sine_wave())
v = ndv.imshow(jax_arr)
65 changes: 6 additions & 59 deletions examples/numpy_arr.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,11 @@
from __future__ import annotations

import numpy as np


def generate_5d_sine_wave(
shape: tuple[int, int, int, int, int],
amplitude: float = 240,
base_frequency: float = 5,
) -> np.ndarray:
"""5D dataset."""
# Unpack the dimensions
angle_dim, freq_dim, phase_dim, ny, nx = shape

# Create an empty array to hold the data
output = np.zeros(shape)

# Define spatial coordinates for the last two dimensions
half_per = base_frequency * np.pi
x = np.linspace(-half_per, half_per, nx)
y = np.linspace(-half_per, half_per, ny)
y, x = np.meshgrid(y, x)

# Iterate through each parameter in the higher dimensions
for phase_idx in range(phase_dim):
for freq_idx in range(freq_dim):
for angle_idx in range(angle_dim):
# Calculate phase and frequency
phase = np.pi / phase_dim * phase_idx
frequency = 1 + (freq_idx * 0.1) # Increasing frequency with each step

# Calculate angle
angle = np.pi / angle_dim * angle_idx
# Rotate x and y coordinates
xr = np.cos(angle) * x - np.sin(angle) * y
np.sin(angle) * x + np.cos(angle) * y

# Compute the sine wave
sine_wave = (amplitude * 0.5) * np.sin(frequency * xr + phase)
sine_wave += amplitude * 0.5

# Assign to the output array
output[angle_idx, freq_idx, phase_idx] = sine_wave

return output

import ndv

try:
from skimage import data

img = data.cells3d()
except Exception:
img = generate_5d_sine_wave((10, 3, 8, 512, 512))


if __name__ == "__main__":
from qtpy import QtWidgets

from ndv import NDViewer
img = ndv.data.cells3d()
except Exception as e:
print(e)
img = ndv.data.nd_sine_wave((10, 3, 8, 512, 512))

qapp = QtWidgets.QApplication([])
v = NDViewer(img)
v.show()
qapp.exec()
ndv.imshow(img)
17 changes: 17 additions & 0 deletions examples/pyopencl_arr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

try:
import pyopencl as cl
import pyopencl.array as cl_array
except ImportError:
raise ImportError("Please install pyopencl to run this example")
import ndv

# Set up OpenCL context and queue
context = cl.create_some_context(interactive=False)
queue = cl.CommandQueue(context)


gpu_data = cl_array.to_device(queue, ndv.data.nd_sine_wave())

ndv.imshow(gpu_data)
21 changes: 21 additions & 0 deletions examples/sparse_arr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

try:
import sparse
except ImportError:
raise ImportError("Please install sparse to run this example")

import numpy as np

import ndv

shape = (256, 4, 512, 512)
N = int(np.prod(shape) * 0.001)
coords = np.random.randint(low=0, high=shape, size=(N, len(shape))).T
data = np.random.randint(0, 256, N)


# Create the sparse array from the coordinates and data
sparse_array = sparse.COO(coords, data, shape=shape)

ndv.imshow(sparse_array)
35 changes: 20 additions & 15 deletions examples/tensorstore_arr.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
from __future__ import annotations

import numpy as np
import tensorstore as ts
from qtpy import QtWidgets
try:
import tensorstore as ts
except ImportError:
raise ImportError("Please install tensorstore to run this example")

from ndv import NDViewer

shape = (10, 4, 3, 512, 512)
import ndv

data = ndv.data.cells3d()

ts_array = ts.open(
{"driver": "zarr", "kvstore": {"driver": "memory"}},
{
"driver": "zarr",
"kvstore": {"driver": "memory"},
"transform": {
# tensorstore supports labeled dimensions
"input_labels": ["z", "c", "y", "x"],
},
},
create=True,
shape=shape,
dtype=ts.uint8,
shape=data.shape,
dtype=data.dtype,
).result()
ts_array[:] = np.random.randint(0, 255, size=shape, dtype=np.uint8)
ts_array = ts_array[ts.d[:].label["t", "c", "z", "y", "x"]]
ts_array[:] = ndv.data.cells3d()

if __name__ == "__main__":
qapp = QtWidgets.QApplication([])
v = NDViewer(ts_array)
v.show()
qapp.exec()
ndv.imshow(ts_array)
21 changes: 21 additions & 0 deletions examples/torch_arr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

try:
import torch
except ImportError:
raise ImportError("Please install torch to run this example")

import warnings

import ndv

warnings.filterwarnings("ignore", "Named tensors") # Named tensors are experimental

# Example usage
try:
torch_data = torch.tensor(ndv.data.nd_sine_wave(), names=("t", "c", "z", "y", "x"))
except TypeError:
print("Named tensors are not supported in your version of PyTorch")
torch_data = torch.tensor(ndv.data.nd_sine_wave())

ndv.imshow(torch_data)
16 changes: 6 additions & 10 deletions examples/xarray_arr.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from __future__ import annotations

import xarray as xr
from qtpy import QtWidgets

from ndv import NDViewer
try:
import xarray as xr
except ImportError:
raise ImportError("Please install xarray to run this example")
import ndv

da = xr.tutorial.open_dataset("air_temperature").air

if __name__ == "__main__":
qapp = QtWidgets.QApplication([])
v = NDViewer(da, colormaps=["thermal"], channel_mode="composite")
v.show()
qapp.exec()
ndv.imshow(da, cmap="thermal")
17 changes: 8 additions & 9 deletions examples/zarr_arr.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from __future__ import annotations

import zarr
import zarr.storage
from qtpy import QtWidgets
import ndv

try:
import zarr
import zarr.storage
except ImportError:
raise ImportError("Please `pip install zarr aiohttp` to run this example")

from ndv import NDViewer

URL = "https://s3.embl.de/i2k-2020/ngff-example-data/v0.4/tczyx.ome.zarr"
zarr_arr = zarr.open(URL, mode="r")

if __name__ == "__main__":
qapp = QtWidgets.QApplication([])
v = NDViewer(zarr_arr["s0"])
v.show()
qapp.exec()
ndv.imshow(zarr_arr["s0"])
Loading
Loading