From 0e65f934e8316527d79d0851a06c823b2bc5d77b Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Thu, 17 Oct 2024 09:26:14 -0400 Subject: [PATCH] wip --- mvc.py | 13 +++ src/ndv/v2ctl.py | 152 +++++++++++++++++++++++++++++ src/ndv/v2view.py | 66 +++++++++++++ src/ndv/viewer/_backends/_vispy.py | 3 + src/ndv/viewer_v2.py | 2 +- x.py | 8 +- 6 files changed, 238 insertions(+), 6 deletions(-) create mode 100644 mvc.py create mode 100644 src/ndv/v2ctl.py create mode 100644 src/ndv/v2view.py diff --git a/mvc.py b/mvc.py new file mode 100644 index 0000000..693c861 --- /dev/null +++ b/mvc.py @@ -0,0 +1,13 @@ +import numpy as np +from qtpy.QtWidgets import QApplication + +from ndv.v2ctl import ViewerController +from ndv.v2view import ViewerView + +app = QApplication([]) + +viewer = ViewerController(ViewerView()) # ultimately, this will be the public api +model = viewer.model +viewer.data = np.random.rand(96, 64, 128).astype(np.float32) +viewer.view.show() # temp +app.exec() diff --git a/src/ndv/v2ctl.py b/src/ndv/v2ctl.py new file mode 100644 index 0000000..d2b60f5 --- /dev/null +++ b/src/ndv/v2ctl.py @@ -0,0 +1,152 @@ +from collections.abc import Container, Hashable, Mapping, Sequence +from typing import Any, Protocol + +from psygnal import SignalInstance + +from .models._array_display_model import ArrayDisplayModel, AxisKey +from .viewer._backends._protocols import PImageHandle +from .viewer._data_wrapper import DataWrapper + + +class ViewP(Protocol): + currentIndexChanged: SignalInstance + + def create_sliders(self, coords: Mapping[Hashable, Sequence]) -> None: ... + def current_index(self) -> Mapping[AxisKey, int]: ... + def set_current_index(self, value: Mapping[AxisKey, int | slice]) -> None: ... + def add_image_to_canvas(self, data: Any) -> PImageHandle: ... + def hide_sliders( + self, axes_to_hide: Container[Hashable], *, show_remainder: bool = ... + ) -> None: ... + + +class ViewerController: + _data_wrapper: DataWrapper | None + _display_model: ArrayDisplayModel + + def __init__(self, view: ViewP, model: ArrayDisplayModel | None = None) -> None: + self.view = view + self.model = model or ArrayDisplayModel() + self._data_wrapper = None + self.view.currentIndexChanged.connect(self.on_slider_value_changed) + + @property + def model(self) -> ArrayDisplayModel: + """Return the display model for the viewer.""" + return self._display_model + + @model.setter + def model(self, display_model: ArrayDisplayModel) -> None: + """Set the display model for the viewer.""" + display_model = ArrayDisplayModel.model_validate(display_model) + previous_model: ArrayDisplayModel | None = getattr(self, "_display_model", None) + if previous_model is not None: + self._set_model_connected(previous_model, False) + + self._display_model = display_model + self._set_model_connected(display_model) + + def _set_model_connected( + self, model: ArrayDisplayModel, connect: bool = True + ) -> None: + """Connect or disconnect the model to/from the viewer. + + We do this in a single method so that we are sure to connect and disconnect + the same events in the same order. + """ + _connect = "connect" if connect else "disconnect" + + for obj, callback in [ + # (model.events.visible_axes, self._on_visible_axes_changed), + # the current_index attribute itself is immutable + (model.current_index.value_changed, self._on_current_index_changed), + # (model.events.channel_axis, self._on_channel_axis_changed), + # TODO: lut values themselves are mutable evented objects... + # so we need to connect to their events as well + # (model.luts.value_changed, self._on_luts_changed), + ]: + getattr(obj, _connect)(callback) + + def _on_current_index_changed(self) -> None: + value = self.model.current_index + self.view.set_current_index(value) + self._update_canvas() + + @property + def data(self) -> Any: + """Return data being displayed.""" + if self._data_wrapper is None: + return None + return self._data_wrapper.data + + @data.setter + def data(self, data: Any) -> None: + """Set the data to be displayed.""" + if data is None: + self._data_wrapper = None + return + self._data_wrapper = DataWrapper.create(data) + dims = self._data_wrapper.dims + coords = { + self._canonicalize_axis_key(ax, dims): c + for ax, c in self._data_wrapper.coords.items() + } + self.view.create_sliders(coords) + self._update_visible_sliders() + self._update_canvas() + + def on_slider_value_changed(self) -> None: + """Update the model when slider value changes.""" + slider_values = self.view.current_index() + self.model.current_index.update(slider_values) + return + self._update_canvas() + + def _update_canvas(self) -> None: + if not self._data_wrapper: + return + idx_request = self._current_index_request() + data = self._data_wrapper.isel(idx_request) + if hdl := getattr(self, "_handle", None): + hdl.remove() + self._handle = self.view.add_image_to_canvas(data) + + def _current_index_request(self) -> Mapping[int, int | slice]: + # Generate cannocalized index request + if self._data_wrapper is None: + return {} + + dims = self._data_wrapper.dims + idx_request = { + self._canonicalize_axis_key(ax, dims): v + for ax, v in self.model.current_index.items() + } + for ax in self.model.visible_axes: + ax_ = self._canonicalize_axis_key(ax, dims) + if not isinstance(idx_request.get(ax_), slice): + idx_request[ax_] = slice(None) + return idx_request + + def _update_visible_sliders(self) -> None: + """Update which sliders are visible based on the current model.""" + dims = self._data_wrapper.dims + visible_axes = { + self._canonicalize_axis_key(ax, dims) for ax in self.model.visible_axes + } + self.view.hide_sliders(visible_axes, show_remainder=True) + + def _canonicalize_axis_key(self, axis: AxisKey, dims: Sequence[Hashable]) -> int: + """Return positive index for AxisKey (which can be +/- int or label).""" + # TODO: improve performance by indexing ahead of time + if isinstance(axis, int): + ndims = len(dims) + ax = axis if axis >= 0 else len(dims) + axis + if ax >= ndims: + raise IndexError( + f"Axis index {axis} out of bounds for data with {ndims} dimensions" + ) + return ax + try: + return dims.index(axis) + except ValueError as e: + raise IndexError(f"Axis label {axis} not found in data dimensions") from e diff --git a/src/ndv/v2view.py b/src/ndv/v2view.py new file mode 100644 index 0000000..8b9bfc9 --- /dev/null +++ b/src/ndv/v2view.py @@ -0,0 +1,66 @@ +from collections.abc import Container, Hashable, Mapping, Sequence +from typing import Any + +from qtpy.QtCore import Qt, Signal +from qtpy.QtWidgets import QFormLayout, QVBoxLayout, QWidget +from superqt import QLabeledSlider + +from .models._array_display_model import AxisKey +from .viewer._backends import get_canvas_class +from .viewer._backends._protocols import PImageHandle + + +class ViewerView(QWidget): + currentIndexChanged = Signal() + + def __init__(self, parent: QWidget | None = None): + super().__init__(parent) + self._sliders: dict[Hashable, QLabeledSlider] = {} + self._canvas = get_canvas_class()() + self._canvas.set_ndim(2) + layout = QVBoxLayout(self) + self._slider_layout = QFormLayout() + self._slider_layout.setFieldGrowthPolicy( + QFormLayout.FieldGrowthPolicy.AllNonFixedFieldsGrow + ) + layout.addWidget(self._canvas.qwidget()) + layout.addLayout(self._slider_layout) + + def create_sliders(self, coords: Mapping[Hashable, Sequence]) -> None: + """Update sliders with the given coordinate ranges.""" + for axis, _coords in coords.items(): + sld = QLabeledSlider(Qt.Orientation.Horizontal) + sld.valueChanged.connect(self.currentIndexChanged.emit) + if isinstance(_coords, range): + sld.setRange(_coords.start, _coords.stop - 1) + sld.setSingleStep(_coords.step) + self._slider_layout.addRow(str(axis), sld) + self._sliders[axis] = sld + self.currentIndexChanged.emit() + + def add_image_to_canvas(self, data: Any) -> PImageHandle: + """Add image data to the canvas.""" + hdl = self._canvas.add_image(data) + self._canvas.set_range() + return hdl + + def hide_sliders( + self, axes_to_hide: Container[Hashable], show_remainder: bool = True + ) -> None: + """Hide sliders based on visible axes.""" + for ax, slider in self._sliders.items(): + if ax in axes_to_hide: + self._slider_layout.setRowVisible(slider, False) + elif show_remainder: + self._slider_layout.setRowVisible(slider, True) + + def current_index(self) -> Mapping[AxisKey, int | slice]: + """Return the current value of the sliders.""" + return {axis: slider.value() for axis, slider in self._sliders.items()} + + def set_current_index(self, value: Mapping[AxisKey, int | slice]) -> None: + """Set the current value of the sliders.""" + for axis, val in value.items(): + if isinstance(val, slice): + raise NotImplementedError("Slices are not supported yet") + self._sliders[axis].setValue(val) diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/viewer/_backends/_vispy.py index e74a05e..e0500a0 100755 --- a/src/ndv/viewer/_backends/_vispy.py +++ b/src/ndv/viewer/_backends/_vispy.py @@ -547,6 +547,9 @@ def set_range( When called with no arguments, the range is set to the full extent of the data. """ + # temporary + self._camera.set_range() + return _x = [0.0, 0.0] _y = [0.0, 0.0] _z = [0.0, 0.0] diff --git a/src/ndv/viewer_v2.py b/src/ndv/viewer_v2.py index 9fb6975..8f1c4ee 100644 --- a/src/ndv/viewer_v2.py +++ b/src/ndv/viewer_v2.py @@ -213,10 +213,10 @@ def _update_canvas(self) -> None: return idx_request = self._current_index_request() data = self._data_wrapper.isel(idx_request) - if hdl := getattr(self, "_handle", None): hdl.remove() self._handle = self._canvas.add_image(data) + self._canvas.set_range() def _on_channel_axis_changed(self, value: AxisKey) -> None: print("Channel axis changed:", value) diff --git a/x.py b/x.py index 3de65e4..c0d8a33 100644 --- a/x.py +++ b/x.py @@ -1,17 +1,15 @@ import numpy as np from qtpy.QtWidgets import QApplication -from rich import print from ndv.viewer_v2 import Viewer app = QApplication([]) v = Viewer() -v.data = np.random.rand(8, 64, 128) +v.data = np.random.rand(96, 64, 128).astype(np.float32) v.model.luts[1] = "viridis" - -# v.model.visible_axes = (0, 1) -print(v.model) +v.model.visible_axes = (-2, -1) +# print(v.model) v.show() v.model.current_index.update({0: 3, 1: 32, 2: 12}) app.exec()