diff --git a/LICENSE b/LICENSE index a4e53b6..54e32df 100644 --- a/LICENSE +++ b/LICENSE @@ -1,22 +1,28 @@ +BSD 3-Clause License -The MIT License (MIT) +Copyright (c) 2023, Howard Hughes Medical Institute -Copyright (c) 2023 William Patton +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index 86de2c8..352e0aa 100644 --- a/README.md +++ b/README.md @@ -1,67 +1,79 @@ -# napari-cellulus +

A napari plugin for cellulus

-[![License MIT](https://img.shields.io/pypi/l/napari-cellulus.svg?color=green)](https://github.com/funkelab/napari-cellulus/raw/main/LICENSE) -[![PyPI](https://img.shields.io/pypi/v/napari-cellulus.svg?color=green)](https://pypi.org/project/napari-cellulus) -[![Python Version](https://img.shields.io/pypi/pyversions/napari-cellulus.svg?color=green)](https://python.org) -[![tests](https://github.com/funkelab/napari-cellulus/workflows/tests/badge.svg)](https://github.com/funkelab/napari-cellulus/actions) -[![codecov](https://codecov.io/gh/funkelab/napari-cellulus/branch/main/graph/badge.svg)](https://codecov.io/gh/funkelab/napari-cellulus) -[![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/napari-cellulus)](https://napari-hub.org/plugins/napari-cellulus) +- **[Introduction](#introduction)** +- **[Installation](#installation)** +- **[Getting Started](#getting-started)** +- **[Citation](#citation)** +- **[Issues](#issues)** -A Napari plugin for Cellulus: Unsupervised Learning of Object-Centric Embeddings for Cell Instance Segmentation in Microscopy Images +### Introduction ----------------------------------- +This repository hosts the code for the napari plugin built around **cellulus**, which was described in the **[preprint](https://arxiv.org/pdf/2310.08501.pdf)** titled **Unsupervised Learning of *Object-Centric Embeddings* for Cell Instance Segmentation in Microscopy Images**. -This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template. +*cellulus* is a deep learning based method which can be used to obtain instance-segmentation of objects in microscopy images in an unsupervised fashion i.e. requiring no ground truth labels during training. - +### Installation -## Installation +One could execute these lines of code below to create a new environment and install dependencies. -You can install `napari-cellulus` via [pip]: +1. Create a new environment called `napari-cellulus`: - pip install napari-cellulus +```bash +conda create -y -n napari-cellulus python==3.9 +``` +2. Activate the newly-created environment: +``` +conda activate napari-cellulus +``` -To install latest development version : +3a. If using a GPU, install pytorch cuda dependencies: - pip install git+https://github.com/funkelab/napari-cellulus.git +```bash +conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia +``` +3b. otherwise (if using a CPU or MPS), run: -## Contributing +```bash +pip install torch torchvision +``` -Contributions are very welcome. Tests can be run with [tox], please ensure -the coverage at least stays the same before you submit a pull request. +4. Install the package from github: -## License +```bash +pip install git+https://github.com/funkelab/napari-cellulus.git +``` -Distributed under the terms of the [MIT] license, -"napari-cellulus" is free and open source software +### Getting Started -## Issues +Run the following commands in a terminal window: +``` +conda activate napari-cellulus +napari +``` -If you encounter any problems, please [file an issue] along with a detailed description. +Next, select `Cellulus` from the `Plugins` drop-down menu. -[napari]: https://github.com/napari/napari -[Cookiecutter]: https://github.com/audreyr/cookiecutter -[@napari]: https://github.com/napari -[MIT]: http://opensource.org/licenses/MIT -[BSD-3]: http://opensource.org/licenses/BSD-3-Clause -[GNU GPL v3.0]: http://www.gnu.org/licenses/gpl-3.0.txt -[GNU LGPL v3.0]: http://www.gnu.org/licenses/lgpl-3.0.txt -[Apache Software License 2.0]: http://www.apache.org/licenses/LICENSE-2.0 -[Mozilla Public License 2.0]: https://www.mozilla.org/media/MPL/2.0/index.txt -[cookiecutter-napari-plugin]: https://github.com/napari/cookiecutter-napari-plugin +### Citation -[file an issue]: https://github.com/funkelab/napari-cellulus/issues +If you find our work useful in your research, please consider citing: -[napari]: https://github.com/napari/napari -[tox]: https://tox.readthedocs.io/en/latest/ -[pip]: https://pypi.org/project/pip/ -[PyPI]: https://pypi.org/ + +```bibtex +@misc{wolf2023unsupervised, + title={Unsupervised Learning of Object-Centric Embeddings for Cell Instance Segmentation in Microscopy Images}, + author={Steffen Wolf and Manan Lalit and Henry Westmacott and Katie McDole and Jan Funke}, + year={2023}, + eprint={2310.08501}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` + +### Issues + +If you encounter any problems, please **[file an issue](https://github.com/funkelab/napari-cellulus/issues)** along with a description. diff --git a/setup.cfg b/setup.cfg index 7280387..ce0031b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,9 +5,8 @@ description = A Napari plugin for Cellulus: Unsupervised Learning of Object-Cent long_description = file: README.md long_description_content_type = text/markdown url = https://github.com/funkelab/napari-cellulus -author = William Patton -author_email = wllmpttn24@gmail.com -license = MIT +author = William Patton, Manan Lalit +author_email = wllmpttn24@gmail.com, lalitm@janelia.hhmi.org license_files = LICENSE classifiers = Development Status :: 2 - Pre-Alpha @@ -31,19 +30,16 @@ project_urls = [options] packages = find: install_requires = - numpy - magicgui + napari[all] + pyqtgraph qtpy cellulus @ git+https://github.com/funkelab/cellulus - matplotlib - torch - gunpowder - + pre-commit python_requires = >=3.8 include_package_data = True package_dir = - =src + = src # add your package requirements here diff --git a/src/napari_cellulus/__init__.py b/src/napari_cellulus/__init__.py index 47a9747..bf80da0 100644 --- a/src/napari_cellulus/__init__.py +++ b/src/napari_cellulus/__init__.py @@ -1,5 +1,5 @@ -__version__ = "0.0.3" +__version__ = "0.1.0" -from ._sample_data import tissuenet_sample +from .sample_data import tissue_net_sample -__all__ = ("tissuenet_sample",) +__all__ = ("tissue_net_sample",) diff --git a/src/napari_cellulus/_tests/test_samples.py b/src/napari_cellulus/_tests/test_samples.py deleted file mode 100644 index 7952513..0000000 --- a/src/napari_cellulus/_tests/test_samples.py +++ /dev/null @@ -1,3 +0,0 @@ -def test_open(make_napari_viewer): - viewer = make_napari_viewer() - viewer.open_sample(plugin="napari-cellulus", sample="tissuenet_sample") diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py deleted file mode 100644 index 9b3e4ae..0000000 --- a/src/napari_cellulus/dataset.py +++ /dev/null @@ -1,171 +0,0 @@ -from typing import List - -import gunpowder as gp -import numpy as np -from napari.layers import Image -from torch.utils.data import IterableDataset - -from .gp.nodes.napari_image_source import NapariImageSource -from .meta_data import NapariDatasetMetaData - - -class NapariDataset(IterableDataset): # type: ignore - def __init__( - self, - layer: Image, - axis_names: List[str], - crop_size: int, - control_point_spacing: int, - control_point_jitter: float, - ): - """A dataset that serves random samples from a zarr container. - - Args: - - layer: - - The napari layer to use. - The data should have shape `(s, c, [t,] [z,] y, x)`, where - `s` = # of samples, `c` = # of channels, `t` = # of frames, and - `z`/`y`/`x` are spatial extents. The dataset should have an - `"axis_names"` attribute that contains the names of the used - axes, e.g., `["s", "c", "y", "x"]` for a 2D dataset. - - axis_names: - - The names of the axes in the napari layer. - - crop_size: - - The size of data crops used during training (distinct from the - "patch size" of the method: from each crop, multiple patches - will be randomly selected and the loss computed on them). This - should be equal to the input size of the model that predicts - the OCEs. - - control_point_spacing: - - The distance in pixels between control points used for elastic - deformation of the raw data. - - control_point_jitter: - - How much to jitter the control points for elastic deformation - of the raw data, given as the standard deviation of a normal - distribution with zero mean. - """ - - self.layer = layer - self.axis_names = axis_names - self.control_point_spacing = control_point_spacing - self.control_point_jitter = control_point_jitter - self.__read_meta_data() - print(f"Number of spatial dims is {self.num_spatial_dims}") - self.crop_size = (crop_size,) * self.num_spatial_dims - self.__setup_pipeline() - - def __iter__(self): - return iter(self.__yield_sample()) - - def __setup_pipeline(self): - self.raw = gp.ArrayKey("RAW") - # treat all dimensions as spatial, with a voxel size = 1 - voxel_size = gp.Coordinate((1,) * self.num_dims) - offset = gp.Coordinate((0,) * self.num_dims) - shape = gp.Coordinate(self.layer.data.shape) - raw_spec = gp.ArraySpec( - roi=gp.Roi(offset, voxel_size * shape), - dtype=np.float32, - interpolatable=True, - voxel_size=voxel_size, - ) - - if self.num_channels == 0 and self.num_samples == 0: - self.pipeline = ( - NapariImageSource( - self.layer, self.raw, raw_spec, self.spatial_dims - ) - + gp.RandomLocation() - + gp.Unsqueeze([self.raw], 0) - + gp.Unsqueeze([self.raw], 0) - ) - elif self.num_channels == 0 and self.num_samples != 0: - self.pipeline = ( - NapariImageSource( - self.layer, self.raw, raw_spec, self.spatial_dims - ) - + gp.RandomLocation() - + gp.Unsqueeze([self.raw], 1) - ) - elif self.num_channels != 0 and self.num_samples == 0: - self.pipeline = ( - NapariImageSource( - self.layer, self.raw, raw_spec, self.spatial_dims - ) - + gp.RandomLocation() - + gp.Unsqueeze([self.raw], 0) - ) - elif self.num_channels != 0 and self.num_samples != 0: - self.pipeline = ( - NapariImageSource( - self.layer, self.raw, raw_spec, self.spatial_dims - ) - + gp.RandomLocation() - ) - - def __yield_sample(self): - """An infinite generator of crops.""" - - with gp.build(self.pipeline): - while True: - # request one sample, all channels, plus crop dimensions - request = gp.BatchRequest() - if self.num_channels == 0 and self.num_samples == 0: - request[self.raw] = gp.ArraySpec( - roi=gp.Roi( - (0,) * (self.num_dims), - self.crop_size, - ) - ) - elif self.num_channels == 0 and self.num_samples != 0: - request[self.raw] = gp.ArraySpec( - roi=gp.Roi( - (0,) * (self.num_dims), (1, *self.crop_size) - ) - ) - elif self.num_channels != 0 and self.num_samples == 0: - request[self.raw] = gp.ArraySpec( - roi=gp.Roi( - (0,) * (self.num_dims), - (self.num_channels, *self.crop_size), - ) - ) - elif self.num_channels != 0 and self.num_samples != 0: - request[self.raw] = gp.ArraySpec( - roi=gp.Roi( - (0,) * (self.num_dims), - (1, self.num_channels, *self.crop_size), - ) - ) - sample = self.pipeline.request_batch(request) - yield sample[self.raw].data[0] - - def __read_meta_data(self): - meta_data = NapariDatasetMetaData( - self.layer.data.shape, self.axis_names - ) - - self.num_dims = meta_data.num_dims - self.num_spatial_dims = meta_data.num_spatial_dims - self.num_channels = meta_data.num_channels - self.num_samples = meta_data.num_samples - self.sample_dim = meta_data.sample_dim - self.channel_dim = meta_data.channel_dim - self.time_dim = meta_data.time_dim - self.spatial_dims = meta_data.spatial_dims - - def get_num_channels(self): - return 1 if self.num_channels == 0 else self.num_channels - - def get_num_spatial_dims(self): - return self.num_spatial_dims diff --git a/src/napari_cellulus/_tests/__init__.py b/src/napari_cellulus/datasets/__init__.py similarity index 100% rename from src/napari_cellulus/_tests/__init__.py rename to src/napari_cellulus/datasets/__init__.py diff --git a/src/napari_cellulus/meta_data.py b/src/napari_cellulus/datasets/meta_data.py similarity index 96% rename from src/napari_cellulus/meta_data.py rename to src/napari_cellulus/datasets/meta_data.py index 7e8751e..35d69fe 100644 --- a/src/napari_cellulus/meta_data.py +++ b/src/napari_cellulus/datasets/meta_data.py @@ -11,7 +11,7 @@ def __init__(self, shape, axis_names): self.channel_dim = None self.time_dim = None self.spatial_array: Tuple[int, ...] = () - self.spatial_dims = () + self.spatial_dims: Tuple[int, ...] = () for dim, axis_name in enumerate(axis_names): if axis_name == "s": self.sample_dim = dim diff --git a/src/napari_cellulus/datasets/napari_dataset.py b/src/napari_cellulus/datasets/napari_dataset.py new file mode 100644 index 0000000..3952ab8 --- /dev/null +++ b/src/napari_cellulus/datasets/napari_dataset.py @@ -0,0 +1,244 @@ +from typing import List + +import gunpowder as gp +import numpy as np +from napari.layers import Image +from torch.utils.data import IterableDataset + +from .meta_data import NapariDatasetMetaData +from .napari_image_source import NapariImageSource + + +class NapariDataset(IterableDataset): + def __init__( + self, + layer: Image, + axis_names: List[str], + crop_size: int, + density: float, + kappa: float, + normalization_factor: float, + ): + self.layer = layer + self.axis_names = axis_names + self.__read_meta_data() + self.crop_size = (crop_size,) * self.num_spatial_dims + self.normalization_factor = normalization_factor + self.density = density + self.kappa = kappa + self.output_shape = tuple(int(_ - 16) for _ in self.crop_size) + self.normalization_factor = normalization_factor + self.unbiased_shape = tuple( + int(_ - (2 * self.kappa)) for _ in self.output_shape + ) + self.__setup_pipeline() + + def __read_meta_data(self): + meta_data = NapariDatasetMetaData( + self.layer.data.shape, self.axis_names + ) + + self.num_dims = meta_data.num_dims + self.num_spatial_dims = meta_data.num_spatial_dims + self.num_channels = meta_data.num_channels + self.num_samples = meta_data.num_samples + self.sample_dim = meta_data.sample_dim + self.channel_dim = meta_data.channel_dim + self.time_dim = meta_data.time_dim + self.spatial_dims = meta_data.spatial_dims + self.spatial_array = meta_data.spatial_array + + def __setup_pipeline(self): + self.raw = gp.ArrayKey("RAW") + + # treat all dimensions as spatial, with a voxel size = 1 + voxel_size = gp.Coordinate((1,) * self.num_dims) + offset = gp.Coordinate((0,) * self.num_dims) + shape = gp.Coordinate(self.layer.data.shape) + raw_spec = gp.ArraySpec( + roi=gp.Roi(offset, voxel_size * shape), + dtype=np.float32, + interpolatable=True, + voxel_size=voxel_size, + ) + if self.num_samples == 0: + self.pipeline = ( + NapariImageSource( + image=self.layer, + key=self.raw, + spec=raw_spec, + spatial_dims=self.spatial_dims, + ) + + gp.Unsqueeze([self.raw], 0) + + gp.RandomLocation() + ) + else: + self.pipeline = ( + NapariImageSource( + image=self.layer, + key=self.raw, + spec=raw_spec, + spatial_dims=self.spatial_dims, + ) + + gp.RandomLocation() + ) + + def __iter__(self): + return iter(self.__yield_sample()) + + def __yield_sample(self): + with gp.build(self.pipeline): + while True: + array_is_zero = True + while array_is_zero: + # request one sample, all channels, plus crop dimensions + request = gp.BatchRequest() + if self.num_channels == 0 and self.num_samples == 0: + request[self.raw] = gp.ArraySpec( + roi=gp.Roi( + (0,) * self.num_dims, + self.crop_size, + ) + ) + elif self.num_channels == 0 and self.num_samples != 0: + request[self.raw] = gp.ArraySpec( + roi=gp.Roi( + (0,) * self.num_dims, (1, *self.crop_size) + ) + ) + elif self.num_channels != 0 and self.num_samples == 0: + request[self.raw] = gp.ArraySpec( + roi=gp.Roi( + (0,) * self.num_dims, + (self.num_channels, *self.crop_size), + ) + ) + elif self.num_channels != 0 and self.num_samples != 0: + request[self.raw] = gp.ArraySpec( + roi=gp.Roi( + (0,) * self.num_dims, + (1, self.num_channels, *self.crop_size), + ) + ) + sample = self.pipeline.request_batch(request) + sample_data = sample[self.raw].data[0] + # if missing a channel, this is added by the Model class + + if np.max(sample_data) <= 0.0: + pass + else: + array_is_zero = False + ( + anchor_samples, + reference_samples, + ) = self.sample_coordinates() + yield sample_data, anchor_samples, reference_samples + + def get_num_samples(self): + return self.num_samples + + def get_num_channels(self): + return self.num_channels + + def get_num_spatial_dims(self): + return self.num_spatial_dims + + def get_num_dims(self): + return self.num_dims + + def get_spatial_dims(self): + return self.spatial_dims + + def get_spatial_array(self): + return self.spatial_array + + def sample_offsets_within_radius(self, radius, number_offsets): + if self.num_spatial_dims == 2: + offsets_x = np.random.randint( + -radius, radius + 1, size=2 * number_offsets + ) + offsets_y = np.random.randint( + -radius, radius + 1, size=2 * number_offsets + ) + offsets_coordinates = np.stack((offsets_x, offsets_y), axis=1) + elif self.num_spatial_dims == 3: + offsets_x = np.random.randint( + -radius, radius + 1, size=3 * number_offsets + ) + offsets_y = np.random.randint( + -radius, radius + 1, size=3 * number_offsets + ) + offsets_z = np.random.randint( + -radius, radius + 1, size=3 * number_offsets + ) + offsets_coordinates = np.stack( + (offsets_x, offsets_y, offsets_z), axis=1 + ) + + in_circle = (offsets_coordinates**2).sum(axis=1) < radius**2 + offsets_coordinates = offsets_coordinates[in_circle] + not_zero = np.absolute(offsets_coordinates).sum(axis=1) > 0 + offsets_coordinates = offsets_coordinates[not_zero] + + if len(offsets_coordinates) < number_offsets: + return self.sample_offsets_within_radius(radius, number_offsets) + + return offsets_coordinates[:number_offsets] + + def sample_coordinates(self): + num_anchors = self.get_num_anchors() + num_references = self.get_num_references() + + if self.num_spatial_dims == 2: + anchor_coordinates_x = np.random.randint( + self.kappa, + self.output_shape[0] - self.kappa + 1, + size=num_anchors, + ) + anchor_coordinates_y = np.random.randint( + self.kappa, + self.output_shape[1] - self.kappa + 1, + size=num_anchors, + ) + anchor_coordinates = np.stack( + (anchor_coordinates_x, anchor_coordinates_y), axis=1 + ) + elif self.num_spatial_dims == 3: + anchor_coordinates_x = np.random.randint( + self.kappa, + self.output_shape[0] - self.kappa + 1, + size=num_anchors, + ) + anchor_coordinates_y = np.random.randint( + self.kappa, + self.output_shape[1] - self.kappa + 1, + size=num_anchors, + ) + anchor_coordinates_z = np.random.randint( + self.kappa, + self.output_shape[2] - self.kappa + 1, + size=num_anchors, + ) + anchor_coordinates = np.stack( + ( + anchor_coordinates_x, + anchor_coordinates_y, + anchor_coordinates_z, + ), + axis=1, + ) + anchor_samples = np.repeat(anchor_coordinates, num_references, axis=0) + offset_in_pos_radius = self.sample_offsets_within_radius( + self.kappa, len(anchor_samples) + ) + reference_samples = anchor_samples + offset_in_pos_radius + + return anchor_samples, reference_samples + + def get_num_anchors(self): + return int( + self.density * self.unbiased_shape[0] * self.unbiased_shape[1] + ) + + def get_num_references(self): + return int(self.density * self.kappa**2 * np.pi) diff --git a/src/napari_cellulus/gp/nodes/napari_image_source.py b/src/napari_cellulus/datasets/napari_image_source.py similarity index 75% rename from src/napari_cellulus/gp/nodes/napari_image_source.py rename to src/napari_cellulus/datasets/napari_image_source.py index a13e996..bee332f 100644 --- a/src/napari_cellulus/gp/nodes/napari_image_source.py +++ b/src/napari_cellulus/datasets/napari_image_source.py @@ -7,7 +7,7 @@ class NapariImageSource(gp.BatchProvider): """ - A gunpowder interface to a napari Image + A gunpowder node to pull data from a napari Image Args: image (Image): The napari image layer to pull data from @@ -20,15 +20,11 @@ def __init__( ): self.array_spec = spec self.image = gp.Array( - normalize( - image.data.astype(np.float32), - pmin=1, - pmax=99.8, - axis=spatial_dims, + data=normalize( + image.data.astype(np.float32), 1, 99.8, axis=spatial_dims ), - self.array_spec, + spec=spec, ) - self.spatial_dims = spatial_dims self.key = key def setup(self): diff --git a/src/napari_cellulus/gp/__init__.py b/src/napari_cellulus/gp/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/napari_cellulus/gp/nodes/__init__.py b/src/napari_cellulus/gp/nodes/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/napari_cellulus/gui_helpers.py b/src/napari_cellulus/gui_helpers.py deleted file mode 100644 index a4aa620..0000000 --- a/src/napari_cellulus/gui_helpers.py +++ /dev/null @@ -1,45 +0,0 @@ -from magicgui.widgets import FunctionGui, create_widget -from matplotlib.backends.backend_qt5agg import ( - FigureCanvasQTAgg, -) -from matplotlib.backends.backend_qt5agg import ( - NavigationToolbar2QT as NavigationToolbar, -) -from matplotlib.figure import Figure -from qtpy import QtWidgets - - -class MplCanvas(FigureCanvasQTAgg): - def __init__(self, parent=None, width=5, height=4, dpi=100): - fig = Figure(figsize=(width, height), dpi=dpi) - self.axes = fig.add_subplot(111) - super().__init__(fig) - fig.set_tight_layout(True) - - -class MainWindow(QtWidgets.QMainWindow): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - sc = MplCanvas(self, width=5, height=4, dpi=100) - - # Create toolbar, passing canvas as first parament, parent (self, the MainWindow) as second. - toolbar = NavigationToolbar(sc, self) - - layout = QtWidgets.QVBoxLayout() - layout.addWidget(toolbar) - layout.addWidget(sc) - - # Create a placeholder widget to hold our toolbar and canvas. - widget = QtWidgets.QWidget() - widget.setLayout(layout) - self.setCentralWidget(widget) - - self.show() - - -def layer_choice_widget(viewer, annotation, **kwargs) -> FunctionGui: - widget = create_widget(annotation=annotation, **kwargs) - viewer.layers.events.inserted.connect(widget.reset_choices) - viewer.layers.events.removed.connect(widget.reset_choices) - return widget diff --git a/src/napari_cellulus/napari.yaml b/src/napari_cellulus/napari.yaml index 1c22723..39ac969 100644 --- a/src/napari_cellulus/napari.yaml +++ b/src/napari_cellulus/napari.yaml @@ -2,16 +2,22 @@ name: napari-cellulus display_name: Cellulus contributions: commands: - - id: napari-cellulus.tissuenet_sample - python_name: napari_cellulus._sample_data:tissuenet_sample + - id: napari-cellulus.tissue_net_sample + python_name: napari_cellulus.sample_data:tissue_net_sample title: Load sample data from Cellulus - - id: napari-cellulus.SegmentationWidget - python_name: napari_cellulus.widgets._widget:SegmentationWidget + - id: napari-cellulus.fluo_n2dl_hela_sample + python_name: napari_cellulus.sample_data:fluo_n2dl_hela_sample + title: Load sample data from Cellulus + - id: napari-cellulus.Widget + python_name: napari_cellulus.widget:Widget title: Cellulus sample_data: - - command: napari-cellulus.tissuenet_sample - display_name: Cellulus - key: tissuenet_sample + - command: napari-cellulus.tissue_net_sample + display_name: TissueNet + key: tissue_net_sample + - command: napari-cellulus.fluo_n2dl_hela_sample + display_name: Fluo-N2DL-HeLa + key: fluo_n2dl_hela_sample widgets: - - command: napari-cellulus.SegmentationWidget + - command: napari-cellulus.Widget display_name: Cellulus diff --git a/src/napari_cellulus/_sample_data.py b/src/napari_cellulus/sample_data.py similarity index 50% rename from src/napari_cellulus/_sample_data.py rename to src/napari_cellulus/sample_data.py index 159abd1..2a4efc0 100644 --- a/src/napari_cellulus/_sample_data.py +++ b/src/napari_cellulus/sample_data.py @@ -1,12 +1,28 @@ from pathlib import Path import numpy as np +import tifffile -TISSUENET_SAMPLE = Path(__file__).parent / "sample_data/tissuenet-sample.npy" +TISSUE_NET_SAMPLE = Path(__file__).parent / "sample_data/tissue_net_sample.npy" +FLUO_N2DL_HELA = Path(__file__).parent / "sample_data/fluo_n2dl_hela.tif" -def tissuenet_sample(): - (x, y) = np.load(TISSUENET_SAMPLE, "r") +def fluo_n2dl_hela_sample(): + x = tifffile.imread(FLUO_N2DL_HELA) + return [ + ( + x, + { + "name": "Raw", + "metadata": {"axes": ["s", "c", "y", "x"]}, + }, + "image", + ) + ] + + +def tissue_net_sample(): + (x, y) = np.load(TISSUE_NET_SAMPLE, "r") x = x.transpose(0, 3, 1, 2) y = y.transpose(0, 3, 1, 2).astype(np.uint8) return [ diff --git a/src/napari_cellulus/sample_data/fluo_n2dl_hela.tif b/src/napari_cellulus/sample_data/fluo_n2dl_hela.tif new file mode 100644 index 0000000..f28870c Binary files /dev/null and b/src/napari_cellulus/sample_data/fluo_n2dl_hela.tif differ diff --git a/src/napari_cellulus/sample_data/tissuenet-sample.npy b/src/napari_cellulus/sample_data/tissue_net_sample.npy similarity index 100% rename from src/napari_cellulus/sample_data/tissuenet-sample.npy rename to src/napari_cellulus/sample_data/tissue_net_sample.npy diff --git a/src/napari_cellulus/widget.py b/src/napari_cellulus/widget.py new file mode 100644 index 0000000..bbad8b8 --- /dev/null +++ b/src/napari_cellulus/widget.py @@ -0,0 +1,950 @@ +from pathlib import Path + +import gunpowder as gp +import numpy as np +import pyqtgraph as pg +import torch +from attrs import asdict +from cellulus.configs.experiment_config import ExperimentConfig +from cellulus.configs.inference_config import InferenceConfig +from cellulus.configs.model_config import ModelConfig +from cellulus.configs.train_config import TrainConfig +from cellulus.criterions import get_loss +from cellulus.models import get_model +from cellulus.train import train_iteration +from cellulus.utils.mean_shift import mean_shift_segmentation +from cellulus.utils.misc import size_filter +from napari.qt.threading import thread_worker +from napari.utils.events import Event +from qtpy.QtCore import Qt +from qtpy.QtWidgets import ( + QButtonGroup, + QCheckBox, + QComboBox, + QGridLayout, + QLabel, + QLineEdit, + QMainWindow, + QPushButton, + QRadioButton, + QScrollArea, + QVBoxLayout, +) +from scipy.ndimage import binary_fill_holes +from scipy.ndimage import distance_transform_edt as dtedt +from skimage.filters import threshold_otsu +from tqdm import tqdm + +from .datasets.napari_dataset import NapariDataset +from .datasets.napari_image_source import NapariImageSource + + +class Model(torch.nn.Module): + def __init__(self, model, selected_axes): + super().__init__() + self.model = model + self.selected_axes = selected_axes + + def forward(self, x): + if "s" in self.selected_axes and "c" in self.selected_axes: + pass + elif "s" in self.selected_axes and "c" not in self.selected_axes: + + x = torch.unsqueeze(x, 1) + elif "s" not in self.selected_axes and "c" in self.selected_axes: + pass + elif "s" not in self.selected_axes and "c" not in self.selected_axes: + x = torch.unsqueeze(x, 1) + return self.model(x) + + @staticmethod + def select_and_add_coordinates(outputs, coordinates): + selections = [] + # outputs.shape = (b, c, h, w) or (b, c, d, h, w) + for output, coordinate in zip(outputs, coordinates): + if output.ndim == 3: + selection = output[:, coordinate[:, 1], coordinate[:, 0]] + elif output.ndim == 4: + selection = output[ + :, coordinate[:, 2], coordinate[:, 1], coordinate[:, 0] + ] + selection = selection.transpose(1, 0) + selection += coordinate + selections.append(selection) + + # selection.shape = (b, c, p) where p is the number of selected positions + return torch.stack(selections, dim=0) + + def set_infer(self, p_salt_pepper, num_infer_iterations, device): + self.model.eval() + self.model.set_infer(p_salt_pepper, num_infer_iterations, device) + + +class Widget(QMainWindow): + def __init__(self, napari_viewer): + super().__init__() + self.viewer = napari_viewer + self.scroll = QScrollArea() + # initialize outer layout + layout = QVBoxLayout() + + # initialize individual grid layouts from top to bottom + self.grid_0 = QGridLayout() # title + self.set_grid_0() + self.grid_1 = QGridLayout() # device + self.set_grid_1() + self.grid_2 = QGridLayout() # raw image selector + self.set_grid_2() + self.grid_3 = QGridLayout() # train configs + self.set_grid_3() + self.grid_4 = QGridLayout() # model configs + self.set_grid_4() + self.grid_5 = QGridLayout() # loss plot and train/stop button + self.set_grid_5() + self.grid_6 = QGridLayout() # inference + self.set_grid_6() + self.grid_7 = QGridLayout() # feedback + self.set_grid_7() + self.create_configs() # configs + self.viewer.dims.events.current_step.connect( + self.update_inference_widgets + ) # listen to viewer slider + + layout.addLayout(self.grid_0) + layout.addLayout(self.grid_1) + layout.addLayout(self.grid_2) + layout.addLayout(self.grid_3) + layout.addLayout(self.grid_4) + layout.addLayout(self.grid_5) + layout.addLayout(self.grid_6) + layout.addLayout(self.grid_7) + self.set_scroll_area(layout) + self.viewer.layers.events.inserted.connect(self.update_raw_selector) + self.viewer.layers.events.removed.connect(self.update_raw_selector) + + def update_raw_selector(self, event): + count = 0 + for i in range(self.raw_selector.count() - 1, -1, -1): + if self.raw_selector.itemText(i) == f"{event.value}": + # remove item + self.raw_selector.removeItem(i) + count = 1 + if count == 0: + self.raw_selector.addItems([f"{event.value}"]) + + def set_grid_0(self): + text_label = QLabel("

Cellulus

") + method_description_label = QLabel( + 'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.

https://github.com/funkelab/cellulus' + ) + self.grid_0.addWidget(text_label, 0, 0, 1, 1) + self.grid_0.addWidget(method_description_label, 1, 0, 2, 1) + + def set_grid_1(self): + device_label = QLabel(self) + device_label.setText("Device") + self.device_combo_box = QComboBox(self) + self.device_combo_box.addItem("cpu") + self.device_combo_box.addItem("cuda:0") + self.device_combo_box.addItem("mps") + self.device_combo_box.setCurrentText("mps") + self.grid_1.addWidget(device_label, 0, 0, 1, 1) + self.grid_1.addWidget(self.device_combo_box, 0, 1, 1, 1) + + def set_grid_2(self): + self.raw_selector = QComboBox(self) + for layer in self.viewer.layers: + self.raw_selector.addItem(f"{layer}") + self.grid_2.addWidget(self.raw_selector, 0, 0, 1, 5) + # Initialize Checkboxes + self.s_check_box = QCheckBox("s/t") + self.c_check_box = QCheckBox("c") + self.z_check_box = QCheckBox("z") + self.y_check_box = QCheckBox("y") + self.x_check_box = QCheckBox("x") + self.grid_2.addWidget(self.s_check_box, 1, 0, 1, 1) + self.grid_2.addWidget(self.c_check_box, 1, 1, 1, 1) + self.grid_2.addWidget(self.z_check_box, 1, 2, 1, 1) + self.grid_2.addWidget(self.y_check_box, 1, 3, 1, 1) + self.grid_2.addWidget(self.x_check_box, 1, 4, 1, 1) + + def set_grid_3(self): + crop_size_label = QLabel(self) + crop_size_label.setText("Crop Size") + self.crop_size_line = QLineEdit(self) + self.crop_size_line.setAlignment(Qt.AlignCenter) + self.crop_size_line.setText("252") + batch_size_label = QLabel(self) + batch_size_label.setText("Batch Size") + self.batch_size_line = QLineEdit(self) + self.batch_size_line.setAlignment(Qt.AlignCenter) + self.batch_size_line.setText("8") + max_iterations_label = QLabel(self) + max_iterations_label.setText("Max iterations") + self.max_iterations_line = QLineEdit(self) + self.max_iterations_line.setAlignment(Qt.AlignCenter) + self.max_iterations_line.setText("100000") + self.grid_3.addWidget(crop_size_label, 0, 0, 1, 1) + self.grid_3.addWidget(self.crop_size_line, 0, 1, 1, 1) + self.grid_3.addWidget(batch_size_label, 1, 0, 1, 1) + self.grid_3.addWidget(self.batch_size_line, 1, 1, 1, 1) + self.grid_3.addWidget(max_iterations_label, 2, 0, 1, 1) + self.grid_3.addWidget(self.max_iterations_line, 2, 1, 1, 1) + + def set_grid_4(self): + feature_maps_label = QLabel(self) + feature_maps_label.setText("Number of feature maps") + self.feature_maps_line = QLineEdit(self) + self.feature_maps_line.setAlignment(Qt.AlignCenter) + self.feature_maps_line.setText("24") + feature_maps_increase_label = QLabel(self) + feature_maps_increase_label.setText("Feature maps inc. factor") + self.feature_maps_increase_line = QLineEdit(self) + self.feature_maps_increase_line.setAlignment(Qt.AlignCenter) + self.feature_maps_increase_line.setText("3") + self.train_model_from_scratch_checkbox = QCheckBox( + "Train model from scratch" + ) + + self.train_model_from_scratch_checkbox.setChecked(False) + self.grid_4.addWidget(feature_maps_label, 0, 0, 1, 1) + self.grid_4.addWidget(self.feature_maps_line, 0, 1, 1, 1) + self.grid_4.addWidget(feature_maps_increase_label, 1, 0, 1, 1) + self.grid_4.addWidget(self.feature_maps_increase_line, 1, 1, 1, 1) + self.grid_4.addWidget( + self.train_model_from_scratch_checkbox, 2, 0, 1, 2 + ) + + def set_grid_5(self): + self.losses_widget = pg.PlotWidget() + self.losses_widget.setBackground((37, 41, 49)) + styles = {"color": "white", "font-size": "16px"} + self.losses_widget.setLabel("left", "Loss", **styles) + self.losses_widget.setLabel("bottom", "Iterations", **styles) + self.start_training_button = QPushButton("Start training") + self.start_training_button.setFixedSize(140, 30) + self.stop_training_button = QPushButton("Stop training") + self.stop_training_button.setFixedSize(140, 30) + + self.grid_5.addWidget(self.losses_widget, 0, 0, 4, 4) + self.grid_5.addWidget(self.start_training_button, 5, 0, 1, 2) + self.grid_5.addWidget(self.stop_training_button, 5, 2, 1, 2) + self.start_training_button.clicked.connect( + self.prepare_for_start_training + ) + self.stop_training_button.clicked.connect( + self.prepare_for_stop_training + ) + + def set_grid_6(self): + threshold_label = QLabel("Threshold") + self.threshold_line = QLineEdit(self) + self.threshold_line.setAlignment(Qt.AlignCenter) + self.threshold_line.setText(None) + + bandwidth_label = QLabel("Bandwidth") + self.bandwidth_line = QLineEdit(self) + self.bandwidth_line.setAlignment(Qt.AlignCenter) + + self.radio_button_group = QButtonGroup(self) + self.radio_button_cell = QRadioButton("Cell") + self.radio_button_nucleus = QRadioButton("Nucleus") + self.radio_button_group.addButton(self.radio_button_nucleus) + self.radio_button_group.addButton(self.radio_button_cell) + + self.radio_button_nucleus.setChecked(True) + self.min_size_label = QLabel("Minimum Size") + self.min_size_line = QLineEdit(self) + self.min_size_line.setAlignment(Qt.AlignCenter) + self.start_inference_button = QPushButton("Start inference") + self.start_inference_button.setFixedSize(140, 30) + self.stop_inference_button = QPushButton("Stop inference") + self.stop_inference_button.setFixedSize(140, 30) + + self.grid_6.addWidget(threshold_label, 0, 0, 1, 1) + self.grid_6.addWidget(self.threshold_line, 0, 1, 1, 1) + self.grid_6.addWidget(bandwidth_label, 1, 0, 1, 1) + self.grid_6.addWidget(self.bandwidth_line, 1, 1, 1, 1) + self.grid_6.addWidget(self.radio_button_cell, 2, 0, 1, 1) + self.grid_6.addWidget(self.radio_button_nucleus, 2, 1, 1, 1) + self.grid_6.addWidget(self.min_size_label, 3, 0, 1, 1) + self.grid_6.addWidget(self.min_size_line, 3, 1, 1, 1) + self.grid_6.addWidget(self.start_inference_button, 4, 0, 1, 1) + self.grid_6.addWidget(self.stop_inference_button, 4, 1, 1, 1) + self.start_inference_button.clicked.connect( + self.prepare_for_start_inference + ) + self.stop_inference_button.clicked.connect( + self.prepare_for_stop_inference + ) + + def set_grid_7(self): + # Initialize Feedback Button + feedback_label = QLabel( + 'Please share any feedback here.' + ) + self.grid_7.addWidget(feedback_label, 0, 0, 2, 1) + + def set_scroll_area(self, layout): + self.scroll.setLayout(layout) + self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) + self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.scroll.setWidgetResizable(True) + + self.setFixedWidth(300) + self.setCentralWidget(self.scroll) + + def get_selected_axes(self): + names = [] + for name, check_box in zip( + "sczyx", + [ + self.s_check_box, + self.c_check_box, + self.z_check_box, + self.y_check_box, + self.x_check_box, + ], + ): + if check_box.isChecked(): + names.append(name) + + return names + + def create_configs(self): + self.train_config = TrainConfig( + crop_size=[int(self.crop_size_line.text())], + batch_size=int(self.batch_size_line.text()), + max_iterations=int(self.max_iterations_line.text()), + device=self.device_combo_box.currentText(), + ) + self.model_config = ModelConfig( + num_fmaps=int(self.feature_maps_line.text()), + fmap_inc_factor=int(self.feature_maps_increase_line.text()), + ) + + self.experiment_config = ExperimentConfig( + train_config=asdict(self.train_config), + model_config=asdict(self.model_config), + ) + self.losses, self.iterations = [], [] + self.start_iteration = 0 + self.model_dir = "/tmp/models" + self.thresholds = [] + self.band_widths = [] + self.min_sizes = [] + if len(self.thresholds) == 0: + self.threshold_line.setEnabled(False) + if len(self.band_widths) == 0: + self.bandwidth_line.setEnabled(False) + if len(self.min_sizes) == 0: + self.min_size_line.setEnabled(False) + + def update_inference_widgets(self, event: Event): + if self.s_check_box.isChecked(): + shape = event.value + sample_index = shape[0] + if len(self.thresholds) == self.napari_dataset.get_num_samples(): + if self.thresholds[sample_index]!=None: + self.threshold_line.setText( + str(round(self.thresholds[sample_index], 3)) + ) + if len(self.band_widths) == self.napari_dataset.get_num_samples(): + if self.band_widths[sample_index]!=None: + self.bandwidth_line.setText( + str(round(self.band_widths[sample_index], 3)) + ) + if len(self.min_sizes) == self.napari_dataset.get_num_samples(): + if self.min_sizes[sample_index]!=None: + self.min_size_line.setText( + str(round(self.min_sizes[sample_index], 3)) + ) + + def prepare_for_start_training(self): + self.start_training_button.setEnabled(False) + self.stop_training_button.setEnabled(True) + self.threshold_line.setEnabled(False) + self.bandwidth_line.setEnabled(False) + self.radio_button_nucleus.setEnabled(False) + self.radio_button_cell.setEnabled(False) + self.min_size_line.setEnabled(False) + self.start_inference_button.setEnabled(False) + self.stop_inference_button.setEnabled(False) + + self.train_worker = self.train() + self.train_worker.yielded.connect(self.on_yield_training) + self.train_worker.start() + + @thread_worker + def train(self): + for layer in self.viewer.layers: + if f"{layer}" == self.raw_selector.currentText(): + raw_image_layer = layer + break + + if not Path(self.model_dir).exists(): + Path(self.model_dir).mkdir() + + # Turn layer into dataset + self.napari_dataset = NapariDataset( + layer=raw_image_layer, + axis_names=self.get_selected_axes(), + crop_size=self.train_config.crop_size[0], # list to integer + density=self.train_config.density, + kappa=self.train_config.kappa, + normalization_factor=self.experiment_config.normalization_factor, + ) + # Create dataloader + train_dataloader = torch.utils.data.DataLoader( + dataset=self.napari_dataset, + batch_size=self.train_config.batch_size, + drop_last=True, + num_workers=self.train_config.num_workers, + pin_memory=True, + ) + # Set model + model_original = get_model( + in_channels=self.napari_dataset.get_num_channels() + if self.napari_dataset.get_num_channels() != 0 + else 1, + out_channels=self.napari_dataset.get_num_spatial_dims(), + num_fmaps=self.model_config.num_fmaps, + fmap_inc_factor=self.model_config.fmap_inc_factor, + features_in_last_layer=self.model_config.features_in_last_layer, + downsampling_factors=[ + tuple(factor) + for factor in self.model_config.downsampling_factors + ], + num_spatial_dims=self.napari_dataset.get_num_spatial_dims(), + ) + + # Set device + self.device = torch.device(self.train_config.device) + model = Model( + model=model_original, selected_axes=self.get_selected_axes() + ) + self.model = model.to(self.device) + + # Initialize model weights + if self.model_config.initialize: + for _name, layer in self.model.named_modules(): + if isinstance(layer, torch.nn.modules.conv._ConvNd): + torch.nn.init.kaiming_normal_( + layer.weight, nonlinearity="relu" + ) + + # Set loss + criterion = get_loss( + regularizer_weight=self.train_config.regularizer_weight, + temperature=self.train_config.temperature, + density=self.train_config.density, + num_spatial_dims=self.napari_dataset.get_num_spatial_dims(), + device=self.device, + ) + + # Set optimizer + self.optimizer = torch.optim.Adam( + self.model.parameters(), + lr=self.train_config.initial_learning_rate, + weight_decay=0.01, + ) + + # Resume training + if self.train_model_from_scratch_checkbox.isChecked(): + self.losses, self.iterations = [], [] + self.start_iteration = 0 + self.losses_widget.clear() + + else: + if self.model_config.checkpoint is None: + pass + else: + print(f"Resuming model from {self.model_config.checkpoint}") + state = torch.load( + self.model_config.checkpoint, map_location=self.device + ) + self.start_iteration = state["iterations"][-1] + 1 + self.model.load_state_dict( + state["model_state_dict"], strict=True + ) + self.optimizer.load_state_dict(state["optim_state_dict"]) + self.losses, self.iterations = ( + state["losses"], + state["iterations"], + ) + + # Call Train Iteration + for iteration, batch in tqdm( + zip( + range(self.start_iteration, self.train_config.max_iterations), + train_dataloader, + ) + ): + loss, oce_loss, prediction = train_iteration( + batch, + model=self.model, + criterion=criterion, + optimizer=self.optimizer, + device=self.device, + ) + yield loss, iteration + + def on_yield_training(self, loss_iteration): + loss, iteration = loss_iteration + print(f"===> Iteration: {iteration}, loss: {loss:.6f}") + self.iterations.append(iteration) + self.losses.append(loss) + self.losses_widget.plot(self.iterations, self.losses) + + def prepare_for_stop_training(self): + self.start_training_button.setEnabled(True) + self.stop_training_button.setEnabled(True) + if len(self.thresholds) == 0: + self.threshold_line.setEnabled(False) + else: + self.threshold_line.setEnabled(True) + if len(self.band_widths) == 0: + self.bandwidth_line.setEnabled(False) + else: + self.bandwidth_line.setEnabled(True) + self.radio_button_nucleus.setEnabled(True) + self.radio_button_cell.setEnabled(True) + if len(self.min_sizes) == 0: + self.min_size_line.setEnabled(False) + else: + self.min_size_line.setEnabled(True) + self.start_inference_button.setEnabled(True) + self.stop_inference_button.setEnabled(True) + if self.train_worker is not None: + state = { + "model_state_dict": self.model.state_dict(), + "optim_state_dict": self.optimizer.state_dict(), + "iterations": self.iterations, + "losses": self.losses, + } + checkpoint_file_name = Path("/tmp/models") / "last.pth" + torch.save(state, checkpoint_file_name) + self.train_worker.quit() + self.model_config.checkpoint = checkpoint_file_name + + def prepare_for_start_inference(self): + self.start_training_button.setEnabled(False) + self.stop_training_button.setEnabled(False) + self.threshold_line.setEnabled(False) + self.bandwidth_line.setEnabled(False) + self.radio_button_nucleus.setEnabled(False) + self.radio_button_cell.setEnabled(False) + self.min_size_line.setEnabled(False) + self.start_inference_button.setEnabled(False) + self.stop_inference_button.setEnabled(True) + + self.inference_config = InferenceConfig( + crop_size=[min(self.napari_dataset.get_spatial_array()) + 16], + post_processing="cell" + if self.radio_button_cell.isChecked() + else "nucleus", + ) + + self.inference_worker = self.infer() + # self.inference_worker.yielded.connect(self.on_yield_infer) + self.inference_worker.returned.connect(self.on_return_infer) + self.inference_worker.start() + + def prepare_for_stop_inference(self): + self.start_training_button.setEnabled(True) + self.stop_training_button.setEnabled(True) + self.threshold_line.setEnabled(True) + self.bandwidth_line.setEnabled(True) + self.radio_button_nucleus.setEnabled(True) + self.radio_button_cell.setEnabled(True) + self.min_size_line.setEnabled(True) + self.start_inference_button.setEnabled(True) + self.stop_inference_button.setEnabled(True) + if self.napari_dataset.get_num_samples() == 0: + self.threshold_line.setText(str(round(self.thresholds[0], 3))) + self.bandwidth_line.setText(str(round(self.band_widths[0], 3))) + self.min_size_line.setText(str(round(self.min_sizes[0], 3))) + + @thread_worker + def infer(self): + for layer in self.viewer.layers: + if f"{layer}" == self.raw_selector.currentText(): + raw_image_layer = layer + break + + self.thresholds = ( + [None] * self.napari_dataset.get_num_samples() + if self.napari_dataset.get_num_samples() != 0 + else [None] * 1 + ) + if ( + self.inference_config.bandwidth is None + and len(self.band_widths) == 0 + ): + self.band_widths = ( + [0.5 * self.experiment_config.object_size] + * self.napari_dataset.get_num_samples() + if self.napari_dataset.get_num_samples() != 0 + else [0.5 * self.experiment_config.object_size] + ) + + if self.inference_config.min_size is None and len(self.min_sizes) == 0: + if self.napari_dataset.get_num_spatial_dims() == 2: + self.min_sizes = ( + [ + int( + 0.1 + * np.pi + * (self.experiment_config.object_size**2) + / 4 + ) + ] + * self.napari_dataset.get_num_samples() + if self.napari_dataset.get_num_samples() != 0 + else [ + int( + 0.1 + * np.pi + * (self.experiment_config.object_size**2) + / 4 + ) + ] + ) + elif ( + self.napari_dataset.get_num_spatial_dims() == 3 + and len(self.min_sizes) == 0 + ): + self.min_sizes = ( + [ + int( + 0.1 + * 4.0 + / 3.0 + * np.pi + * (self.experiment_config.object_size**3) + / 8 + ) + ] + * self.napari_dataset.get_num_samples() + if self.napari_dataset.get_num_samples() != 0 + else [ + int( + 0.1 + * 4.0 + / 3.0 + * np.pi + * (self.experiment_config.object_size**3) + / 8 + ) + ] + ) + + # set in eval mode + self.model.eval() + self.model.set_infer( + p_salt_pepper=self.inference_config.p_salt_pepper, + num_infer_iterations=self.inference_config.num_infer_iterations, + device=self.device, + ) + + if self.napari_dataset.get_num_spatial_dims() == 2: + crop_size_tuple = (self.inference_config.crop_size[0],) * 2 + + elif self.napari_dataset.get_num_spatial_dims() == 3: + crop_size_tuple = (self.inference_config.crop_size[0],) * 3 + + input_shape = gp.Coordinate( + ( + 1, + self.napari_dataset.get_num_channels() + if self.napari_dataset.get_num_channels() != 0 + else 1, + *crop_size_tuple, + ) + ) + + if self.napari_dataset.get_num_channels() == 0: + output_shape = gp.Coordinate( + self.model( + torch.zeros((1, *crop_size_tuple), dtype=torch.float32).to( + self.device + ) + ).shape + ) + else: + output_shape = gp.Coordinate( + self.model( + torch.zeros( + (1, 1, *crop_size_tuple), dtype=torch.float32 + ).to(self.device) + ).shape + ) + + voxel_size = ( + (1,) * 2 + if self.napari_dataset.get_num_spatial_dims() == 2 + else (1,) * 3 + ) + + input_size = gp.Coordinate(input_shape[2:]) * gp.Coordinate(voxel_size) + output_size = gp.Coordinate(output_shape[2:]) * gp.Coordinate( + voxel_size + ) + context = (input_size - output_size) // 2 + raw = gp.ArrayKey("RAW") + prediction = gp.ArrayKey("PREDICT") + scan_request = gp.BatchRequest() + + # scan_request.add(raw, input_size) + scan_request[raw] = gp.Roi( + (-8,) * (self.napari_dataset.get_num_spatial_dims()), + crop_size_tuple, + ) + scan_request.add(prediction, output_size) + predict = gp.torch.Predict( + self.model, + inputs={"x": raw}, + outputs={0: prediction}, + array_specs={prediction: gp.ArraySpec(voxel_size=voxel_size)}, + ) + + pipeline = NapariImageSource( + image=raw_image_layer, + key=raw, + spec=gp.ArraySpec( + gp.Roi( + (0,) * self.napari_dataset.get_num_spatial_dims(), + raw_image_layer.data.shape[ + -self.napari_dataset.get_num_spatial_dims() : + ], + ), + voxel_size=(1,) * self.napari_dataset.get_num_spatial_dims(), + ), + spatial_dims=self.napari_dataset.get_spatial_dims(), + ) + + if self.napari_dataset.get_num_samples() == 0: + pipeline += ( + gp.Pad(raw, context, mode="reflect") + + gp.Unsqueeze([raw], 0) + + predict + + gp.Scan(scan_request) + ) + else: + pipeline += ( + gp.Pad(raw, context, mode="reflect") + + predict + + gp.Scan(scan_request) + ) + + request = gp.BatchRequest() + request.add( + prediction, + raw_image_layer.data.shape[ + -self.napari_dataset.get_num_spatial_dims() : + ], + ) + + # Obtain Embeddings + print("Predicting Embeddings ...") + + with gp.build(pipeline): + batch = pipeline.request_batch(request) + + embeddings = batch.arrays[prediction].data + embeddings_centered = np.zeros_like(embeddings) + foreground_mask = np.zeros_like(embeddings[:, 0:1, ...], dtype=bool) + colormaps = ["red", "green", "blue"] + + # Obtain Object Centered Embeddings + for sample in tqdm(range(embeddings.shape[0])): + embeddings_sample = embeddings[sample] + embeddings_std = embeddings_sample[-1, ...] + embeddings_mean = embeddings_sample[ + np.newaxis, : self.napari_dataset.get_num_spatial_dims(), ... + ].copy() + threshold = threshold_otsu(embeddings_std) + + self.thresholds[sample] = threshold + binary_mask = embeddings_std < threshold + foreground_mask[sample] = binary_mask[np.newaxis, ...] + embeddings_centered_sample = embeddings_sample.copy() + embeddings_mean_masked = ( + binary_mask[np.newaxis, np.newaxis, ...] * embeddings_mean + ) + if embeddings_centered_sample.shape[0] == 3: + c_x = embeddings_mean_masked[0, 0] + c_y = embeddings_mean_masked[0, 1] + c_x = c_x[c_x != 0].mean() + c_y = c_y[c_y != 0].mean() + embeddings_centered_sample[0] -= c_x + embeddings_centered_sample[1] -= c_y + elif embeddings_centered_sample.shape[0] == 3: + c_x = embeddings_mean_masked[0, 0] + c_y = embeddings_mean_masked[0, 1] + c_z = embeddings_mean_masked[0, 2] + c_x = c_x[c_x != 0].mean() + c_y = c_y[c_y != 0].mean() + c_z = c_z[c_z != 0].mean() + embeddings_centered_sample[0] -= c_x + embeddings_centered_sample[1] -= c_y + embeddings_centered_sample[2] -= c_z + + embeddings_centered[sample] = embeddings_centered_sample + + embeddings_layers = [ + ( + embeddings_centered[:, i : i + 1, ...].copy(), + { + "name": "Offset (" + + "zyx"[self.napari_dataset.get_num_spatial_dims() - i] + + ")" + if i < self.napari_dataset.get_num_spatial_dims() + else "Uncertainty", + "colormap": colormaps[ + self.napari_dataset.get_num_spatial_dims() - i + ] + if i < self.napari_dataset.get_num_spatial_dims() + else "gray", + "blending": "additive", + }, + "image", + ) + for i in range(self.napari_dataset.get_num_spatial_dims() + 1) + ] + print("Clustering Objects in the obtained Foreground Mask ...") + detection = np.zeros_like(embeddings[:, 0:1, ...], dtype=np.uint16) + for sample in tqdm(range(embeddings.shape[0])): + embeddings_sample = embeddings[sample] + embeddings_std = embeddings_sample[-1, ...] + embeddings_mean = embeddings_sample[ + np.newaxis, : self.napari_dataset.get_num_spatial_dims(), ... + ].copy() + + detection_sample = mean_shift_segmentation( + embeddings_mean, + embeddings_std, + bandwidth=self.band_widths[sample], + min_size=self.inference_config.min_size, + reduction_probability=self.inference_config.reduction_probability, + threshold=self.thresholds[sample], + seeds=None, + ) + detection[sample, 0, ...] = detection_sample + + print("Converting Detections to Segmentations ...") + segmentation = np.zeros_like(embeddings[:, 0:1, ...], dtype=np.uint16) + if self.radio_button_cell.isChecked(): + for sample in tqdm(range(embeddings.shape[0])): + segmentation_sample = detection[sample, 0].copy() + distance_foreground = dtedt(segmentation_sample == 0) + expanded_mask = ( + distance_foreground < self.inference_config.grow_distance + ) + distance_background = dtedt(expanded_mask) + segmentation_sample[ + distance_background < self.inference_config.shrink_distance + ] = 0 + segmentation[sample, 0, ...] = segmentation_sample + elif self.radio_button_nucleus.isChecked(): + raw_image = raw_image_layer.data + for sample in tqdm(range(embeddings.shape[0])): + segmentation_sample = detection[sample, 0] + if ( + self.napari_dataset.get_num_samples() == 0 + and self.napari_dataset.get_num_channels() == 0 + ): + raw_image_sample = raw_image + elif ( + self.napari_dataset.get_num_samples() != 0 + and self.napari_dataset.get_num_channels() == 0 + ): + raw_image_sample = raw_image[sample] + elif ( + self.napari_dataset.get_num_samples() == 0 + and self.napari_dataset.get_num_channels() != 0 + ): + raw_image_sample = raw_image[0] + else: + raw_image_sample = raw_image[sample, 0] + + ids = np.unique(segmentation_sample) + ids = ids[ids != 0] + + for id_ in ids: + segmentation_id_mask = segmentation_sample == id_ + if self.napari_dataset.get_num_spatial_dims() == 2: + y, x = np.where(segmentation_id_mask) + y_min, y_max, x_min, x_max = ( + np.min(y), + np.max(y), + np.min(x), + np.max(x), + ) + elif self.napari_dataset.get_num_spatial_dims() == 3: + z, y, x = np.where(segmentation_id_mask) + z_min, z_max, y_min, y_max, x_min, x_max = ( + np.min(z), + np.max(z), + np.min(y), + np.max(y), + np.min(x), + np.max(x), + ) + raw_image_masked = raw_image_sample[segmentation_id_mask] + threshold = threshold_otsu(raw_image_masked) + mask = segmentation_id_mask & ( + raw_image_sample > threshold + ) + + if self.napari_dataset.get_num_spatial_dims() == 2: + mask_small = binary_fill_holes( + mask[y_min : y_max + 1, x_min : x_max + 1] + ) + mask[y_min : y_max + 1, x_min : x_max + 1] = mask_small + y, x = np.where(mask) + segmentation[sample, 0, y, x] = id_ + elif self.napari_dataset.get_num_spatial_dims() == 3: + mask_small = binary_fill_holes( + mask[ + z_min : z_max + 1, + y_min : y_max + 1, + x_min : x_max + 1, + ] + ) + mask[ + z_min : z_max + 1, + y_min : y_max + 1, + x_min : x_max + 1, + ] = mask_small + z, y, x = np.where(mask) + segmentation[sample, 0, z, y, x] = id_ + + print("Removing small objects ...") + + # size filter - remove small objects + for sample in tqdm(range(embeddings.shape[0])): + segmentation[sample, 0, ...] = size_filter( + segmentation[sample, 0], self.min_sizes[sample] + ) + return ( + embeddings_layers + + [(foreground_mask, {"name": "Foreground Mask"}, "labels")] + + [(detection, {"name": "Detection"}, "labels")] + + [(segmentation, {"name": "Segmentation"}, "labels")] + ) + + def on_return_infer(self, layers): + for data, metadata, layer_type in layers: + if layer_type == "image": + self.viewer.add_image(data, **metadata) + elif layer_type == "labels": + self.viewer.add_labels(data.astype(int), **metadata) + self.viewer.layers["Offset (x)"].visible = False + self.viewer.layers["Offset (y)"].visible = False + self.viewer.layers["Uncertainty"].visible = False + self.viewer.layers["Foreground Mask"].visible = False + self.viewer.layers["Detection"].visible = False + self.viewer.layers["Segmentation"].visible = True + self.inference_worker.quit() + self.prepare_for_stop_inference() diff --git a/src/napari_cellulus/widgets/__init__.py b/src/napari_cellulus/widgets/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py deleted file mode 100644 index 09e1dc5..0000000 --- a/src/napari_cellulus/widgets/_widget.py +++ /dev/null @@ -1,765 +0,0 @@ -import os -import time - -import gunpowder as gp -import napari -import numpy as np -import pyqtgraph as pg -import torch -from cellulus.criterions import get_loss -from cellulus.models import get_model -from cellulus.train import train_iteration -from cellulus.utils.mean_shift import mean_shift_segmentation -from magicgui import magic_factory - -# widget stuff -from napari.qt.threading import thread_worker -from qtpy.QtCore import Qt -from qtpy.QtWidgets import ( - QCheckBox, - QComboBox, - QGridLayout, - QGroupBox, - QHBoxLayout, - QLabel, - QLineEdit, - QMainWindow, - QPushButton, - QScrollArea, - QVBoxLayout, - QWidget, -) -from scipy.ndimage import distance_transform_edt as dtedt -from skimage.filters import threshold_otsu -from superqt import QCollapsible - -from ..dataset import NapariDataset -from ..gp.nodes.napari_image_source import NapariImageSource - -# local package imports -from ..gui_helpers import layer_choice_widget - -############ GLOBALS ################### -time_now = 0 -_train_config = None -_model_config = None -_segment_config = None -_model = None -_optimizer = None -_scheduler = None -_dataset = None - - -class SegmentationWidget(QMainWindow): - def __init__(self, napari_viewer): - super().__init__() - self.widget = QWidget() - self.scroll = QScrollArea() - self.viewer = napari_viewer - - # initialize train_config and model_config - - self.train_config = None - self.model_config = None - self.segment_config = None - - # initialize losses and iterations - self.losses = [] - self.iterations = [] - - # initialize mode. this will change to 'training' and 'segmentring' later - self.mode = "configuring" - - # initialize UI components - method_description_label = QLabel( - 'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.

https://github.com/funkelab/cellulus' - ) - - # specify layout - outer_layout = QVBoxLayout() - - # Initialize object size widget - object_size_label = QLabel(self) - object_size_label.setText("Object Size [px]:") - self.object_size_line = QLineEdit(self) - self.object_size_line.setText("30") - device_label = QLabel(self) - device_label.setText("Device") - self.device_combo_box = QComboBox(self) - self.device_combo_box.addItem("cpu") - self.device_combo_box.addItem("cuda:0") - self.device_combo_box.addItem("mps") - self.device_combo_box.setCurrentText("mps") - - grid_layout = QGridLayout() - grid_layout.addWidget(object_size_label, 0, 0) - grid_layout.addWidget(self.object_size_line, 0, 1) - grid_layout.addWidget(device_label, 1, 0) - grid_layout.addWidget(self.device_combo_box, 1, 1) - - # Initialize train configs widget - collapsible_train_configs = QCollapsible("Train Configs", self) - collapsible_train_configs.addWidget(self.create_train_configs_widget) - - # Initialize model configs widget - collapsible_model_configs = QCollapsible("Model Configs", self) - collapsible_model_configs.addWidget(self.create_model_configs_widget) - - # Initialize loss/iterations widget - - self.losses_widget = pg.PlotWidget() - self.losses_widget.setBackground((37, 41, 49)) - styles = {"color": "white", "font-size": "16px"} - self.losses_widget.setLabel("left", "Loss", **styles) - self.losses_widget.setLabel("bottom", "Iterations", **styles) - # self.canvas = MplCanvas(self, width=5, height=10, dpi=100) - # canvas_layout = QVBoxLayout() - # canvas_layout.addWidget(self.canvas) - # if len(self.iterations) == 0: - # self.loss_plot = self.canvas.axes.plot([], [], label="")[0] - # self.canvas.axes.legend() - # self.canvas.axes.set_xlabel("Iterations") - # self.canvas.axes.set_ylabel("Loss") - # plot_container_widget = QWidget() - # plot_container_widget.setLayout(canvas_layout) - - # Initialize Layer Choice - self.raw_selector = layer_choice_widget( - self.viewer, annotation=napari.layers.Image, name="raw" - ) - - # Initialize Checkboxes - self.s_checkbox = QCheckBox("s") - self.c_checkbox = QCheckBox("c") - self.t_checkbox = QCheckBox("t") - self.z_checkbox = QCheckBox("z") - self.y_checkbox = QCheckBox("y") - self.x_checkbox = QCheckBox("x") - - axis_layout = QHBoxLayout() - axis_layout.addWidget(self.s_checkbox) - axis_layout.addWidget(self.c_checkbox) - axis_layout.addWidget(self.t_checkbox) - axis_layout.addWidget(self.z_checkbox) - axis_layout.addWidget(self.y_checkbox) - axis_layout.addWidget(self.x_checkbox) - axis_selector = QGroupBox("Axis Names:") - axis_selector.setLayout(axis_layout) - - # Initialize Train Button - self.train_button = QPushButton("Train", self) - self.train_button.clicked.connect(self.prepare_for_training) - - # Initialize Save and Load Widget - collapsible_save_load_widget = QCollapsible( - "Save and Load Model", self - ) - save_load_layout = QHBoxLayout() - save_model_button = QPushButton("Save Model", self) - load_model_button = QPushButton("Load Model", self) - save_load_layout.addWidget(save_model_button) - save_load_layout.addWidget(load_model_button) - save_load_widget = QWidget() - save_load_widget.setLayout(save_load_layout) - collapsible_save_load_widget.addWidget(save_load_widget) - - # Initialize Segment Configs widget - collapsible_segment_configs = QCollapsible("Inference Configs", self) - collapsible_segment_configs.addWidget( - self.create_segment_configs_widget - ) - - # Initialize Segment Button - self.segment_button = QPushButton("Segment", self) - self.segment_button.clicked.connect(self.prepare_for_segmenting) - - # Initialize progress bar - # self.pbar = QProgressBar(self) - - # Initialize Feedback Button - self.feedback_label = QLabel( - 'Please share any feedback here.' - ) - - # Add all components to outer_layout - - outer_layout.addWidget(method_description_label) - outer_layout.addLayout(grid_layout) - # outer_layout.addWidget(global_params_widget) - outer_layout.addWidget(collapsible_train_configs) - outer_layout.addWidget(collapsible_model_configs) - # outer_layout.addWidget(plot_container_widget) - outer_layout.addWidget(self.losses_widget) - outer_layout.addWidget(self.raw_selector.native) - outer_layout.addWidget(axis_selector) - outer_layout.addWidget(self.train_button) - outer_layout.addWidget(collapsible_save_load_widget) - outer_layout.addWidget(collapsible_segment_configs) - outer_layout.addWidget(self.segment_button) - outer_layout.addWidget(self.feedback_label) - outer_layout.setSpacing(20) - self.widget.setLayout(outer_layout) - - self.scroll.setWidget(self.widget) - self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) - self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - self.scroll.setWidgetResizable(True) - - self.setFixedWidth(400) - self.setCentralWidget(self.scroll) - - @property - def create_train_configs_widget(self): - @magic_factory(call_button="Save") - def train_configs_widget( - crop_size: int = 252, - batch_size: int = 8, - max_iterations: int = 100_000, - initial_learning_rate: float = 4e-5, - temperature: float = 10.0, - regularizer_weight: float = 1e-5, - reduce_mean: bool = True, - density: float = 0.1, - kappa: float = 10.0, - num_workers: int = 8, - control_point_spacing: int = 64, - control_point_jitter: float = 2.0, - ): - global _train_config - # Specify what should happen when 'Save' button is pressed - _train_config = { - "crop_size": crop_size, - "batch_size": batch_size, - "max_iterations": max_iterations, - "initial_learning_rate": initial_learning_rate, - "temperature": temperature, - "regularizer_weight": regularizer_weight, - "reduce_mean": reduce_mean, - "density": density, - "kappa": kappa, - "num_workers": num_workers, - "control_point_spacing": control_point_spacing, - "control_point_jitter": control_point_jitter, - } - - if not hasattr(self, "__create_train_configs_widget"): - self.__create_train_configs_widget = train_configs_widget() - self.__create_train_configs_widget_native = ( - self.__create_train_configs_widget.native - ) - return self.__create_train_configs_widget_native - - @property - def create_model_configs_widget(self): - @magic_factory(call_button="Save") - def model_configs_widget( - num_fmaps: int = 24, - fmap_inc_factor: int = 3, - features_in_last_layer: int = 64, - downsampling_factors: int = 2, - downsampling_layers: int = 1, - initialize: bool = True, - ): - global _model_config - # Specify what should happen when 'Save' button is pressed - _model_config = { - "num_fmaps": num_fmaps, - "fmap_inc_factor": fmap_inc_factor, - "features_in_last_layer": features_in_last_layer, - "downsampling_factors": downsampling_factors, - "downsampling_layers": downsampling_layers, - "initialize": initialize, - } - - if not hasattr(self, "__create_model_configs_widget"): - self.__create_model_configs_widget = model_configs_widget() - self.__create_model_configs_widget_native = ( - self.__create_model_configs_widget.native - ) - return self.__create_model_configs_widget_native - - @property - def create_segment_configs_widget(self): - @magic_factory(call_button="Save") - def segment_configs_widget( - crop_size: int = 252, - p_salt_pepper: float = 0.01, - num_infer_iterations: int = 16, - bandwidth: int = 7, - reduction_probability: float = 0.1, - min_size: int = 25, - grow_distance: int = 3, - shrink_distance: int = 6, - ): - global _segment_config - # Specify what should happen when 'Save' button is pressed - _segment_config = { - "crop_size": crop_size, - "p_salt_pepper": p_salt_pepper, - "num_infer_iterations": num_infer_iterations, - "bandwidth": bandwidth, - "reduction_probability": reduction_probability, - "min_size": min_size, - "grow_distance": grow_distance, - "shrink_distance": shrink_distance, - } - - if not hasattr(self, "__create_segment_configs_widget"): - self.__create_segment_configs_widget = segment_configs_widget() - self.__create_segment_configs_widget_native = ( - self.__create_segment_configs_widget.native - ) - return self.__create_segment_configs_widget_native - - def get_selected_axes(self): - names = [] - for name, checkbox in zip( - "sctzyx", - [ - self.s_checkbox, - self.c_checkbox, - self.t_checkbox, - self.z_checkbox, - self.y_checkbox, - self.x_checkbox, - ], - ): - if checkbox.isChecked(): - names.append(name) - return names - - def prepare_for_training(self): - - # check if train_config object exists - global _train_config, _model_config, _model, _optimizer - - if _train_config is None: - # set default values - _train_config = { - "crop_size": 252, - "batch_size": 8, - "max_iterations": 100_000, - "initial_learning_rate": 4e-5, - "temperature": 10.0, - "regularizer_weight": 1e-5, - "reduce_mean": True, - "density": 0.1, - "kappa": 10.0, - "num_workers": 8, - "control_point_spacing": 64, - "control_point_jitter": 2.0, - } - - # check if model_config object exists - if _model_config is None: - _model_config = { - "num_fmaps": 24, - "fmap_inc_factor": 3, - "features_in_last_layer": 64, - "downsampling_factors": 2, - "downsampling_layers": 1, - "initialize": True, - } - - print(self.sender()) - self.update_mode(self.sender()) - - if self.mode == "training": - self.worker = self.train_napari() - self.worker.yielded.connect(self.on_yield) - self.worker.returned.connect(self.on_return) - self.worker.start() - elif self.mode == "configuring": - state = { - "iteration": self.iterations[-1], - "model_state_dict": _model.state_dict(), - "optim_state_dict": _optimizer.state_dict(), - "iterations": self.iterations, - "losses": self.losses, - } - filename = os.path.join("models", "last.pth") - torch.save(state, filename) - self.worker.quit() - - def update_mode(self, sender): - - if self.train_button.text() == "Train" and sender == self.train_button: - self.train_button.setText("Pause") - self.mode = "training" - self.segment_button.setEnabled(False) - elif ( - self.train_button.text() == "Pause" and sender == self.train_button - ): - self.train_button.setText("Train") - self.mode = "configuring" - self.segment_button.setEnabled(True) - elif ( - self.segment_button.text() == "Segment" - and sender == self.segment_button - ): - self.segment_button.setText("Pause") - self.mode = "segmenting" - self.train_button.setEnabled(False) - elif ( - self.segment_button.text() == "Pause" - and sender == self.segment_button - ): - self.segment_button.setText("Segment") - self.mode = "configuring" - self.train_button.setEnabled(True) - - @thread_worker - def train_napari(self): - - global _train_config, _model_config, _model, _scheduler, _optimizer, _dataset - - # Turn layer into dataset - _dataset = NapariDataset( - layer=self.raw_selector.value, - axis_names=self.get_selected_axes(), - crop_size=_train_config["crop_size"], - control_point_spacing=_train_config["control_point_spacing"], - control_point_jitter=_train_config["control_point_jitter"], - ) - - if not os.path.exists("models"): - os.makedirs("models") - - # create train dataloader - dataloader = torch.utils.data.DataLoader( - dataset=_dataset, - batch_size=_train_config["batch_size"], - drop_last=True, - num_workers=_train_config["num_workers"], - pin_memory=True, - ) - - downsampling_factors = [ - [ - _model_config["downsampling_factors"], - ] - * _dataset.get_num_spatial_dims() - ] * _model_config["downsampling_layers"] - - # set model - _model = get_model( - in_channels=_dataset.get_num_channels(), - out_channels=_dataset.get_num_spatial_dims(), - num_fmaps=_model_config["num_fmaps"], - fmap_inc_factor=_model_config["fmap_inc_factor"], - features_in_last_layer=_model_config["features_in_last_layer"], - downsampling_factors=[ - tuple(factor) for factor in downsampling_factors - ], - num_spatial_dims=_dataset.get_num_spatial_dims(), - ) - - # set device - device = torch.device(self.device_combo_box.currentText()) - - _model = _model.to(device) - - # initialize model weights - if _model_config["initialize"]: - for _name, layer in _model.named_modules(): - if isinstance(layer, torch.nn.modules.conv._ConvNd): - torch.nn.init.kaiming_normal_( - layer.weight, nonlinearity="relu" - ) - - # set loss - criterion = get_loss( - regularizer_weight=_train_config["regularizer_weight"], - temperature=_train_config["temperature"], - kappa=_train_config["kappa"], - density=_train_config["density"], - num_spatial_dims=_dataset.get_num_spatial_dims(), - reduce_mean=_train_config["reduce_mean"], - device=device, - ) - - # set optimizer - _optimizer = torch.optim.Adam( - _model.parameters(), - lr=_train_config["initial_learning_rate"], - ) - - # set scheduler: - - def lambda_(iteration): - return pow( - (1 - ((iteration) / _train_config["max_iterations"])), 0.9 - ) - - # resume training - if len(self.iterations) == 0: - start_iteration = 0 - else: - start_iteration = self.iterations[-1] - - if not os.path.exists("models/last.pth"): - pass - else: - print("Resuming model from 'models/last.pth'") - state = torch.load("models/last.pth", map_location=device) - start_iteration = state["iteration"] + 1 - self.iterations = state["iterations"] - self.losses = state["losses"] - _model.load_state_dict(state["model_state_dict"], strict=True) - _optimizer.load_state_dict(state["optim_state_dict"]) - - # call `train_iteration` - for iteration, batch in zip( - range(start_iteration, _train_config["max_iterations"]), - dataloader, - ): - _scheduler = torch.optim.lr_scheduler.LambdaLR( - _optimizer, lr_lambda=lambda_, last_epoch=iteration - 1 - ) - - train_loss, prediction = train_iteration( - batch, - model=_model, - criterion=criterion, - optimizer=_optimizer, - device=device, - ) - _scheduler.step() - yield (iteration, train_loss) - - def on_yield(self, step_data): - if self.mode == "training": - global time_now - iteration, loss = step_data - current_time = time.time() - time_elapsed = current_time - time_now - time_now = current_time - print( - f"iteration {iteration}, loss {loss}, seconds/iteration {time_elapsed}" - ) - self.iterations.append(iteration) - self.losses.append(loss) - self.update_canvas() - elif self.mode == "segmenting": - print(step_data) - # self.pbar.setValue(step_data) - - def update_canvas(self): - self.losses_widget.plot(self.iterations, self.losses) - - # self.loss_plot.set_xdata(self.iterations) - # self.loss_plot.set_ydata(self.losses) - # self.canvas.axes.relim() - # self.canvas.axes.autoscale_view() - # self.canvas.draw() - - def on_return(self, layers): - # Describes what happens once segment button has completed - - for data, metadata, layer_type in layers: - if layer_type == "image": - self.viewer.add_image(data, **metadata) - elif layer_type == "labels": - self.viewer.add_labels(data.astype(int), **metadata) - self.update_mode(self.segment_button) - - def prepare_for_segmenting(self): - global _segment_config - # check if segment_config exists - if _segment_config is None: - _segment_config = { - "crop_size": 252, - "p_salt_pepper": 0.01, - "num_infer_iterations": 16, - "bandwidth": None, - "reduction_probability": 0.1, - "min_size": None, - "grow_distance": 3, - "shrink_distance": 6, - } - - # update mode - self.update_mode(self.sender()) - - if self.mode == "segmenting": - self.worker = self.segment_napari() - self.worker.yielded.connect(self.on_yield) - self.worker.returned.connect(self.on_return) - self.worker.start() - elif self.mode == "configuring": - self.worker.quit() - - @thread_worker - def segment_napari(self): - global _segment_config, _model, _dataset - - raw = self.raw_selector.value - - if _segment_config["bandwidth"] is None: - _segment_config["bandwidth"] = int( - 0.5 * float(self.object_size_line.text()) - ) - if _segment_config["min_size"] is None: - _segment_config["min_size"] = int( - 0.1 * np.pi * (float(self.object_size_line.text()) ** 2) / 4 - ) - _model.eval() - - num_spatial_dims = _dataset.num_spatial_dims - num_channels = _dataset.num_channels - spatial_dims = _dataset.spatial_dims - num_samples = _dataset.num_samples - - print( - f"Num spatial dims {num_spatial_dims} num channels {num_channels} spatial_dims {spatial_dims} num_samples {num_samples}" - ) - - crop_size = (_segment_config["crop_size"],) * num_spatial_dims - device = self.device_combo_box.currentText() - - num_channels_temp = 1 if num_channels == 0 else num_channels - - voxel_size = gp.Coordinate((1,) * num_spatial_dims) - _model.set_infer( - p_salt_pepper=_segment_config["p_salt_pepper"], - num_infer_iterations=_segment_config["num_infer_iterations"], - device=device, - ) - - input_shape = gp.Coordinate((1, num_channels_temp, *crop_size)) - output_shape = gp.Coordinate( - _model( - torch.zeros( - (1, num_channels_temp, *crop_size), dtype=torch.float32 - ).to(device) - ).shape - ) - input_size = ( - gp.Coordinate(input_shape[-num_spatial_dims:]) * voxel_size - ) - output_size = ( - gp.Coordinate(output_shape[-num_spatial_dims:]) * voxel_size - ) - context = (input_size - output_size) / 2 - - raw_key = gp.ArrayKey("RAW") - prediction_key = gp.ArrayKey("PREDICT") - scan_request = gp.BatchRequest() - scan_request.add(raw_key, input_size) - scan_request.add(prediction_key, output_size) - - predict = gp.torch.Predict( - _model, - inputs={"raw": raw_key}, - outputs={0: prediction_key}, - array_specs={prediction_key: gp.ArraySpec(voxel_size=voxel_size)}, - ) - pipeline = NapariImageSource( - raw, - raw_key, - gp.ArraySpec( - gp.Roi( - (0,) * num_spatial_dims, - raw.data.shape[-num_spatial_dims:], - ), - voxel_size=voxel_size, - ), - spatial_dims, - ) - - if num_samples == 0 and num_channels == 0: - pipeline += ( - gp.Pad(raw_key, context) - + gp.Unsqueeze([raw_key], 0) - + gp.Unsqueeze([raw_key], 0) - + predict - + gp.Scan(scan_request) - ) - elif num_samples != 0 and num_channels == 0: - pipeline += ( - gp.Pad(raw_key, context) - + gp.Unsqueeze([raw_key], 1) - + predict - + gp.Scan(scan_request) - ) - elif num_samples == 0 and num_channels != 0: - pipeline += ( - gp.Pad(raw_key, context) - + gp.Unsqueeze([raw_key], 0) - + predict - + gp.Scan(scan_request) - ) - elif num_samples != 0 and num_channels != 0: - pipeline += ( - gp.Pad(raw_key, context) + predict + gp.Scan(scan_request) - ) - - # request to pipeline for ROI of whole image/volume - request = gp.BatchRequest() - request.add(prediction_key, raw.data.shape[-num_spatial_dims:]) - counter = 0 - with gp.build(pipeline): - batch = pipeline.request_batch(request) - yield counter - counter += 0.1 - - prediction = batch.arrays[prediction_key].data - colormaps = ["red", "green", "blue"] - prediction_layers = [ - ( - prediction[:, i : i + 1, ...].copy(), - { - "name": "offset-" + "zyx"[num_spatial_dims - i] - if i < num_spatial_dims - else "std", - "colormap": colormaps[num_spatial_dims - i] - if i < num_spatial_dims - else "gray", - "blending": "additive", - }, - "image", - ) - for i in range(num_spatial_dims + 1) - ] - - foreground = np.zeros_like(prediction[:, 0:1, ...], dtype=bool) - for sample in range(prediction.shape[0]): - embeddings = prediction[sample] - embeddings_std = embeddings[-1, ...] - thresh = threshold_otsu(embeddings_std) - print(f"Threshold for sample {sample} is {thresh}") - binary_mask = embeddings_std < thresh - foreground[sample, 0, ...] = binary_mask - - labels = np.zeros_like(prediction[:, 0:1, ...], dtype=np.uint64) - for sample in range(prediction.shape[0]): - embeddings = prediction[sample] - embeddings_std = embeddings[-1, ...] - embeddings_mean = embeddings[np.newaxis, :num_spatial_dims, ...] - segmentation = mean_shift_segmentation( - embeddings_mean, - embeddings_std, - _segment_config["bandwidth"], - _segment_config["min_size"], - _segment_config["reduction_probability"], - ) - labels[sample, 0, ...] = segmentation - - pp_labels = np.zeros_like(prediction[:, 0:1, ...], dtype=np.uint64) - for sample in range(prediction.shape[0]): - segmentation = labels[sample, 0] - distance_foreground = dtedt(segmentation == 0) - expanded_mask = ( - distance_foreground < _segment_config["grow_distance"] - ) - distance_background = dtedt(expanded_mask) - segmentation[ - distance_background < _segment_config["shrink_distance"] - ] = 0 - pp_labels[sample, 0, ...] = segmentation - return ( - prediction_layers - + [(foreground, {"name": "Foreground"}, "labels")] - + [(labels, {"name": "Segmentation"}, "labels")] - + [(pp_labels, {"name": "Post Processed"}, "labels")] - )