diff --git a/LICENSE b/LICENSE
index a4e53b6..54e32df 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,22 +1,28 @@
+BSD 3-Clause License
-The MIT License (MIT)
+Copyright (c) 2023, Howard Hughes Medical Institute
-Copyright (c) 2023 William Patton
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
-The above copyright notice and this permission notice shall be included in
-all copies or substantial portions of the Software.
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-THE SOFTWARE.
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/README.md b/README.md
index 86de2c8..352e0aa 100644
--- a/README.md
+++ b/README.md
@@ -1,67 +1,79 @@
-# napari-cellulus
+
A napari plugin for cellulus
-[![License MIT](https://img.shields.io/pypi/l/napari-cellulus.svg?color=green)](https://github.com/funkelab/napari-cellulus/raw/main/LICENSE)
-[![PyPI](https://img.shields.io/pypi/v/napari-cellulus.svg?color=green)](https://pypi.org/project/napari-cellulus)
-[![Python Version](https://img.shields.io/pypi/pyversions/napari-cellulus.svg?color=green)](https://python.org)
-[![tests](https://github.com/funkelab/napari-cellulus/workflows/tests/badge.svg)](https://github.com/funkelab/napari-cellulus/actions)
-[![codecov](https://codecov.io/gh/funkelab/napari-cellulus/branch/main/graph/badge.svg)](https://codecov.io/gh/funkelab/napari-cellulus)
-[![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/napari-cellulus)](https://napari-hub.org/plugins/napari-cellulus)
+- **[Introduction](#introduction)**
+- **[Installation](#installation)**
+- **[Getting Started](#getting-started)**
+- **[Citation](#citation)**
+- **[Issues](#issues)**
-A Napari plugin for Cellulus: Unsupervised Learning of Object-Centric Embeddings for Cell Instance Segmentation in Microscopy Images
+### Introduction
-----------------------------------
+This repository hosts the code for the napari plugin built around **cellulus**, which was described in the **[preprint](https://arxiv.org/pdf/2310.08501.pdf)** titled **Unsupervised Learning of *Object-Centric Embeddings* for Cell Instance Segmentation in Microscopy Images**.
-This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template.
+*cellulus* is a deep learning based method which can be used to obtain instance-segmentation of objects in microscopy images in an unsupervised fashion i.e. requiring no ground truth labels during training.
-
+### Installation
-## Installation
+One could execute these lines of code below to create a new environment and install dependencies.
-You can install `napari-cellulus` via [pip]:
+1. Create a new environment called `napari-cellulus`:
- pip install napari-cellulus
+```bash
+conda create -y -n napari-cellulus python==3.9
+```
+2. Activate the newly-created environment:
+```
+conda activate napari-cellulus
+```
-To install latest development version :
+3a. If using a GPU, install pytorch cuda dependencies:
- pip install git+https://github.com/funkelab/napari-cellulus.git
+```bash
+conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia
+```
+3b. otherwise (if using a CPU or MPS), run:
-## Contributing
+```bash
+pip install torch torchvision
+```
-Contributions are very welcome. Tests can be run with [tox], please ensure
-the coverage at least stays the same before you submit a pull request.
+4. Install the package from github:
-## License
+```bash
+pip install git+https://github.com/funkelab/napari-cellulus.git
+```
-Distributed under the terms of the [MIT] license,
-"napari-cellulus" is free and open source software
+### Getting Started
-## Issues
+Run the following commands in a terminal window:
+```
+conda activate napari-cellulus
+napari
+```
-If you encounter any problems, please [file an issue] along with a detailed description.
+Next, select `Cellulus` from the `Plugins` drop-down menu.
-[napari]: https://github.com/napari/napari
-[Cookiecutter]: https://github.com/audreyr/cookiecutter
-[@napari]: https://github.com/napari
-[MIT]: http://opensource.org/licenses/MIT
-[BSD-3]: http://opensource.org/licenses/BSD-3-Clause
-[GNU GPL v3.0]: http://www.gnu.org/licenses/gpl-3.0.txt
-[GNU LGPL v3.0]: http://www.gnu.org/licenses/lgpl-3.0.txt
-[Apache Software License 2.0]: http://www.apache.org/licenses/LICENSE-2.0
-[Mozilla Public License 2.0]: https://www.mozilla.org/media/MPL/2.0/index.txt
-[cookiecutter-napari-plugin]: https://github.com/napari/cookiecutter-napari-plugin
+### Citation
-[file an issue]: https://github.com/funkelab/napari-cellulus/issues
+If you find our work useful in your research, please consider citing:
-[napari]: https://github.com/napari/napari
-[tox]: https://tox.readthedocs.io/en/latest/
-[pip]: https://pypi.org/project/pip/
-[PyPI]: https://pypi.org/
+
+```bibtex
+@misc{wolf2023unsupervised,
+ title={Unsupervised Learning of Object-Centric Embeddings for Cell Instance Segmentation in Microscopy Images},
+ author={Steffen Wolf and Manan Lalit and Henry Westmacott and Katie McDole and Jan Funke},
+ year={2023},
+ eprint={2310.08501},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG}
+}
+```
+
+### Issues
+
+If you encounter any problems, please **[file an issue](https://github.com/funkelab/napari-cellulus/issues)** along with a description.
diff --git a/setup.cfg b/setup.cfg
index 7280387..ce0031b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -5,9 +5,8 @@ description = A Napari plugin for Cellulus: Unsupervised Learning of Object-Cent
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/funkelab/napari-cellulus
-author = William Patton
-author_email = wllmpttn24@gmail.com
-license = MIT
+author = William Patton, Manan Lalit
+author_email = wllmpttn24@gmail.com, lalitm@janelia.hhmi.org
license_files = LICENSE
classifiers =
Development Status :: 2 - Pre-Alpha
@@ -31,19 +30,16 @@ project_urls =
[options]
packages = find:
install_requires =
- numpy
- magicgui
+ napari[all]
+ pyqtgraph
qtpy
cellulus @ git+https://github.com/funkelab/cellulus
- matplotlib
- torch
- gunpowder
-
+ pre-commit
python_requires = >=3.8
include_package_data = True
package_dir =
- =src
+ = src
# add your package requirements here
diff --git a/src/napari_cellulus/__init__.py b/src/napari_cellulus/__init__.py
index 47a9747..bf80da0 100644
--- a/src/napari_cellulus/__init__.py
+++ b/src/napari_cellulus/__init__.py
@@ -1,5 +1,5 @@
-__version__ = "0.0.3"
+__version__ = "0.1.0"
-from ._sample_data import tissuenet_sample
+from .sample_data import tissue_net_sample
-__all__ = ("tissuenet_sample",)
+__all__ = ("tissue_net_sample",)
diff --git a/src/napari_cellulus/_tests/test_samples.py b/src/napari_cellulus/_tests/test_samples.py
deleted file mode 100644
index 7952513..0000000
--- a/src/napari_cellulus/_tests/test_samples.py
+++ /dev/null
@@ -1,3 +0,0 @@
-def test_open(make_napari_viewer):
- viewer = make_napari_viewer()
- viewer.open_sample(plugin="napari-cellulus", sample="tissuenet_sample")
diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py
deleted file mode 100644
index 9b3e4ae..0000000
--- a/src/napari_cellulus/dataset.py
+++ /dev/null
@@ -1,171 +0,0 @@
-from typing import List
-
-import gunpowder as gp
-import numpy as np
-from napari.layers import Image
-from torch.utils.data import IterableDataset
-
-from .gp.nodes.napari_image_source import NapariImageSource
-from .meta_data import NapariDatasetMetaData
-
-
-class NapariDataset(IterableDataset): # type: ignore
- def __init__(
- self,
- layer: Image,
- axis_names: List[str],
- crop_size: int,
- control_point_spacing: int,
- control_point_jitter: float,
- ):
- """A dataset that serves random samples from a zarr container.
-
- Args:
-
- layer:
-
- The napari layer to use.
- The data should have shape `(s, c, [t,] [z,] y, x)`, where
- `s` = # of samples, `c` = # of channels, `t` = # of frames, and
- `z`/`y`/`x` are spatial extents. The dataset should have an
- `"axis_names"` attribute that contains the names of the used
- axes, e.g., `["s", "c", "y", "x"]` for a 2D dataset.
-
- axis_names:
-
- The names of the axes in the napari layer.
-
- crop_size:
-
- The size of data crops used during training (distinct from the
- "patch size" of the method: from each crop, multiple patches
- will be randomly selected and the loss computed on them). This
- should be equal to the input size of the model that predicts
- the OCEs.
-
- control_point_spacing:
-
- The distance in pixels between control points used for elastic
- deformation of the raw data.
-
- control_point_jitter:
-
- How much to jitter the control points for elastic deformation
- of the raw data, given as the standard deviation of a normal
- distribution with zero mean.
- """
-
- self.layer = layer
- self.axis_names = axis_names
- self.control_point_spacing = control_point_spacing
- self.control_point_jitter = control_point_jitter
- self.__read_meta_data()
- print(f"Number of spatial dims is {self.num_spatial_dims}")
- self.crop_size = (crop_size,) * self.num_spatial_dims
- self.__setup_pipeline()
-
- def __iter__(self):
- return iter(self.__yield_sample())
-
- def __setup_pipeline(self):
- self.raw = gp.ArrayKey("RAW")
- # treat all dimensions as spatial, with a voxel size = 1
- voxel_size = gp.Coordinate((1,) * self.num_dims)
- offset = gp.Coordinate((0,) * self.num_dims)
- shape = gp.Coordinate(self.layer.data.shape)
- raw_spec = gp.ArraySpec(
- roi=gp.Roi(offset, voxel_size * shape),
- dtype=np.float32,
- interpolatable=True,
- voxel_size=voxel_size,
- )
-
- if self.num_channels == 0 and self.num_samples == 0:
- self.pipeline = (
- NapariImageSource(
- self.layer, self.raw, raw_spec, self.spatial_dims
- )
- + gp.RandomLocation()
- + gp.Unsqueeze([self.raw], 0)
- + gp.Unsqueeze([self.raw], 0)
- )
- elif self.num_channels == 0 and self.num_samples != 0:
- self.pipeline = (
- NapariImageSource(
- self.layer, self.raw, raw_spec, self.spatial_dims
- )
- + gp.RandomLocation()
- + gp.Unsqueeze([self.raw], 1)
- )
- elif self.num_channels != 0 and self.num_samples == 0:
- self.pipeline = (
- NapariImageSource(
- self.layer, self.raw, raw_spec, self.spatial_dims
- )
- + gp.RandomLocation()
- + gp.Unsqueeze([self.raw], 0)
- )
- elif self.num_channels != 0 and self.num_samples != 0:
- self.pipeline = (
- NapariImageSource(
- self.layer, self.raw, raw_spec, self.spatial_dims
- )
- + gp.RandomLocation()
- )
-
- def __yield_sample(self):
- """An infinite generator of crops."""
-
- with gp.build(self.pipeline):
- while True:
- # request one sample, all channels, plus crop dimensions
- request = gp.BatchRequest()
- if self.num_channels == 0 and self.num_samples == 0:
- request[self.raw] = gp.ArraySpec(
- roi=gp.Roi(
- (0,) * (self.num_dims),
- self.crop_size,
- )
- )
- elif self.num_channels == 0 and self.num_samples != 0:
- request[self.raw] = gp.ArraySpec(
- roi=gp.Roi(
- (0,) * (self.num_dims), (1, *self.crop_size)
- )
- )
- elif self.num_channels != 0 and self.num_samples == 0:
- request[self.raw] = gp.ArraySpec(
- roi=gp.Roi(
- (0,) * (self.num_dims),
- (self.num_channels, *self.crop_size),
- )
- )
- elif self.num_channels != 0 and self.num_samples != 0:
- request[self.raw] = gp.ArraySpec(
- roi=gp.Roi(
- (0,) * (self.num_dims),
- (1, self.num_channels, *self.crop_size),
- )
- )
- sample = self.pipeline.request_batch(request)
- yield sample[self.raw].data[0]
-
- def __read_meta_data(self):
- meta_data = NapariDatasetMetaData(
- self.layer.data.shape, self.axis_names
- )
-
- self.num_dims = meta_data.num_dims
- self.num_spatial_dims = meta_data.num_spatial_dims
- self.num_channels = meta_data.num_channels
- self.num_samples = meta_data.num_samples
- self.sample_dim = meta_data.sample_dim
- self.channel_dim = meta_data.channel_dim
- self.time_dim = meta_data.time_dim
- self.spatial_dims = meta_data.spatial_dims
-
- def get_num_channels(self):
- return 1 if self.num_channels == 0 else self.num_channels
-
- def get_num_spatial_dims(self):
- return self.num_spatial_dims
diff --git a/src/napari_cellulus/_tests/__init__.py b/src/napari_cellulus/datasets/__init__.py
similarity index 100%
rename from src/napari_cellulus/_tests/__init__.py
rename to src/napari_cellulus/datasets/__init__.py
diff --git a/src/napari_cellulus/meta_data.py b/src/napari_cellulus/datasets/meta_data.py
similarity index 96%
rename from src/napari_cellulus/meta_data.py
rename to src/napari_cellulus/datasets/meta_data.py
index 7e8751e..35d69fe 100644
--- a/src/napari_cellulus/meta_data.py
+++ b/src/napari_cellulus/datasets/meta_data.py
@@ -11,7 +11,7 @@ def __init__(self, shape, axis_names):
self.channel_dim = None
self.time_dim = None
self.spatial_array: Tuple[int, ...] = ()
- self.spatial_dims = ()
+ self.spatial_dims: Tuple[int, ...] = ()
for dim, axis_name in enumerate(axis_names):
if axis_name == "s":
self.sample_dim = dim
diff --git a/src/napari_cellulus/datasets/napari_dataset.py b/src/napari_cellulus/datasets/napari_dataset.py
new file mode 100644
index 0000000..3952ab8
--- /dev/null
+++ b/src/napari_cellulus/datasets/napari_dataset.py
@@ -0,0 +1,244 @@
+from typing import List
+
+import gunpowder as gp
+import numpy as np
+from napari.layers import Image
+from torch.utils.data import IterableDataset
+
+from .meta_data import NapariDatasetMetaData
+from .napari_image_source import NapariImageSource
+
+
+class NapariDataset(IterableDataset):
+ def __init__(
+ self,
+ layer: Image,
+ axis_names: List[str],
+ crop_size: int,
+ density: float,
+ kappa: float,
+ normalization_factor: float,
+ ):
+ self.layer = layer
+ self.axis_names = axis_names
+ self.__read_meta_data()
+ self.crop_size = (crop_size,) * self.num_spatial_dims
+ self.normalization_factor = normalization_factor
+ self.density = density
+ self.kappa = kappa
+ self.output_shape = tuple(int(_ - 16) for _ in self.crop_size)
+ self.normalization_factor = normalization_factor
+ self.unbiased_shape = tuple(
+ int(_ - (2 * self.kappa)) for _ in self.output_shape
+ )
+ self.__setup_pipeline()
+
+ def __read_meta_data(self):
+ meta_data = NapariDatasetMetaData(
+ self.layer.data.shape, self.axis_names
+ )
+
+ self.num_dims = meta_data.num_dims
+ self.num_spatial_dims = meta_data.num_spatial_dims
+ self.num_channels = meta_data.num_channels
+ self.num_samples = meta_data.num_samples
+ self.sample_dim = meta_data.sample_dim
+ self.channel_dim = meta_data.channel_dim
+ self.time_dim = meta_data.time_dim
+ self.spatial_dims = meta_data.spatial_dims
+ self.spatial_array = meta_data.spatial_array
+
+ def __setup_pipeline(self):
+ self.raw = gp.ArrayKey("RAW")
+
+ # treat all dimensions as spatial, with a voxel size = 1
+ voxel_size = gp.Coordinate((1,) * self.num_dims)
+ offset = gp.Coordinate((0,) * self.num_dims)
+ shape = gp.Coordinate(self.layer.data.shape)
+ raw_spec = gp.ArraySpec(
+ roi=gp.Roi(offset, voxel_size * shape),
+ dtype=np.float32,
+ interpolatable=True,
+ voxel_size=voxel_size,
+ )
+ if self.num_samples == 0:
+ self.pipeline = (
+ NapariImageSource(
+ image=self.layer,
+ key=self.raw,
+ spec=raw_spec,
+ spatial_dims=self.spatial_dims,
+ )
+ + gp.Unsqueeze([self.raw], 0)
+ + gp.RandomLocation()
+ )
+ else:
+ self.pipeline = (
+ NapariImageSource(
+ image=self.layer,
+ key=self.raw,
+ spec=raw_spec,
+ spatial_dims=self.spatial_dims,
+ )
+ + gp.RandomLocation()
+ )
+
+ def __iter__(self):
+ return iter(self.__yield_sample())
+
+ def __yield_sample(self):
+ with gp.build(self.pipeline):
+ while True:
+ array_is_zero = True
+ while array_is_zero:
+ # request one sample, all channels, plus crop dimensions
+ request = gp.BatchRequest()
+ if self.num_channels == 0 and self.num_samples == 0:
+ request[self.raw] = gp.ArraySpec(
+ roi=gp.Roi(
+ (0,) * self.num_dims,
+ self.crop_size,
+ )
+ )
+ elif self.num_channels == 0 and self.num_samples != 0:
+ request[self.raw] = gp.ArraySpec(
+ roi=gp.Roi(
+ (0,) * self.num_dims, (1, *self.crop_size)
+ )
+ )
+ elif self.num_channels != 0 and self.num_samples == 0:
+ request[self.raw] = gp.ArraySpec(
+ roi=gp.Roi(
+ (0,) * self.num_dims,
+ (self.num_channels, *self.crop_size),
+ )
+ )
+ elif self.num_channels != 0 and self.num_samples != 0:
+ request[self.raw] = gp.ArraySpec(
+ roi=gp.Roi(
+ (0,) * self.num_dims,
+ (1, self.num_channels, *self.crop_size),
+ )
+ )
+ sample = self.pipeline.request_batch(request)
+ sample_data = sample[self.raw].data[0]
+ # if missing a channel, this is added by the Model class
+
+ if np.max(sample_data) <= 0.0:
+ pass
+ else:
+ array_is_zero = False
+ (
+ anchor_samples,
+ reference_samples,
+ ) = self.sample_coordinates()
+ yield sample_data, anchor_samples, reference_samples
+
+ def get_num_samples(self):
+ return self.num_samples
+
+ def get_num_channels(self):
+ return self.num_channels
+
+ def get_num_spatial_dims(self):
+ return self.num_spatial_dims
+
+ def get_num_dims(self):
+ return self.num_dims
+
+ def get_spatial_dims(self):
+ return self.spatial_dims
+
+ def get_spatial_array(self):
+ return self.spatial_array
+
+ def sample_offsets_within_radius(self, radius, number_offsets):
+ if self.num_spatial_dims == 2:
+ offsets_x = np.random.randint(
+ -radius, radius + 1, size=2 * number_offsets
+ )
+ offsets_y = np.random.randint(
+ -radius, radius + 1, size=2 * number_offsets
+ )
+ offsets_coordinates = np.stack((offsets_x, offsets_y), axis=1)
+ elif self.num_spatial_dims == 3:
+ offsets_x = np.random.randint(
+ -radius, radius + 1, size=3 * number_offsets
+ )
+ offsets_y = np.random.randint(
+ -radius, radius + 1, size=3 * number_offsets
+ )
+ offsets_z = np.random.randint(
+ -radius, radius + 1, size=3 * number_offsets
+ )
+ offsets_coordinates = np.stack(
+ (offsets_x, offsets_y, offsets_z), axis=1
+ )
+
+ in_circle = (offsets_coordinates**2).sum(axis=1) < radius**2
+ offsets_coordinates = offsets_coordinates[in_circle]
+ not_zero = np.absolute(offsets_coordinates).sum(axis=1) > 0
+ offsets_coordinates = offsets_coordinates[not_zero]
+
+ if len(offsets_coordinates) < number_offsets:
+ return self.sample_offsets_within_radius(radius, number_offsets)
+
+ return offsets_coordinates[:number_offsets]
+
+ def sample_coordinates(self):
+ num_anchors = self.get_num_anchors()
+ num_references = self.get_num_references()
+
+ if self.num_spatial_dims == 2:
+ anchor_coordinates_x = np.random.randint(
+ self.kappa,
+ self.output_shape[0] - self.kappa + 1,
+ size=num_anchors,
+ )
+ anchor_coordinates_y = np.random.randint(
+ self.kappa,
+ self.output_shape[1] - self.kappa + 1,
+ size=num_anchors,
+ )
+ anchor_coordinates = np.stack(
+ (anchor_coordinates_x, anchor_coordinates_y), axis=1
+ )
+ elif self.num_spatial_dims == 3:
+ anchor_coordinates_x = np.random.randint(
+ self.kappa,
+ self.output_shape[0] - self.kappa + 1,
+ size=num_anchors,
+ )
+ anchor_coordinates_y = np.random.randint(
+ self.kappa,
+ self.output_shape[1] - self.kappa + 1,
+ size=num_anchors,
+ )
+ anchor_coordinates_z = np.random.randint(
+ self.kappa,
+ self.output_shape[2] - self.kappa + 1,
+ size=num_anchors,
+ )
+ anchor_coordinates = np.stack(
+ (
+ anchor_coordinates_x,
+ anchor_coordinates_y,
+ anchor_coordinates_z,
+ ),
+ axis=1,
+ )
+ anchor_samples = np.repeat(anchor_coordinates, num_references, axis=0)
+ offset_in_pos_radius = self.sample_offsets_within_radius(
+ self.kappa, len(anchor_samples)
+ )
+ reference_samples = anchor_samples + offset_in_pos_radius
+
+ return anchor_samples, reference_samples
+
+ def get_num_anchors(self):
+ return int(
+ self.density * self.unbiased_shape[0] * self.unbiased_shape[1]
+ )
+
+ def get_num_references(self):
+ return int(self.density * self.kappa**2 * np.pi)
diff --git a/src/napari_cellulus/gp/nodes/napari_image_source.py b/src/napari_cellulus/datasets/napari_image_source.py
similarity index 75%
rename from src/napari_cellulus/gp/nodes/napari_image_source.py
rename to src/napari_cellulus/datasets/napari_image_source.py
index a13e996..bee332f 100644
--- a/src/napari_cellulus/gp/nodes/napari_image_source.py
+++ b/src/napari_cellulus/datasets/napari_image_source.py
@@ -7,7 +7,7 @@
class NapariImageSource(gp.BatchProvider):
"""
- A gunpowder interface to a napari Image
+ A gunpowder node to pull data from a napari Image
Args:
image (Image):
The napari image layer to pull data from
@@ -20,15 +20,11 @@ def __init__(
):
self.array_spec = spec
self.image = gp.Array(
- normalize(
- image.data.astype(np.float32),
- pmin=1,
- pmax=99.8,
- axis=spatial_dims,
+ data=normalize(
+ image.data.astype(np.float32), 1, 99.8, axis=spatial_dims
),
- self.array_spec,
+ spec=spec,
)
- self.spatial_dims = spatial_dims
self.key = key
def setup(self):
diff --git a/src/napari_cellulus/gp/__init__.py b/src/napari_cellulus/gp/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/napari_cellulus/gp/nodes/__init__.py b/src/napari_cellulus/gp/nodes/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/napari_cellulus/gui_helpers.py b/src/napari_cellulus/gui_helpers.py
deleted file mode 100644
index a4aa620..0000000
--- a/src/napari_cellulus/gui_helpers.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from magicgui.widgets import FunctionGui, create_widget
-from matplotlib.backends.backend_qt5agg import (
- FigureCanvasQTAgg,
-)
-from matplotlib.backends.backend_qt5agg import (
- NavigationToolbar2QT as NavigationToolbar,
-)
-from matplotlib.figure import Figure
-from qtpy import QtWidgets
-
-
-class MplCanvas(FigureCanvasQTAgg):
- def __init__(self, parent=None, width=5, height=4, dpi=100):
- fig = Figure(figsize=(width, height), dpi=dpi)
- self.axes = fig.add_subplot(111)
- super().__init__(fig)
- fig.set_tight_layout(True)
-
-
-class MainWindow(QtWidgets.QMainWindow):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- sc = MplCanvas(self, width=5, height=4, dpi=100)
-
- # Create toolbar, passing canvas as first parament, parent (self, the MainWindow) as second.
- toolbar = NavigationToolbar(sc, self)
-
- layout = QtWidgets.QVBoxLayout()
- layout.addWidget(toolbar)
- layout.addWidget(sc)
-
- # Create a placeholder widget to hold our toolbar and canvas.
- widget = QtWidgets.QWidget()
- widget.setLayout(layout)
- self.setCentralWidget(widget)
-
- self.show()
-
-
-def layer_choice_widget(viewer, annotation, **kwargs) -> FunctionGui:
- widget = create_widget(annotation=annotation, **kwargs)
- viewer.layers.events.inserted.connect(widget.reset_choices)
- viewer.layers.events.removed.connect(widget.reset_choices)
- return widget
diff --git a/src/napari_cellulus/napari.yaml b/src/napari_cellulus/napari.yaml
index 1c22723..39ac969 100644
--- a/src/napari_cellulus/napari.yaml
+++ b/src/napari_cellulus/napari.yaml
@@ -2,16 +2,22 @@ name: napari-cellulus
display_name: Cellulus
contributions:
commands:
- - id: napari-cellulus.tissuenet_sample
- python_name: napari_cellulus._sample_data:tissuenet_sample
+ - id: napari-cellulus.tissue_net_sample
+ python_name: napari_cellulus.sample_data:tissue_net_sample
title: Load sample data from Cellulus
- - id: napari-cellulus.SegmentationWidget
- python_name: napari_cellulus.widgets._widget:SegmentationWidget
+ - id: napari-cellulus.fluo_n2dl_hela_sample
+ python_name: napari_cellulus.sample_data:fluo_n2dl_hela_sample
+ title: Load sample data from Cellulus
+ - id: napari-cellulus.Widget
+ python_name: napari_cellulus.widget:Widget
title: Cellulus
sample_data:
- - command: napari-cellulus.tissuenet_sample
- display_name: Cellulus
- key: tissuenet_sample
+ - command: napari-cellulus.tissue_net_sample
+ display_name: TissueNet
+ key: tissue_net_sample
+ - command: napari-cellulus.fluo_n2dl_hela_sample
+ display_name: Fluo-N2DL-HeLa
+ key: fluo_n2dl_hela_sample
widgets:
- - command: napari-cellulus.SegmentationWidget
+ - command: napari-cellulus.Widget
display_name: Cellulus
diff --git a/src/napari_cellulus/_sample_data.py b/src/napari_cellulus/sample_data.py
similarity index 50%
rename from src/napari_cellulus/_sample_data.py
rename to src/napari_cellulus/sample_data.py
index 159abd1..2a4efc0 100644
--- a/src/napari_cellulus/_sample_data.py
+++ b/src/napari_cellulus/sample_data.py
@@ -1,12 +1,28 @@
from pathlib import Path
import numpy as np
+import tifffile
-TISSUENET_SAMPLE = Path(__file__).parent / "sample_data/tissuenet-sample.npy"
+TISSUE_NET_SAMPLE = Path(__file__).parent / "sample_data/tissue_net_sample.npy"
+FLUO_N2DL_HELA = Path(__file__).parent / "sample_data/fluo_n2dl_hela.tif"
-def tissuenet_sample():
- (x, y) = np.load(TISSUENET_SAMPLE, "r")
+def fluo_n2dl_hela_sample():
+ x = tifffile.imread(FLUO_N2DL_HELA)
+ return [
+ (
+ x,
+ {
+ "name": "Raw",
+ "metadata": {"axes": ["s", "c", "y", "x"]},
+ },
+ "image",
+ )
+ ]
+
+
+def tissue_net_sample():
+ (x, y) = np.load(TISSUE_NET_SAMPLE, "r")
x = x.transpose(0, 3, 1, 2)
y = y.transpose(0, 3, 1, 2).astype(np.uint8)
return [
diff --git a/src/napari_cellulus/sample_data/fluo_n2dl_hela.tif b/src/napari_cellulus/sample_data/fluo_n2dl_hela.tif
new file mode 100644
index 0000000..f28870c
Binary files /dev/null and b/src/napari_cellulus/sample_data/fluo_n2dl_hela.tif differ
diff --git a/src/napari_cellulus/sample_data/tissuenet-sample.npy b/src/napari_cellulus/sample_data/tissue_net_sample.npy
similarity index 100%
rename from src/napari_cellulus/sample_data/tissuenet-sample.npy
rename to src/napari_cellulus/sample_data/tissue_net_sample.npy
diff --git a/src/napari_cellulus/widget.py b/src/napari_cellulus/widget.py
new file mode 100644
index 0000000..bbad8b8
--- /dev/null
+++ b/src/napari_cellulus/widget.py
@@ -0,0 +1,950 @@
+from pathlib import Path
+
+import gunpowder as gp
+import numpy as np
+import pyqtgraph as pg
+import torch
+from attrs import asdict
+from cellulus.configs.experiment_config import ExperimentConfig
+from cellulus.configs.inference_config import InferenceConfig
+from cellulus.configs.model_config import ModelConfig
+from cellulus.configs.train_config import TrainConfig
+from cellulus.criterions import get_loss
+from cellulus.models import get_model
+from cellulus.train import train_iteration
+from cellulus.utils.mean_shift import mean_shift_segmentation
+from cellulus.utils.misc import size_filter
+from napari.qt.threading import thread_worker
+from napari.utils.events import Event
+from qtpy.QtCore import Qt
+from qtpy.QtWidgets import (
+ QButtonGroup,
+ QCheckBox,
+ QComboBox,
+ QGridLayout,
+ QLabel,
+ QLineEdit,
+ QMainWindow,
+ QPushButton,
+ QRadioButton,
+ QScrollArea,
+ QVBoxLayout,
+)
+from scipy.ndimage import binary_fill_holes
+from scipy.ndimage import distance_transform_edt as dtedt
+from skimage.filters import threshold_otsu
+from tqdm import tqdm
+
+from .datasets.napari_dataset import NapariDataset
+from .datasets.napari_image_source import NapariImageSource
+
+
+class Model(torch.nn.Module):
+ def __init__(self, model, selected_axes):
+ super().__init__()
+ self.model = model
+ self.selected_axes = selected_axes
+
+ def forward(self, x):
+ if "s" in self.selected_axes and "c" in self.selected_axes:
+ pass
+ elif "s" in self.selected_axes and "c" not in self.selected_axes:
+
+ x = torch.unsqueeze(x, 1)
+ elif "s" not in self.selected_axes and "c" in self.selected_axes:
+ pass
+ elif "s" not in self.selected_axes and "c" not in self.selected_axes:
+ x = torch.unsqueeze(x, 1)
+ return self.model(x)
+
+ @staticmethod
+ def select_and_add_coordinates(outputs, coordinates):
+ selections = []
+ # outputs.shape = (b, c, h, w) or (b, c, d, h, w)
+ for output, coordinate in zip(outputs, coordinates):
+ if output.ndim == 3:
+ selection = output[:, coordinate[:, 1], coordinate[:, 0]]
+ elif output.ndim == 4:
+ selection = output[
+ :, coordinate[:, 2], coordinate[:, 1], coordinate[:, 0]
+ ]
+ selection = selection.transpose(1, 0)
+ selection += coordinate
+ selections.append(selection)
+
+ # selection.shape = (b, c, p) where p is the number of selected positions
+ return torch.stack(selections, dim=0)
+
+ def set_infer(self, p_salt_pepper, num_infer_iterations, device):
+ self.model.eval()
+ self.model.set_infer(p_salt_pepper, num_infer_iterations, device)
+
+
+class Widget(QMainWindow):
+ def __init__(self, napari_viewer):
+ super().__init__()
+ self.viewer = napari_viewer
+ self.scroll = QScrollArea()
+ # initialize outer layout
+ layout = QVBoxLayout()
+
+ # initialize individual grid layouts from top to bottom
+ self.grid_0 = QGridLayout() # title
+ self.set_grid_0()
+ self.grid_1 = QGridLayout() # device
+ self.set_grid_1()
+ self.grid_2 = QGridLayout() # raw image selector
+ self.set_grid_2()
+ self.grid_3 = QGridLayout() # train configs
+ self.set_grid_3()
+ self.grid_4 = QGridLayout() # model configs
+ self.set_grid_4()
+ self.grid_5 = QGridLayout() # loss plot and train/stop button
+ self.set_grid_5()
+ self.grid_6 = QGridLayout() # inference
+ self.set_grid_6()
+ self.grid_7 = QGridLayout() # feedback
+ self.set_grid_7()
+ self.create_configs() # configs
+ self.viewer.dims.events.current_step.connect(
+ self.update_inference_widgets
+ ) # listen to viewer slider
+
+ layout.addLayout(self.grid_0)
+ layout.addLayout(self.grid_1)
+ layout.addLayout(self.grid_2)
+ layout.addLayout(self.grid_3)
+ layout.addLayout(self.grid_4)
+ layout.addLayout(self.grid_5)
+ layout.addLayout(self.grid_6)
+ layout.addLayout(self.grid_7)
+ self.set_scroll_area(layout)
+ self.viewer.layers.events.inserted.connect(self.update_raw_selector)
+ self.viewer.layers.events.removed.connect(self.update_raw_selector)
+
+ def update_raw_selector(self, event):
+ count = 0
+ for i in range(self.raw_selector.count() - 1, -1, -1):
+ if self.raw_selector.itemText(i) == f"{event.value}":
+ # remove item
+ self.raw_selector.removeItem(i)
+ count = 1
+ if count == 0:
+ self.raw_selector.addItems([f"{event.value}"])
+
+ def set_grid_0(self):
+ text_label = QLabel("Cellulus
")
+ method_description_label = QLabel(
+ 'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.
https://github.com/funkelab/cellulus'
+ )
+ self.grid_0.addWidget(text_label, 0, 0, 1, 1)
+ self.grid_0.addWidget(method_description_label, 1, 0, 2, 1)
+
+ def set_grid_1(self):
+ device_label = QLabel(self)
+ device_label.setText("Device")
+ self.device_combo_box = QComboBox(self)
+ self.device_combo_box.addItem("cpu")
+ self.device_combo_box.addItem("cuda:0")
+ self.device_combo_box.addItem("mps")
+ self.device_combo_box.setCurrentText("mps")
+ self.grid_1.addWidget(device_label, 0, 0, 1, 1)
+ self.grid_1.addWidget(self.device_combo_box, 0, 1, 1, 1)
+
+ def set_grid_2(self):
+ self.raw_selector = QComboBox(self)
+ for layer in self.viewer.layers:
+ self.raw_selector.addItem(f"{layer}")
+ self.grid_2.addWidget(self.raw_selector, 0, 0, 1, 5)
+ # Initialize Checkboxes
+ self.s_check_box = QCheckBox("s/t")
+ self.c_check_box = QCheckBox("c")
+ self.z_check_box = QCheckBox("z")
+ self.y_check_box = QCheckBox("y")
+ self.x_check_box = QCheckBox("x")
+ self.grid_2.addWidget(self.s_check_box, 1, 0, 1, 1)
+ self.grid_2.addWidget(self.c_check_box, 1, 1, 1, 1)
+ self.grid_2.addWidget(self.z_check_box, 1, 2, 1, 1)
+ self.grid_2.addWidget(self.y_check_box, 1, 3, 1, 1)
+ self.grid_2.addWidget(self.x_check_box, 1, 4, 1, 1)
+
+ def set_grid_3(self):
+ crop_size_label = QLabel(self)
+ crop_size_label.setText("Crop Size")
+ self.crop_size_line = QLineEdit(self)
+ self.crop_size_line.setAlignment(Qt.AlignCenter)
+ self.crop_size_line.setText("252")
+ batch_size_label = QLabel(self)
+ batch_size_label.setText("Batch Size")
+ self.batch_size_line = QLineEdit(self)
+ self.batch_size_line.setAlignment(Qt.AlignCenter)
+ self.batch_size_line.setText("8")
+ max_iterations_label = QLabel(self)
+ max_iterations_label.setText("Max iterations")
+ self.max_iterations_line = QLineEdit(self)
+ self.max_iterations_line.setAlignment(Qt.AlignCenter)
+ self.max_iterations_line.setText("100000")
+ self.grid_3.addWidget(crop_size_label, 0, 0, 1, 1)
+ self.grid_3.addWidget(self.crop_size_line, 0, 1, 1, 1)
+ self.grid_3.addWidget(batch_size_label, 1, 0, 1, 1)
+ self.grid_3.addWidget(self.batch_size_line, 1, 1, 1, 1)
+ self.grid_3.addWidget(max_iterations_label, 2, 0, 1, 1)
+ self.grid_3.addWidget(self.max_iterations_line, 2, 1, 1, 1)
+
+ def set_grid_4(self):
+ feature_maps_label = QLabel(self)
+ feature_maps_label.setText("Number of feature maps")
+ self.feature_maps_line = QLineEdit(self)
+ self.feature_maps_line.setAlignment(Qt.AlignCenter)
+ self.feature_maps_line.setText("24")
+ feature_maps_increase_label = QLabel(self)
+ feature_maps_increase_label.setText("Feature maps inc. factor")
+ self.feature_maps_increase_line = QLineEdit(self)
+ self.feature_maps_increase_line.setAlignment(Qt.AlignCenter)
+ self.feature_maps_increase_line.setText("3")
+ self.train_model_from_scratch_checkbox = QCheckBox(
+ "Train model from scratch"
+ )
+
+ self.train_model_from_scratch_checkbox.setChecked(False)
+ self.grid_4.addWidget(feature_maps_label, 0, 0, 1, 1)
+ self.grid_4.addWidget(self.feature_maps_line, 0, 1, 1, 1)
+ self.grid_4.addWidget(feature_maps_increase_label, 1, 0, 1, 1)
+ self.grid_4.addWidget(self.feature_maps_increase_line, 1, 1, 1, 1)
+ self.grid_4.addWidget(
+ self.train_model_from_scratch_checkbox, 2, 0, 1, 2
+ )
+
+ def set_grid_5(self):
+ self.losses_widget = pg.PlotWidget()
+ self.losses_widget.setBackground((37, 41, 49))
+ styles = {"color": "white", "font-size": "16px"}
+ self.losses_widget.setLabel("left", "Loss", **styles)
+ self.losses_widget.setLabel("bottom", "Iterations", **styles)
+ self.start_training_button = QPushButton("Start training")
+ self.start_training_button.setFixedSize(140, 30)
+ self.stop_training_button = QPushButton("Stop training")
+ self.stop_training_button.setFixedSize(140, 30)
+
+ self.grid_5.addWidget(self.losses_widget, 0, 0, 4, 4)
+ self.grid_5.addWidget(self.start_training_button, 5, 0, 1, 2)
+ self.grid_5.addWidget(self.stop_training_button, 5, 2, 1, 2)
+ self.start_training_button.clicked.connect(
+ self.prepare_for_start_training
+ )
+ self.stop_training_button.clicked.connect(
+ self.prepare_for_stop_training
+ )
+
+ def set_grid_6(self):
+ threshold_label = QLabel("Threshold")
+ self.threshold_line = QLineEdit(self)
+ self.threshold_line.setAlignment(Qt.AlignCenter)
+ self.threshold_line.setText(None)
+
+ bandwidth_label = QLabel("Bandwidth")
+ self.bandwidth_line = QLineEdit(self)
+ self.bandwidth_line.setAlignment(Qt.AlignCenter)
+
+ self.radio_button_group = QButtonGroup(self)
+ self.radio_button_cell = QRadioButton("Cell")
+ self.radio_button_nucleus = QRadioButton("Nucleus")
+ self.radio_button_group.addButton(self.radio_button_nucleus)
+ self.radio_button_group.addButton(self.radio_button_cell)
+
+ self.radio_button_nucleus.setChecked(True)
+ self.min_size_label = QLabel("Minimum Size")
+ self.min_size_line = QLineEdit(self)
+ self.min_size_line.setAlignment(Qt.AlignCenter)
+ self.start_inference_button = QPushButton("Start inference")
+ self.start_inference_button.setFixedSize(140, 30)
+ self.stop_inference_button = QPushButton("Stop inference")
+ self.stop_inference_button.setFixedSize(140, 30)
+
+ self.grid_6.addWidget(threshold_label, 0, 0, 1, 1)
+ self.grid_6.addWidget(self.threshold_line, 0, 1, 1, 1)
+ self.grid_6.addWidget(bandwidth_label, 1, 0, 1, 1)
+ self.grid_6.addWidget(self.bandwidth_line, 1, 1, 1, 1)
+ self.grid_6.addWidget(self.radio_button_cell, 2, 0, 1, 1)
+ self.grid_6.addWidget(self.radio_button_nucleus, 2, 1, 1, 1)
+ self.grid_6.addWidget(self.min_size_label, 3, 0, 1, 1)
+ self.grid_6.addWidget(self.min_size_line, 3, 1, 1, 1)
+ self.grid_6.addWidget(self.start_inference_button, 4, 0, 1, 1)
+ self.grid_6.addWidget(self.stop_inference_button, 4, 1, 1, 1)
+ self.start_inference_button.clicked.connect(
+ self.prepare_for_start_inference
+ )
+ self.stop_inference_button.clicked.connect(
+ self.prepare_for_stop_inference
+ )
+
+ def set_grid_7(self):
+ # Initialize Feedback Button
+ feedback_label = QLabel(
+ 'Please share any feedback here.'
+ )
+ self.grid_7.addWidget(feedback_label, 0, 0, 2, 1)
+
+ def set_scroll_area(self, layout):
+ self.scroll.setLayout(layout)
+ self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
+ self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
+ self.scroll.setWidgetResizable(True)
+
+ self.setFixedWidth(300)
+ self.setCentralWidget(self.scroll)
+
+ def get_selected_axes(self):
+ names = []
+ for name, check_box in zip(
+ "sczyx",
+ [
+ self.s_check_box,
+ self.c_check_box,
+ self.z_check_box,
+ self.y_check_box,
+ self.x_check_box,
+ ],
+ ):
+ if check_box.isChecked():
+ names.append(name)
+
+ return names
+
+ def create_configs(self):
+ self.train_config = TrainConfig(
+ crop_size=[int(self.crop_size_line.text())],
+ batch_size=int(self.batch_size_line.text()),
+ max_iterations=int(self.max_iterations_line.text()),
+ device=self.device_combo_box.currentText(),
+ )
+ self.model_config = ModelConfig(
+ num_fmaps=int(self.feature_maps_line.text()),
+ fmap_inc_factor=int(self.feature_maps_increase_line.text()),
+ )
+
+ self.experiment_config = ExperimentConfig(
+ train_config=asdict(self.train_config),
+ model_config=asdict(self.model_config),
+ )
+ self.losses, self.iterations = [], []
+ self.start_iteration = 0
+ self.model_dir = "/tmp/models"
+ self.thresholds = []
+ self.band_widths = []
+ self.min_sizes = []
+ if len(self.thresholds) == 0:
+ self.threshold_line.setEnabled(False)
+ if len(self.band_widths) == 0:
+ self.bandwidth_line.setEnabled(False)
+ if len(self.min_sizes) == 0:
+ self.min_size_line.setEnabled(False)
+
+ def update_inference_widgets(self, event: Event):
+ if self.s_check_box.isChecked():
+ shape = event.value
+ sample_index = shape[0]
+ if len(self.thresholds) == self.napari_dataset.get_num_samples():
+ if self.thresholds[sample_index]!=None:
+ self.threshold_line.setText(
+ str(round(self.thresholds[sample_index], 3))
+ )
+ if len(self.band_widths) == self.napari_dataset.get_num_samples():
+ if self.band_widths[sample_index]!=None:
+ self.bandwidth_line.setText(
+ str(round(self.band_widths[sample_index], 3))
+ )
+ if len(self.min_sizes) == self.napari_dataset.get_num_samples():
+ if self.min_sizes[sample_index]!=None:
+ self.min_size_line.setText(
+ str(round(self.min_sizes[sample_index], 3))
+ )
+
+ def prepare_for_start_training(self):
+ self.start_training_button.setEnabled(False)
+ self.stop_training_button.setEnabled(True)
+ self.threshold_line.setEnabled(False)
+ self.bandwidth_line.setEnabled(False)
+ self.radio_button_nucleus.setEnabled(False)
+ self.radio_button_cell.setEnabled(False)
+ self.min_size_line.setEnabled(False)
+ self.start_inference_button.setEnabled(False)
+ self.stop_inference_button.setEnabled(False)
+
+ self.train_worker = self.train()
+ self.train_worker.yielded.connect(self.on_yield_training)
+ self.train_worker.start()
+
+ @thread_worker
+ def train(self):
+ for layer in self.viewer.layers:
+ if f"{layer}" == self.raw_selector.currentText():
+ raw_image_layer = layer
+ break
+
+ if not Path(self.model_dir).exists():
+ Path(self.model_dir).mkdir()
+
+ # Turn layer into dataset
+ self.napari_dataset = NapariDataset(
+ layer=raw_image_layer,
+ axis_names=self.get_selected_axes(),
+ crop_size=self.train_config.crop_size[0], # list to integer
+ density=self.train_config.density,
+ kappa=self.train_config.kappa,
+ normalization_factor=self.experiment_config.normalization_factor,
+ )
+ # Create dataloader
+ train_dataloader = torch.utils.data.DataLoader(
+ dataset=self.napari_dataset,
+ batch_size=self.train_config.batch_size,
+ drop_last=True,
+ num_workers=self.train_config.num_workers,
+ pin_memory=True,
+ )
+ # Set model
+ model_original = get_model(
+ in_channels=self.napari_dataset.get_num_channels()
+ if self.napari_dataset.get_num_channels() != 0
+ else 1,
+ out_channels=self.napari_dataset.get_num_spatial_dims(),
+ num_fmaps=self.model_config.num_fmaps,
+ fmap_inc_factor=self.model_config.fmap_inc_factor,
+ features_in_last_layer=self.model_config.features_in_last_layer,
+ downsampling_factors=[
+ tuple(factor)
+ for factor in self.model_config.downsampling_factors
+ ],
+ num_spatial_dims=self.napari_dataset.get_num_spatial_dims(),
+ )
+
+ # Set device
+ self.device = torch.device(self.train_config.device)
+ model = Model(
+ model=model_original, selected_axes=self.get_selected_axes()
+ )
+ self.model = model.to(self.device)
+
+ # Initialize model weights
+ if self.model_config.initialize:
+ for _name, layer in self.model.named_modules():
+ if isinstance(layer, torch.nn.modules.conv._ConvNd):
+ torch.nn.init.kaiming_normal_(
+ layer.weight, nonlinearity="relu"
+ )
+
+ # Set loss
+ criterion = get_loss(
+ regularizer_weight=self.train_config.regularizer_weight,
+ temperature=self.train_config.temperature,
+ density=self.train_config.density,
+ num_spatial_dims=self.napari_dataset.get_num_spatial_dims(),
+ device=self.device,
+ )
+
+ # Set optimizer
+ self.optimizer = torch.optim.Adam(
+ self.model.parameters(),
+ lr=self.train_config.initial_learning_rate,
+ weight_decay=0.01,
+ )
+
+ # Resume training
+ if self.train_model_from_scratch_checkbox.isChecked():
+ self.losses, self.iterations = [], []
+ self.start_iteration = 0
+ self.losses_widget.clear()
+
+ else:
+ if self.model_config.checkpoint is None:
+ pass
+ else:
+ print(f"Resuming model from {self.model_config.checkpoint}")
+ state = torch.load(
+ self.model_config.checkpoint, map_location=self.device
+ )
+ self.start_iteration = state["iterations"][-1] + 1
+ self.model.load_state_dict(
+ state["model_state_dict"], strict=True
+ )
+ self.optimizer.load_state_dict(state["optim_state_dict"])
+ self.losses, self.iterations = (
+ state["losses"],
+ state["iterations"],
+ )
+
+ # Call Train Iteration
+ for iteration, batch in tqdm(
+ zip(
+ range(self.start_iteration, self.train_config.max_iterations),
+ train_dataloader,
+ )
+ ):
+ loss, oce_loss, prediction = train_iteration(
+ batch,
+ model=self.model,
+ criterion=criterion,
+ optimizer=self.optimizer,
+ device=self.device,
+ )
+ yield loss, iteration
+
+ def on_yield_training(self, loss_iteration):
+ loss, iteration = loss_iteration
+ print(f"===> Iteration: {iteration}, loss: {loss:.6f}")
+ self.iterations.append(iteration)
+ self.losses.append(loss)
+ self.losses_widget.plot(self.iterations, self.losses)
+
+ def prepare_for_stop_training(self):
+ self.start_training_button.setEnabled(True)
+ self.stop_training_button.setEnabled(True)
+ if len(self.thresholds) == 0:
+ self.threshold_line.setEnabled(False)
+ else:
+ self.threshold_line.setEnabled(True)
+ if len(self.band_widths) == 0:
+ self.bandwidth_line.setEnabled(False)
+ else:
+ self.bandwidth_line.setEnabled(True)
+ self.radio_button_nucleus.setEnabled(True)
+ self.radio_button_cell.setEnabled(True)
+ if len(self.min_sizes) == 0:
+ self.min_size_line.setEnabled(False)
+ else:
+ self.min_size_line.setEnabled(True)
+ self.start_inference_button.setEnabled(True)
+ self.stop_inference_button.setEnabled(True)
+ if self.train_worker is not None:
+ state = {
+ "model_state_dict": self.model.state_dict(),
+ "optim_state_dict": self.optimizer.state_dict(),
+ "iterations": self.iterations,
+ "losses": self.losses,
+ }
+ checkpoint_file_name = Path("/tmp/models") / "last.pth"
+ torch.save(state, checkpoint_file_name)
+ self.train_worker.quit()
+ self.model_config.checkpoint = checkpoint_file_name
+
+ def prepare_for_start_inference(self):
+ self.start_training_button.setEnabled(False)
+ self.stop_training_button.setEnabled(False)
+ self.threshold_line.setEnabled(False)
+ self.bandwidth_line.setEnabled(False)
+ self.radio_button_nucleus.setEnabled(False)
+ self.radio_button_cell.setEnabled(False)
+ self.min_size_line.setEnabled(False)
+ self.start_inference_button.setEnabled(False)
+ self.stop_inference_button.setEnabled(True)
+
+ self.inference_config = InferenceConfig(
+ crop_size=[min(self.napari_dataset.get_spatial_array()) + 16],
+ post_processing="cell"
+ if self.radio_button_cell.isChecked()
+ else "nucleus",
+ )
+
+ self.inference_worker = self.infer()
+ # self.inference_worker.yielded.connect(self.on_yield_infer)
+ self.inference_worker.returned.connect(self.on_return_infer)
+ self.inference_worker.start()
+
+ def prepare_for_stop_inference(self):
+ self.start_training_button.setEnabled(True)
+ self.stop_training_button.setEnabled(True)
+ self.threshold_line.setEnabled(True)
+ self.bandwidth_line.setEnabled(True)
+ self.radio_button_nucleus.setEnabled(True)
+ self.radio_button_cell.setEnabled(True)
+ self.min_size_line.setEnabled(True)
+ self.start_inference_button.setEnabled(True)
+ self.stop_inference_button.setEnabled(True)
+ if self.napari_dataset.get_num_samples() == 0:
+ self.threshold_line.setText(str(round(self.thresholds[0], 3)))
+ self.bandwidth_line.setText(str(round(self.band_widths[0], 3)))
+ self.min_size_line.setText(str(round(self.min_sizes[0], 3)))
+
+ @thread_worker
+ def infer(self):
+ for layer in self.viewer.layers:
+ if f"{layer}" == self.raw_selector.currentText():
+ raw_image_layer = layer
+ break
+
+ self.thresholds = (
+ [None] * self.napari_dataset.get_num_samples()
+ if self.napari_dataset.get_num_samples() != 0
+ else [None] * 1
+ )
+ if (
+ self.inference_config.bandwidth is None
+ and len(self.band_widths) == 0
+ ):
+ self.band_widths = (
+ [0.5 * self.experiment_config.object_size]
+ * self.napari_dataset.get_num_samples()
+ if self.napari_dataset.get_num_samples() != 0
+ else [0.5 * self.experiment_config.object_size]
+ )
+
+ if self.inference_config.min_size is None and len(self.min_sizes) == 0:
+ if self.napari_dataset.get_num_spatial_dims() == 2:
+ self.min_sizes = (
+ [
+ int(
+ 0.1
+ * np.pi
+ * (self.experiment_config.object_size**2)
+ / 4
+ )
+ ]
+ * self.napari_dataset.get_num_samples()
+ if self.napari_dataset.get_num_samples() != 0
+ else [
+ int(
+ 0.1
+ * np.pi
+ * (self.experiment_config.object_size**2)
+ / 4
+ )
+ ]
+ )
+ elif (
+ self.napari_dataset.get_num_spatial_dims() == 3
+ and len(self.min_sizes) == 0
+ ):
+ self.min_sizes = (
+ [
+ int(
+ 0.1
+ * 4.0
+ / 3.0
+ * np.pi
+ * (self.experiment_config.object_size**3)
+ / 8
+ )
+ ]
+ * self.napari_dataset.get_num_samples()
+ if self.napari_dataset.get_num_samples() != 0
+ else [
+ int(
+ 0.1
+ * 4.0
+ / 3.0
+ * np.pi
+ * (self.experiment_config.object_size**3)
+ / 8
+ )
+ ]
+ )
+
+ # set in eval mode
+ self.model.eval()
+ self.model.set_infer(
+ p_salt_pepper=self.inference_config.p_salt_pepper,
+ num_infer_iterations=self.inference_config.num_infer_iterations,
+ device=self.device,
+ )
+
+ if self.napari_dataset.get_num_spatial_dims() == 2:
+ crop_size_tuple = (self.inference_config.crop_size[0],) * 2
+
+ elif self.napari_dataset.get_num_spatial_dims() == 3:
+ crop_size_tuple = (self.inference_config.crop_size[0],) * 3
+
+ input_shape = gp.Coordinate(
+ (
+ 1,
+ self.napari_dataset.get_num_channels()
+ if self.napari_dataset.get_num_channels() != 0
+ else 1,
+ *crop_size_tuple,
+ )
+ )
+
+ if self.napari_dataset.get_num_channels() == 0:
+ output_shape = gp.Coordinate(
+ self.model(
+ torch.zeros((1, *crop_size_tuple), dtype=torch.float32).to(
+ self.device
+ )
+ ).shape
+ )
+ else:
+ output_shape = gp.Coordinate(
+ self.model(
+ torch.zeros(
+ (1, 1, *crop_size_tuple), dtype=torch.float32
+ ).to(self.device)
+ ).shape
+ )
+
+ voxel_size = (
+ (1,) * 2
+ if self.napari_dataset.get_num_spatial_dims() == 2
+ else (1,) * 3
+ )
+
+ input_size = gp.Coordinate(input_shape[2:]) * gp.Coordinate(voxel_size)
+ output_size = gp.Coordinate(output_shape[2:]) * gp.Coordinate(
+ voxel_size
+ )
+ context = (input_size - output_size) // 2
+ raw = gp.ArrayKey("RAW")
+ prediction = gp.ArrayKey("PREDICT")
+ scan_request = gp.BatchRequest()
+
+ # scan_request.add(raw, input_size)
+ scan_request[raw] = gp.Roi(
+ (-8,) * (self.napari_dataset.get_num_spatial_dims()),
+ crop_size_tuple,
+ )
+ scan_request.add(prediction, output_size)
+ predict = gp.torch.Predict(
+ self.model,
+ inputs={"x": raw},
+ outputs={0: prediction},
+ array_specs={prediction: gp.ArraySpec(voxel_size=voxel_size)},
+ )
+
+ pipeline = NapariImageSource(
+ image=raw_image_layer,
+ key=raw,
+ spec=gp.ArraySpec(
+ gp.Roi(
+ (0,) * self.napari_dataset.get_num_spatial_dims(),
+ raw_image_layer.data.shape[
+ -self.napari_dataset.get_num_spatial_dims() :
+ ],
+ ),
+ voxel_size=(1,) * self.napari_dataset.get_num_spatial_dims(),
+ ),
+ spatial_dims=self.napari_dataset.get_spatial_dims(),
+ )
+
+ if self.napari_dataset.get_num_samples() == 0:
+ pipeline += (
+ gp.Pad(raw, context, mode="reflect")
+ + gp.Unsqueeze([raw], 0)
+ + predict
+ + gp.Scan(scan_request)
+ )
+ else:
+ pipeline += (
+ gp.Pad(raw, context, mode="reflect")
+ + predict
+ + gp.Scan(scan_request)
+ )
+
+ request = gp.BatchRequest()
+ request.add(
+ prediction,
+ raw_image_layer.data.shape[
+ -self.napari_dataset.get_num_spatial_dims() :
+ ],
+ )
+
+ # Obtain Embeddings
+ print("Predicting Embeddings ...")
+
+ with gp.build(pipeline):
+ batch = pipeline.request_batch(request)
+
+ embeddings = batch.arrays[prediction].data
+ embeddings_centered = np.zeros_like(embeddings)
+ foreground_mask = np.zeros_like(embeddings[:, 0:1, ...], dtype=bool)
+ colormaps = ["red", "green", "blue"]
+
+ # Obtain Object Centered Embeddings
+ for sample in tqdm(range(embeddings.shape[0])):
+ embeddings_sample = embeddings[sample]
+ embeddings_std = embeddings_sample[-1, ...]
+ embeddings_mean = embeddings_sample[
+ np.newaxis, : self.napari_dataset.get_num_spatial_dims(), ...
+ ].copy()
+ threshold = threshold_otsu(embeddings_std)
+
+ self.thresholds[sample] = threshold
+ binary_mask = embeddings_std < threshold
+ foreground_mask[sample] = binary_mask[np.newaxis, ...]
+ embeddings_centered_sample = embeddings_sample.copy()
+ embeddings_mean_masked = (
+ binary_mask[np.newaxis, np.newaxis, ...] * embeddings_mean
+ )
+ if embeddings_centered_sample.shape[0] == 3:
+ c_x = embeddings_mean_masked[0, 0]
+ c_y = embeddings_mean_masked[0, 1]
+ c_x = c_x[c_x != 0].mean()
+ c_y = c_y[c_y != 0].mean()
+ embeddings_centered_sample[0] -= c_x
+ embeddings_centered_sample[1] -= c_y
+ elif embeddings_centered_sample.shape[0] == 3:
+ c_x = embeddings_mean_masked[0, 0]
+ c_y = embeddings_mean_masked[0, 1]
+ c_z = embeddings_mean_masked[0, 2]
+ c_x = c_x[c_x != 0].mean()
+ c_y = c_y[c_y != 0].mean()
+ c_z = c_z[c_z != 0].mean()
+ embeddings_centered_sample[0] -= c_x
+ embeddings_centered_sample[1] -= c_y
+ embeddings_centered_sample[2] -= c_z
+
+ embeddings_centered[sample] = embeddings_centered_sample
+
+ embeddings_layers = [
+ (
+ embeddings_centered[:, i : i + 1, ...].copy(),
+ {
+ "name": "Offset ("
+ + "zyx"[self.napari_dataset.get_num_spatial_dims() - i]
+ + ")"
+ if i < self.napari_dataset.get_num_spatial_dims()
+ else "Uncertainty",
+ "colormap": colormaps[
+ self.napari_dataset.get_num_spatial_dims() - i
+ ]
+ if i < self.napari_dataset.get_num_spatial_dims()
+ else "gray",
+ "blending": "additive",
+ },
+ "image",
+ )
+ for i in range(self.napari_dataset.get_num_spatial_dims() + 1)
+ ]
+ print("Clustering Objects in the obtained Foreground Mask ...")
+ detection = np.zeros_like(embeddings[:, 0:1, ...], dtype=np.uint16)
+ for sample in tqdm(range(embeddings.shape[0])):
+ embeddings_sample = embeddings[sample]
+ embeddings_std = embeddings_sample[-1, ...]
+ embeddings_mean = embeddings_sample[
+ np.newaxis, : self.napari_dataset.get_num_spatial_dims(), ...
+ ].copy()
+
+ detection_sample = mean_shift_segmentation(
+ embeddings_mean,
+ embeddings_std,
+ bandwidth=self.band_widths[sample],
+ min_size=self.inference_config.min_size,
+ reduction_probability=self.inference_config.reduction_probability,
+ threshold=self.thresholds[sample],
+ seeds=None,
+ )
+ detection[sample, 0, ...] = detection_sample
+
+ print("Converting Detections to Segmentations ...")
+ segmentation = np.zeros_like(embeddings[:, 0:1, ...], dtype=np.uint16)
+ if self.radio_button_cell.isChecked():
+ for sample in tqdm(range(embeddings.shape[0])):
+ segmentation_sample = detection[sample, 0].copy()
+ distance_foreground = dtedt(segmentation_sample == 0)
+ expanded_mask = (
+ distance_foreground < self.inference_config.grow_distance
+ )
+ distance_background = dtedt(expanded_mask)
+ segmentation_sample[
+ distance_background < self.inference_config.shrink_distance
+ ] = 0
+ segmentation[sample, 0, ...] = segmentation_sample
+ elif self.radio_button_nucleus.isChecked():
+ raw_image = raw_image_layer.data
+ for sample in tqdm(range(embeddings.shape[0])):
+ segmentation_sample = detection[sample, 0]
+ if (
+ self.napari_dataset.get_num_samples() == 0
+ and self.napari_dataset.get_num_channels() == 0
+ ):
+ raw_image_sample = raw_image
+ elif (
+ self.napari_dataset.get_num_samples() != 0
+ and self.napari_dataset.get_num_channels() == 0
+ ):
+ raw_image_sample = raw_image[sample]
+ elif (
+ self.napari_dataset.get_num_samples() == 0
+ and self.napari_dataset.get_num_channels() != 0
+ ):
+ raw_image_sample = raw_image[0]
+ else:
+ raw_image_sample = raw_image[sample, 0]
+
+ ids = np.unique(segmentation_sample)
+ ids = ids[ids != 0]
+
+ for id_ in ids:
+ segmentation_id_mask = segmentation_sample == id_
+ if self.napari_dataset.get_num_spatial_dims() == 2:
+ y, x = np.where(segmentation_id_mask)
+ y_min, y_max, x_min, x_max = (
+ np.min(y),
+ np.max(y),
+ np.min(x),
+ np.max(x),
+ )
+ elif self.napari_dataset.get_num_spatial_dims() == 3:
+ z, y, x = np.where(segmentation_id_mask)
+ z_min, z_max, y_min, y_max, x_min, x_max = (
+ np.min(z),
+ np.max(z),
+ np.min(y),
+ np.max(y),
+ np.min(x),
+ np.max(x),
+ )
+ raw_image_masked = raw_image_sample[segmentation_id_mask]
+ threshold = threshold_otsu(raw_image_masked)
+ mask = segmentation_id_mask & (
+ raw_image_sample > threshold
+ )
+
+ if self.napari_dataset.get_num_spatial_dims() == 2:
+ mask_small = binary_fill_holes(
+ mask[y_min : y_max + 1, x_min : x_max + 1]
+ )
+ mask[y_min : y_max + 1, x_min : x_max + 1] = mask_small
+ y, x = np.where(mask)
+ segmentation[sample, 0, y, x] = id_
+ elif self.napari_dataset.get_num_spatial_dims() == 3:
+ mask_small = binary_fill_holes(
+ mask[
+ z_min : z_max + 1,
+ y_min : y_max + 1,
+ x_min : x_max + 1,
+ ]
+ )
+ mask[
+ z_min : z_max + 1,
+ y_min : y_max + 1,
+ x_min : x_max + 1,
+ ] = mask_small
+ z, y, x = np.where(mask)
+ segmentation[sample, 0, z, y, x] = id_
+
+ print("Removing small objects ...")
+
+ # size filter - remove small objects
+ for sample in tqdm(range(embeddings.shape[0])):
+ segmentation[sample, 0, ...] = size_filter(
+ segmentation[sample, 0], self.min_sizes[sample]
+ )
+ return (
+ embeddings_layers
+ + [(foreground_mask, {"name": "Foreground Mask"}, "labels")]
+ + [(detection, {"name": "Detection"}, "labels")]
+ + [(segmentation, {"name": "Segmentation"}, "labels")]
+ )
+
+ def on_return_infer(self, layers):
+ for data, metadata, layer_type in layers:
+ if layer_type == "image":
+ self.viewer.add_image(data, **metadata)
+ elif layer_type == "labels":
+ self.viewer.add_labels(data.astype(int), **metadata)
+ self.viewer.layers["Offset (x)"].visible = False
+ self.viewer.layers["Offset (y)"].visible = False
+ self.viewer.layers["Uncertainty"].visible = False
+ self.viewer.layers["Foreground Mask"].visible = False
+ self.viewer.layers["Detection"].visible = False
+ self.viewer.layers["Segmentation"].visible = True
+ self.inference_worker.quit()
+ self.prepare_for_stop_inference()
diff --git a/src/napari_cellulus/widgets/__init__.py b/src/napari_cellulus/widgets/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py
deleted file mode 100644
index 09e1dc5..0000000
--- a/src/napari_cellulus/widgets/_widget.py
+++ /dev/null
@@ -1,765 +0,0 @@
-import os
-import time
-
-import gunpowder as gp
-import napari
-import numpy as np
-import pyqtgraph as pg
-import torch
-from cellulus.criterions import get_loss
-from cellulus.models import get_model
-from cellulus.train import train_iteration
-from cellulus.utils.mean_shift import mean_shift_segmentation
-from magicgui import magic_factory
-
-# widget stuff
-from napari.qt.threading import thread_worker
-from qtpy.QtCore import Qt
-from qtpy.QtWidgets import (
- QCheckBox,
- QComboBox,
- QGridLayout,
- QGroupBox,
- QHBoxLayout,
- QLabel,
- QLineEdit,
- QMainWindow,
- QPushButton,
- QScrollArea,
- QVBoxLayout,
- QWidget,
-)
-from scipy.ndimage import distance_transform_edt as dtedt
-from skimage.filters import threshold_otsu
-from superqt import QCollapsible
-
-from ..dataset import NapariDataset
-from ..gp.nodes.napari_image_source import NapariImageSource
-
-# local package imports
-from ..gui_helpers import layer_choice_widget
-
-############ GLOBALS ###################
-time_now = 0
-_train_config = None
-_model_config = None
-_segment_config = None
-_model = None
-_optimizer = None
-_scheduler = None
-_dataset = None
-
-
-class SegmentationWidget(QMainWindow):
- def __init__(self, napari_viewer):
- super().__init__()
- self.widget = QWidget()
- self.scroll = QScrollArea()
- self.viewer = napari_viewer
-
- # initialize train_config and model_config
-
- self.train_config = None
- self.model_config = None
- self.segment_config = None
-
- # initialize losses and iterations
- self.losses = []
- self.iterations = []
-
- # initialize mode. this will change to 'training' and 'segmentring' later
- self.mode = "configuring"
-
- # initialize UI components
- method_description_label = QLabel(
- 'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.
https://github.com/funkelab/cellulus'
- )
-
- # specify layout
- outer_layout = QVBoxLayout()
-
- # Initialize object size widget
- object_size_label = QLabel(self)
- object_size_label.setText("Object Size [px]:")
- self.object_size_line = QLineEdit(self)
- self.object_size_line.setText("30")
- device_label = QLabel(self)
- device_label.setText("Device")
- self.device_combo_box = QComboBox(self)
- self.device_combo_box.addItem("cpu")
- self.device_combo_box.addItem("cuda:0")
- self.device_combo_box.addItem("mps")
- self.device_combo_box.setCurrentText("mps")
-
- grid_layout = QGridLayout()
- grid_layout.addWidget(object_size_label, 0, 0)
- grid_layout.addWidget(self.object_size_line, 0, 1)
- grid_layout.addWidget(device_label, 1, 0)
- grid_layout.addWidget(self.device_combo_box, 1, 1)
-
- # Initialize train configs widget
- collapsible_train_configs = QCollapsible("Train Configs", self)
- collapsible_train_configs.addWidget(self.create_train_configs_widget)
-
- # Initialize model configs widget
- collapsible_model_configs = QCollapsible("Model Configs", self)
- collapsible_model_configs.addWidget(self.create_model_configs_widget)
-
- # Initialize loss/iterations widget
-
- self.losses_widget = pg.PlotWidget()
- self.losses_widget.setBackground((37, 41, 49))
- styles = {"color": "white", "font-size": "16px"}
- self.losses_widget.setLabel("left", "Loss", **styles)
- self.losses_widget.setLabel("bottom", "Iterations", **styles)
- # self.canvas = MplCanvas(self, width=5, height=10, dpi=100)
- # canvas_layout = QVBoxLayout()
- # canvas_layout.addWidget(self.canvas)
- # if len(self.iterations) == 0:
- # self.loss_plot = self.canvas.axes.plot([], [], label="")[0]
- # self.canvas.axes.legend()
- # self.canvas.axes.set_xlabel("Iterations")
- # self.canvas.axes.set_ylabel("Loss")
- # plot_container_widget = QWidget()
- # plot_container_widget.setLayout(canvas_layout)
-
- # Initialize Layer Choice
- self.raw_selector = layer_choice_widget(
- self.viewer, annotation=napari.layers.Image, name="raw"
- )
-
- # Initialize Checkboxes
- self.s_checkbox = QCheckBox("s")
- self.c_checkbox = QCheckBox("c")
- self.t_checkbox = QCheckBox("t")
- self.z_checkbox = QCheckBox("z")
- self.y_checkbox = QCheckBox("y")
- self.x_checkbox = QCheckBox("x")
-
- axis_layout = QHBoxLayout()
- axis_layout.addWidget(self.s_checkbox)
- axis_layout.addWidget(self.c_checkbox)
- axis_layout.addWidget(self.t_checkbox)
- axis_layout.addWidget(self.z_checkbox)
- axis_layout.addWidget(self.y_checkbox)
- axis_layout.addWidget(self.x_checkbox)
- axis_selector = QGroupBox("Axis Names:")
- axis_selector.setLayout(axis_layout)
-
- # Initialize Train Button
- self.train_button = QPushButton("Train", self)
- self.train_button.clicked.connect(self.prepare_for_training)
-
- # Initialize Save and Load Widget
- collapsible_save_load_widget = QCollapsible(
- "Save and Load Model", self
- )
- save_load_layout = QHBoxLayout()
- save_model_button = QPushButton("Save Model", self)
- load_model_button = QPushButton("Load Model", self)
- save_load_layout.addWidget(save_model_button)
- save_load_layout.addWidget(load_model_button)
- save_load_widget = QWidget()
- save_load_widget.setLayout(save_load_layout)
- collapsible_save_load_widget.addWidget(save_load_widget)
-
- # Initialize Segment Configs widget
- collapsible_segment_configs = QCollapsible("Inference Configs", self)
- collapsible_segment_configs.addWidget(
- self.create_segment_configs_widget
- )
-
- # Initialize Segment Button
- self.segment_button = QPushButton("Segment", self)
- self.segment_button.clicked.connect(self.prepare_for_segmenting)
-
- # Initialize progress bar
- # self.pbar = QProgressBar(self)
-
- # Initialize Feedback Button
- self.feedback_label = QLabel(
- 'Please share any feedback here.'
- )
-
- # Add all components to outer_layout
-
- outer_layout.addWidget(method_description_label)
- outer_layout.addLayout(grid_layout)
- # outer_layout.addWidget(global_params_widget)
- outer_layout.addWidget(collapsible_train_configs)
- outer_layout.addWidget(collapsible_model_configs)
- # outer_layout.addWidget(plot_container_widget)
- outer_layout.addWidget(self.losses_widget)
- outer_layout.addWidget(self.raw_selector.native)
- outer_layout.addWidget(axis_selector)
- outer_layout.addWidget(self.train_button)
- outer_layout.addWidget(collapsible_save_load_widget)
- outer_layout.addWidget(collapsible_segment_configs)
- outer_layout.addWidget(self.segment_button)
- outer_layout.addWidget(self.feedback_label)
- outer_layout.setSpacing(20)
- self.widget.setLayout(outer_layout)
-
- self.scroll.setWidget(self.widget)
- self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
- self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
- self.scroll.setWidgetResizable(True)
-
- self.setFixedWidth(400)
- self.setCentralWidget(self.scroll)
-
- @property
- def create_train_configs_widget(self):
- @magic_factory(call_button="Save")
- def train_configs_widget(
- crop_size: int = 252,
- batch_size: int = 8,
- max_iterations: int = 100_000,
- initial_learning_rate: float = 4e-5,
- temperature: float = 10.0,
- regularizer_weight: float = 1e-5,
- reduce_mean: bool = True,
- density: float = 0.1,
- kappa: float = 10.0,
- num_workers: int = 8,
- control_point_spacing: int = 64,
- control_point_jitter: float = 2.0,
- ):
- global _train_config
- # Specify what should happen when 'Save' button is pressed
- _train_config = {
- "crop_size": crop_size,
- "batch_size": batch_size,
- "max_iterations": max_iterations,
- "initial_learning_rate": initial_learning_rate,
- "temperature": temperature,
- "regularizer_weight": regularizer_weight,
- "reduce_mean": reduce_mean,
- "density": density,
- "kappa": kappa,
- "num_workers": num_workers,
- "control_point_spacing": control_point_spacing,
- "control_point_jitter": control_point_jitter,
- }
-
- if not hasattr(self, "__create_train_configs_widget"):
- self.__create_train_configs_widget = train_configs_widget()
- self.__create_train_configs_widget_native = (
- self.__create_train_configs_widget.native
- )
- return self.__create_train_configs_widget_native
-
- @property
- def create_model_configs_widget(self):
- @magic_factory(call_button="Save")
- def model_configs_widget(
- num_fmaps: int = 24,
- fmap_inc_factor: int = 3,
- features_in_last_layer: int = 64,
- downsampling_factors: int = 2,
- downsampling_layers: int = 1,
- initialize: bool = True,
- ):
- global _model_config
- # Specify what should happen when 'Save' button is pressed
- _model_config = {
- "num_fmaps": num_fmaps,
- "fmap_inc_factor": fmap_inc_factor,
- "features_in_last_layer": features_in_last_layer,
- "downsampling_factors": downsampling_factors,
- "downsampling_layers": downsampling_layers,
- "initialize": initialize,
- }
-
- if not hasattr(self, "__create_model_configs_widget"):
- self.__create_model_configs_widget = model_configs_widget()
- self.__create_model_configs_widget_native = (
- self.__create_model_configs_widget.native
- )
- return self.__create_model_configs_widget_native
-
- @property
- def create_segment_configs_widget(self):
- @magic_factory(call_button="Save")
- def segment_configs_widget(
- crop_size: int = 252,
- p_salt_pepper: float = 0.01,
- num_infer_iterations: int = 16,
- bandwidth: int = 7,
- reduction_probability: float = 0.1,
- min_size: int = 25,
- grow_distance: int = 3,
- shrink_distance: int = 6,
- ):
- global _segment_config
- # Specify what should happen when 'Save' button is pressed
- _segment_config = {
- "crop_size": crop_size,
- "p_salt_pepper": p_salt_pepper,
- "num_infer_iterations": num_infer_iterations,
- "bandwidth": bandwidth,
- "reduction_probability": reduction_probability,
- "min_size": min_size,
- "grow_distance": grow_distance,
- "shrink_distance": shrink_distance,
- }
-
- if not hasattr(self, "__create_segment_configs_widget"):
- self.__create_segment_configs_widget = segment_configs_widget()
- self.__create_segment_configs_widget_native = (
- self.__create_segment_configs_widget.native
- )
- return self.__create_segment_configs_widget_native
-
- def get_selected_axes(self):
- names = []
- for name, checkbox in zip(
- "sctzyx",
- [
- self.s_checkbox,
- self.c_checkbox,
- self.t_checkbox,
- self.z_checkbox,
- self.y_checkbox,
- self.x_checkbox,
- ],
- ):
- if checkbox.isChecked():
- names.append(name)
- return names
-
- def prepare_for_training(self):
-
- # check if train_config object exists
- global _train_config, _model_config, _model, _optimizer
-
- if _train_config is None:
- # set default values
- _train_config = {
- "crop_size": 252,
- "batch_size": 8,
- "max_iterations": 100_000,
- "initial_learning_rate": 4e-5,
- "temperature": 10.0,
- "regularizer_weight": 1e-5,
- "reduce_mean": True,
- "density": 0.1,
- "kappa": 10.0,
- "num_workers": 8,
- "control_point_spacing": 64,
- "control_point_jitter": 2.0,
- }
-
- # check if model_config object exists
- if _model_config is None:
- _model_config = {
- "num_fmaps": 24,
- "fmap_inc_factor": 3,
- "features_in_last_layer": 64,
- "downsampling_factors": 2,
- "downsampling_layers": 1,
- "initialize": True,
- }
-
- print(self.sender())
- self.update_mode(self.sender())
-
- if self.mode == "training":
- self.worker = self.train_napari()
- self.worker.yielded.connect(self.on_yield)
- self.worker.returned.connect(self.on_return)
- self.worker.start()
- elif self.mode == "configuring":
- state = {
- "iteration": self.iterations[-1],
- "model_state_dict": _model.state_dict(),
- "optim_state_dict": _optimizer.state_dict(),
- "iterations": self.iterations,
- "losses": self.losses,
- }
- filename = os.path.join("models", "last.pth")
- torch.save(state, filename)
- self.worker.quit()
-
- def update_mode(self, sender):
-
- if self.train_button.text() == "Train" and sender == self.train_button:
- self.train_button.setText("Pause")
- self.mode = "training"
- self.segment_button.setEnabled(False)
- elif (
- self.train_button.text() == "Pause" and sender == self.train_button
- ):
- self.train_button.setText("Train")
- self.mode = "configuring"
- self.segment_button.setEnabled(True)
- elif (
- self.segment_button.text() == "Segment"
- and sender == self.segment_button
- ):
- self.segment_button.setText("Pause")
- self.mode = "segmenting"
- self.train_button.setEnabled(False)
- elif (
- self.segment_button.text() == "Pause"
- and sender == self.segment_button
- ):
- self.segment_button.setText("Segment")
- self.mode = "configuring"
- self.train_button.setEnabled(True)
-
- @thread_worker
- def train_napari(self):
-
- global _train_config, _model_config, _model, _scheduler, _optimizer, _dataset
-
- # Turn layer into dataset
- _dataset = NapariDataset(
- layer=self.raw_selector.value,
- axis_names=self.get_selected_axes(),
- crop_size=_train_config["crop_size"],
- control_point_spacing=_train_config["control_point_spacing"],
- control_point_jitter=_train_config["control_point_jitter"],
- )
-
- if not os.path.exists("models"):
- os.makedirs("models")
-
- # create train dataloader
- dataloader = torch.utils.data.DataLoader(
- dataset=_dataset,
- batch_size=_train_config["batch_size"],
- drop_last=True,
- num_workers=_train_config["num_workers"],
- pin_memory=True,
- )
-
- downsampling_factors = [
- [
- _model_config["downsampling_factors"],
- ]
- * _dataset.get_num_spatial_dims()
- ] * _model_config["downsampling_layers"]
-
- # set model
- _model = get_model(
- in_channels=_dataset.get_num_channels(),
- out_channels=_dataset.get_num_spatial_dims(),
- num_fmaps=_model_config["num_fmaps"],
- fmap_inc_factor=_model_config["fmap_inc_factor"],
- features_in_last_layer=_model_config["features_in_last_layer"],
- downsampling_factors=[
- tuple(factor) for factor in downsampling_factors
- ],
- num_spatial_dims=_dataset.get_num_spatial_dims(),
- )
-
- # set device
- device = torch.device(self.device_combo_box.currentText())
-
- _model = _model.to(device)
-
- # initialize model weights
- if _model_config["initialize"]:
- for _name, layer in _model.named_modules():
- if isinstance(layer, torch.nn.modules.conv._ConvNd):
- torch.nn.init.kaiming_normal_(
- layer.weight, nonlinearity="relu"
- )
-
- # set loss
- criterion = get_loss(
- regularizer_weight=_train_config["regularizer_weight"],
- temperature=_train_config["temperature"],
- kappa=_train_config["kappa"],
- density=_train_config["density"],
- num_spatial_dims=_dataset.get_num_spatial_dims(),
- reduce_mean=_train_config["reduce_mean"],
- device=device,
- )
-
- # set optimizer
- _optimizer = torch.optim.Adam(
- _model.parameters(),
- lr=_train_config["initial_learning_rate"],
- )
-
- # set scheduler:
-
- def lambda_(iteration):
- return pow(
- (1 - ((iteration) / _train_config["max_iterations"])), 0.9
- )
-
- # resume training
- if len(self.iterations) == 0:
- start_iteration = 0
- else:
- start_iteration = self.iterations[-1]
-
- if not os.path.exists("models/last.pth"):
- pass
- else:
- print("Resuming model from 'models/last.pth'")
- state = torch.load("models/last.pth", map_location=device)
- start_iteration = state["iteration"] + 1
- self.iterations = state["iterations"]
- self.losses = state["losses"]
- _model.load_state_dict(state["model_state_dict"], strict=True)
- _optimizer.load_state_dict(state["optim_state_dict"])
-
- # call `train_iteration`
- for iteration, batch in zip(
- range(start_iteration, _train_config["max_iterations"]),
- dataloader,
- ):
- _scheduler = torch.optim.lr_scheduler.LambdaLR(
- _optimizer, lr_lambda=lambda_, last_epoch=iteration - 1
- )
-
- train_loss, prediction = train_iteration(
- batch,
- model=_model,
- criterion=criterion,
- optimizer=_optimizer,
- device=device,
- )
- _scheduler.step()
- yield (iteration, train_loss)
-
- def on_yield(self, step_data):
- if self.mode == "training":
- global time_now
- iteration, loss = step_data
- current_time = time.time()
- time_elapsed = current_time - time_now
- time_now = current_time
- print(
- f"iteration {iteration}, loss {loss}, seconds/iteration {time_elapsed}"
- )
- self.iterations.append(iteration)
- self.losses.append(loss)
- self.update_canvas()
- elif self.mode == "segmenting":
- print(step_data)
- # self.pbar.setValue(step_data)
-
- def update_canvas(self):
- self.losses_widget.plot(self.iterations, self.losses)
-
- # self.loss_plot.set_xdata(self.iterations)
- # self.loss_plot.set_ydata(self.losses)
- # self.canvas.axes.relim()
- # self.canvas.axes.autoscale_view()
- # self.canvas.draw()
-
- def on_return(self, layers):
- # Describes what happens once segment button has completed
-
- for data, metadata, layer_type in layers:
- if layer_type == "image":
- self.viewer.add_image(data, **metadata)
- elif layer_type == "labels":
- self.viewer.add_labels(data.astype(int), **metadata)
- self.update_mode(self.segment_button)
-
- def prepare_for_segmenting(self):
- global _segment_config
- # check if segment_config exists
- if _segment_config is None:
- _segment_config = {
- "crop_size": 252,
- "p_salt_pepper": 0.01,
- "num_infer_iterations": 16,
- "bandwidth": None,
- "reduction_probability": 0.1,
- "min_size": None,
- "grow_distance": 3,
- "shrink_distance": 6,
- }
-
- # update mode
- self.update_mode(self.sender())
-
- if self.mode == "segmenting":
- self.worker = self.segment_napari()
- self.worker.yielded.connect(self.on_yield)
- self.worker.returned.connect(self.on_return)
- self.worker.start()
- elif self.mode == "configuring":
- self.worker.quit()
-
- @thread_worker
- def segment_napari(self):
- global _segment_config, _model, _dataset
-
- raw = self.raw_selector.value
-
- if _segment_config["bandwidth"] is None:
- _segment_config["bandwidth"] = int(
- 0.5 * float(self.object_size_line.text())
- )
- if _segment_config["min_size"] is None:
- _segment_config["min_size"] = int(
- 0.1 * np.pi * (float(self.object_size_line.text()) ** 2) / 4
- )
- _model.eval()
-
- num_spatial_dims = _dataset.num_spatial_dims
- num_channels = _dataset.num_channels
- spatial_dims = _dataset.spatial_dims
- num_samples = _dataset.num_samples
-
- print(
- f"Num spatial dims {num_spatial_dims} num channels {num_channels} spatial_dims {spatial_dims} num_samples {num_samples}"
- )
-
- crop_size = (_segment_config["crop_size"],) * num_spatial_dims
- device = self.device_combo_box.currentText()
-
- num_channels_temp = 1 if num_channels == 0 else num_channels
-
- voxel_size = gp.Coordinate((1,) * num_spatial_dims)
- _model.set_infer(
- p_salt_pepper=_segment_config["p_salt_pepper"],
- num_infer_iterations=_segment_config["num_infer_iterations"],
- device=device,
- )
-
- input_shape = gp.Coordinate((1, num_channels_temp, *crop_size))
- output_shape = gp.Coordinate(
- _model(
- torch.zeros(
- (1, num_channels_temp, *crop_size), dtype=torch.float32
- ).to(device)
- ).shape
- )
- input_size = (
- gp.Coordinate(input_shape[-num_spatial_dims:]) * voxel_size
- )
- output_size = (
- gp.Coordinate(output_shape[-num_spatial_dims:]) * voxel_size
- )
- context = (input_size - output_size) / 2
-
- raw_key = gp.ArrayKey("RAW")
- prediction_key = gp.ArrayKey("PREDICT")
- scan_request = gp.BatchRequest()
- scan_request.add(raw_key, input_size)
- scan_request.add(prediction_key, output_size)
-
- predict = gp.torch.Predict(
- _model,
- inputs={"raw": raw_key},
- outputs={0: prediction_key},
- array_specs={prediction_key: gp.ArraySpec(voxel_size=voxel_size)},
- )
- pipeline = NapariImageSource(
- raw,
- raw_key,
- gp.ArraySpec(
- gp.Roi(
- (0,) * num_spatial_dims,
- raw.data.shape[-num_spatial_dims:],
- ),
- voxel_size=voxel_size,
- ),
- spatial_dims,
- )
-
- if num_samples == 0 and num_channels == 0:
- pipeline += (
- gp.Pad(raw_key, context)
- + gp.Unsqueeze([raw_key], 0)
- + gp.Unsqueeze([raw_key], 0)
- + predict
- + gp.Scan(scan_request)
- )
- elif num_samples != 0 and num_channels == 0:
- pipeline += (
- gp.Pad(raw_key, context)
- + gp.Unsqueeze([raw_key], 1)
- + predict
- + gp.Scan(scan_request)
- )
- elif num_samples == 0 and num_channels != 0:
- pipeline += (
- gp.Pad(raw_key, context)
- + gp.Unsqueeze([raw_key], 0)
- + predict
- + gp.Scan(scan_request)
- )
- elif num_samples != 0 and num_channels != 0:
- pipeline += (
- gp.Pad(raw_key, context) + predict + gp.Scan(scan_request)
- )
-
- # request to pipeline for ROI of whole image/volume
- request = gp.BatchRequest()
- request.add(prediction_key, raw.data.shape[-num_spatial_dims:])
- counter = 0
- with gp.build(pipeline):
- batch = pipeline.request_batch(request)
- yield counter
- counter += 0.1
-
- prediction = batch.arrays[prediction_key].data
- colormaps = ["red", "green", "blue"]
- prediction_layers = [
- (
- prediction[:, i : i + 1, ...].copy(),
- {
- "name": "offset-" + "zyx"[num_spatial_dims - i]
- if i < num_spatial_dims
- else "std",
- "colormap": colormaps[num_spatial_dims - i]
- if i < num_spatial_dims
- else "gray",
- "blending": "additive",
- },
- "image",
- )
- for i in range(num_spatial_dims + 1)
- ]
-
- foreground = np.zeros_like(prediction[:, 0:1, ...], dtype=bool)
- for sample in range(prediction.shape[0]):
- embeddings = prediction[sample]
- embeddings_std = embeddings[-1, ...]
- thresh = threshold_otsu(embeddings_std)
- print(f"Threshold for sample {sample} is {thresh}")
- binary_mask = embeddings_std < thresh
- foreground[sample, 0, ...] = binary_mask
-
- labels = np.zeros_like(prediction[:, 0:1, ...], dtype=np.uint64)
- for sample in range(prediction.shape[0]):
- embeddings = prediction[sample]
- embeddings_std = embeddings[-1, ...]
- embeddings_mean = embeddings[np.newaxis, :num_spatial_dims, ...]
- segmentation = mean_shift_segmentation(
- embeddings_mean,
- embeddings_std,
- _segment_config["bandwidth"],
- _segment_config["min_size"],
- _segment_config["reduction_probability"],
- )
- labels[sample, 0, ...] = segmentation
-
- pp_labels = np.zeros_like(prediction[:, 0:1, ...], dtype=np.uint64)
- for sample in range(prediction.shape[0]):
- segmentation = labels[sample, 0]
- distance_foreground = dtedt(segmentation == 0)
- expanded_mask = (
- distance_foreground < _segment_config["grow_distance"]
- )
- distance_background = dtedt(expanded_mask)
- segmentation[
- distance_background < _segment_config["shrink_distance"]
- ] = 0
- pp_labels[sample, 0, ...] = segmentation
- return (
- prediction_layers
- + [(foreground, {"name": "Foreground"}, "labels")]
- + [(labels, {"name": "Segmentation"}, "labels")]
- + [(pp_labels, {"name": "Post Processed"}, "labels")]
- )