From b282aab40a1d36fa495187e5f53ea65c799c6f4f Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 18 Sep 2023 10:30:20 -0400 Subject: [PATCH 01/40] Refactor dataset.py --- src/napari_cellulus/dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py index 6608ba1..60aee88 100644 --- a/src/napari_cellulus/dataset.py +++ b/src/napari_cellulus/dataset.py @@ -1,14 +1,13 @@ import math -from typing import Tuple, List +from typing import List, Tuple import gunpowder as gp +from cellulus.datasets import DatasetMetaData from napari.layers import Image from torch.utils.data import IterableDataset from .gp.nodes.napari_image_source import NapariImageSource -from cellulus.datasets import DatasetMetaData - class NapariDataset(IterableDataset): # type: ignore def __init__( From 21b4c7ff5d49b473a95dfbf1c03972c85601216c Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 18 Sep 2023 10:33:12 -0400 Subject: [PATCH 02/40] Refactor widget.py --- src/napari_cellulus/widgets/_widget.py | 38 ++++++++++++++++---------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 666d176..8fec053 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -6,8 +6,8 @@ Replace code below according to your needs. """ -import dataclasses import contextlib +import dataclasses # python built in libraries from pathlib import Path @@ -33,16 +33,16 @@ NavigationToolbar2QT as NavigationToolbar, ) from napari.qt.threading import FunctionWorker, thread_worker +from qtpy.QtCore import QUrl +from qtpy.QtGui import QDesktopServices from qtpy.QtWidgets import ( + QCheckBox, + QGroupBox, + QHBoxLayout, QPushButton, QVBoxLayout, QWidget, - QCheckBox, - QHBoxLayout, - QGroupBox, ) -from qtpy.QtGui import QDesktopServices -from qtpy.QtCore import QUrl from superqt import QCollapsible from tqdm import tqdm @@ -236,12 +236,12 @@ def __init__(self, napari_viewer): ) layout.addWidget(self.raw_selector.native) - self.s_checkbox = QCheckBox('s') - self.c_checkbox = QCheckBox('c') - self.t_checkbox = QCheckBox('t') - self.z_checkbox = QCheckBox('z') - self.y_checkbox = QCheckBox('y') - self.x_checkbox = QCheckBox('x') + self.s_checkbox = QCheckBox("s") + self.c_checkbox = QCheckBox("c") + self.t_checkbox = QCheckBox("t") + self.z_checkbox = QCheckBox("z") + self.y_checkbox = QCheckBox("y") + self.x_checkbox = QCheckBox("x") axis_layout = QHBoxLayout() axis_layout.addWidget(self.s_checkbox) @@ -630,13 +630,21 @@ def update_progress_plot(self): def get_selected_axes(self): names = [] - for name, checkbox in zip("sctzyx", [self.s_checkbox, self.c_checkbox, self.t_checkbox, - self.z_checkbox, self.y_checkbox, self.x_checkbox]): + for name, checkbox in zip( + "sctzyx", + [ + self.s_checkbox, + self.c_checkbox, + self.t_checkbox, + self.z_checkbox, + self.y_checkbox, + self.x_checkbox, + ], + ): if checkbox.isChecked(): names.append(name) return names - def start_training_loop(self): self.reset_training_state(keep_stats=True) training_stats = get_training_stats() From 619e2f9b930bec9438972030df3339e92e6b09e6 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 18 Sep 2023 10:36:04 -0400 Subject: [PATCH 03/40] Change default crop size to 252 --- src/napari_cellulus/widgets/_widget.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 8fec053..f91bb64 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -97,9 +97,11 @@ def get_train_config(**kwargs): return _train_config -@magic_factory(call_button="Save") +@magic_factory( + call_button="Save", device={"choices": ["cpu", "cuda:0", "mps"]} +) def train_config_widget( - crop_size: list[int] = [256, 256], + crop_size: list[int] = [252, 252], batch_size: int = 8, max_iterations: int = 100_000, initial_learning_rate: float = 4e-5, From d2ff8a81978a0dad59aff5a5a9128353509c60d9 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 18 Sep 2023 10:36:30 -0400 Subject: [PATCH 04/40] Generalize device --- src/napari_cellulus/widgets/_widget.py | 37 +++++++++++++++----------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index f91bb64..594b334 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -115,6 +115,7 @@ def train_config_widget( num_workers: int = 8, control_point_spacing: int = 64, control_point_jitter: float = 2.0, + device="cpu", ): get_train_config( crop_size=crop_size, @@ -131,6 +132,7 @@ def train_config_widget( num_workers=num_workers, control_point_spacing=control_point_spacing, control_point_jitter=control_point_jitter, + device=device, ) @@ -165,6 +167,10 @@ def get_training_state(dataset: Optional[NapariDataset] = None): global _scheduler global _model_config global _train_config + + # set device + device = torch.device(_train_config.device) + if _model_config is None: # TODO: deal with hard coded defaults _model_config = ModelConfig(num_fmaps=24, fmap_inc_factor=3) @@ -182,7 +188,7 @@ def get_training_state(dataset: Optional[NapariDataset] = None): tuple(factor) for factor in _model_config.downsampling_factors ], num_spatial_dims=dataset.get_num_spatial_dims(), - ).cuda() + ).to(device) # Weight initialization # TODO: move weight initialization to funlib.learn.torch @@ -379,11 +385,18 @@ def async_segment( raw.data = raw.data.astype(np.float32) global _model + assert ( _model is not None ), "You must train a model before running inference" model = _model + # set in eval mode + model.eval() + + # device + device = torch.device(_train_config.device) + num_spatial_dims = len(raw.data.shape) - 2 num_channels = raw.data.shape[1] @@ -391,6 +404,7 @@ def async_segment( model.set_infer( p_salt_pepper=p_salt_pepper, num_infer_iterations=num_infer_iterations, + device=device, ) # prediction crop size is the size of the scanned tiles to be provided to the model @@ -405,7 +419,7 @@ def async_segment( *crop_size, ), dtype=torch.float32, - ).cuda() + ).to(device) ).shape ) @@ -800,7 +814,9 @@ def train_cellulus( axis_names, iteration=0, ): + train_config = get_train_config() + # Turn layer into dataset: train_dataset = NapariDataset( raw, @@ -811,12 +827,6 @@ def train_cellulus( ) model, optimizer, scheduler = get_training_state(train_dataset) - # TODO: How to display profiling stats - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") - model = model.to(device) model.train() train_dataloader = torch.utils.data.DataLoader( @@ -835,15 +845,11 @@ def train_cellulus( density=train_config.density, num_spatial_dims=train_dataset.get_num_spatial_dims(), reduce_mean=train_config.reduce_mean, + device=train_config.device, ) - def train_iteration( - batch, - model, - criterion, - optimizer, - ): - prediction = model(batch.cuda()) + def train_iteration(batch, model, criterion, optimizer, device): + prediction = model(batch.to(device)) loss = criterion(prediction) loss = loss.mean() optimizer.zero_grad() @@ -864,6 +870,7 @@ def train_iteration( model=model, criterion=criterion, optimizer=optimizer, + device=train_config.device, ) scheduler.step() From f0801622338d1a5934b88de1519623e5750d0c0f Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 18 Sep 2023 10:37:00 -0400 Subject: [PATCH 05/40] Update version --- src/napari_cellulus/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/napari_cellulus/__init__.py b/src/napari_cellulus/__init__.py index 60ccee5..c1d6450 100644 --- a/src/napari_cellulus/__init__.py +++ b/src/napari_cellulus/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.1" +__version__ = "0.0.2" from ._sample_data import tissuenet_sample from .widgets._widget import ( From 9fce242a0b5b4a7ed3dda75fd514e10a0bbacf7c Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 18 Sep 2023 10:38:13 -0400 Subject: [PATCH 06/40] Refactor setup.cfg --- setup.cfg | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index e99c5a7..7280387 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,7 +38,6 @@ install_requires = matplotlib torch gunpowder - python_requires = >=3.8 From a408821825c01fe7d11c9f534ea300db94f45f71 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 20 Sep 2023 10:22:03 -0400 Subject: [PATCH 07/40] Add meta_data.py without asserts checking for sample axis --- src/napari_cellulus/dataset.py | 9 ++++++-- src/napari_cellulus/meta_data.py | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 src/napari_cellulus/meta_data.py diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py index 60aee88..045d509 100644 --- a/src/napari_cellulus/dataset.py +++ b/src/napari_cellulus/dataset.py @@ -2,11 +2,11 @@ from typing import List, Tuple import gunpowder as gp -from cellulus.datasets import DatasetMetaData from napari.layers import Image from torch.utils.data import IterableDataset from .gp.nodes.napari_image_source import NapariImageSource +from .meta_data import NapariDatasetMetaData class NapariDataset(IterableDataset): # type: ignore @@ -110,7 +110,9 @@ def __yield_sample(self): yield sample[self.raw].data[0] def __read_meta_data(self): - meta_data = DatasetMetaData(self.layer.data.shape, self.axis_names) + meta_data = NapariDatasetMetaData( + self.layer.data.shape, self.axis_names + ) self.num_dims = meta_data.num_dims self.num_spatial_dims = meta_data.num_spatial_dims @@ -119,6 +121,9 @@ def __read_meta_data(self): self.sample_dim = meta_data.sample_dim self.channel_dim = meta_data.channel_dim self.time_dim = meta_data.time_dim + print( + f"{self.num_dims}, {self.num_spatial_dims}, {self.num_channels}, {self.num_samples}, {self.sample_dim}, {self.channel_dim}, {self.time_dim}" + ) def get_num_channels(self): return self.num_channels diff --git a/src/napari_cellulus/meta_data.py b/src/napari_cellulus/meta_data.py new file mode 100644 index 0000000..9ae1a17 --- /dev/null +++ b/src/napari_cellulus/meta_data.py @@ -0,0 +1,37 @@ +from typing import Tuple + +import zarr + +from cellulus.configs import DatasetConfig + + +class NapariDatasetMetaData: + def __init__(self, shape, axis_names): + self.num_dims = len(axis_names) + self.num_spatial_dims: int = 0 + self.num_samples: int = 0 + self.num_channels: int = 0 + self.sample_dim = None + self.channel_dim = None + self.time_dim = None + self.spatial_array: Tuple[int, ...] = () + for dim, axis_name in enumerate(axis_names): + if axis_name == "s": + self.sample_dim = dim + self.num_samples = shape[dim] + elif axis_name == "c": + self.channel_dim = dim + self.num_channels = shape[dim] + elif axis_name == "t": + self.num_spatial_dims += 1 + self.time_dim = dim + elif axis_name == "z": + self.num_spatial_dims += 1 + self.spatial_array += (shape[dim],) + elif axis_name == "y": + self.num_spatial_dims += 1 + self.spatial_array += (shape[dim],) + elif axis_name == "x": + self.num_spatial_dims += 1 + self.spatial_array += (shape[dim],) + From 522bafeb9ece2aa72124ef7826314b3decf2b4e5 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 20 Sep 2023 15:02:45 -0400 Subject: [PATCH 08/40] Specify all cases for pipeline --- src/napari_cellulus/dataset.py | 92 ++++++++++++++++++++++++++++------ 1 file changed, 78 insertions(+), 14 deletions(-) diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py index 045d509..415dd82 100644 --- a/src/napari_cellulus/dataset.py +++ b/src/napari_cellulus/dataset.py @@ -75,22 +75,86 @@ def __iter__(self): def __setup_pipeline(self): self.raw = gp.ArrayKey("RAW") + # treat all dimensions as spatial, with a voxel size = 1 + voxel_size = gp.Coordinate((1,) * self.num_dims) + offset = gp.Coordinate((0,) * self.num_dims) + shape = gp.Coordinate(self.layer.data.shape) + raw_spec = gp.ArraySpec( + roi=gp.Roi(offset, voxel_size * shape), + dtype=float, + interpolatable=True, + voxel_size=voxel_size, + ) - self.pipeline = ( - NapariImageSource(self.layer, self.raw) - + gp.RandomLocation() - + gp.ElasticAugment( - control_point_spacing=(self.control_point_spacing,) - * self.num_spatial_dims, - jitter_sigma=(self.control_point_jitter,) - * self.num_spatial_dims, - rotation_interval=(0, math.pi / 2), - scale_interval=(0.9, 1.1), - subsample=4, - spatial_dims=self.num_spatial_dims, + if self.num_channels == 0 and self.num_samples == 0: + self.pipeline = ( + NapariImageSource(self.layer, self.raw, raw_spec) + + gp.RandomLocation() + + gp.ElasticAugment( + control_point_spacing=(self.control_point_spacing,) + * self.num_spatial_dims, + jitter_sigma=(self.control_point_jitter,) + * self.num_spatial_dims, + rotation_interval=(0, math.pi / 2), + scale_interval=(0.9, 1.1), + subsample=4, + spatial_dims=self.num_spatial_dims, + ) + + gp.Unsqueeze([self.raw], 0) + + gp.Unsqueeze([self.raw], 0) + # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) + ) + elif self.num_channels == 0 and self.num_samples != 0: + self.pipeline = ( + NapariImageSource(self.layer, self.raw, raw_spec) + + gp.Unsqueeze([self.raw], 1) + + gp.RandomLocation() + + gp.ElasticAugment( + control_point_spacing=(self.control_point_spacing,) + * self.num_spatial_dims, + jitter_sigma=(self.control_point_jitter,) + * self.num_spatial_dims, + rotation_interval=(0, math.pi / 2), + scale_interval=(0.9, 1.1), + subsample=4, + spatial_dims=self.num_spatial_dims, + ) + + gp.Unsqueeze([self.raw], 1) + # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) + ) + elif self.num_channels != 0 and self.num_samples == 0: + self.pipeline = ( + NapariImageSource(self.layer, self.raw, raw_spec) + + gp.RandomLocation() + + gp.ElasticAugment( + control_point_spacing=(self.control_point_spacing,) + * self.num_spatial_dims, + jitter_sigma=(self.control_point_jitter,) + * self.num_spatial_dims, + rotation_interval=(0, math.pi / 2), + scale_interval=(0.9, 1.1), + subsample=4, + spatial_dims=self.num_spatial_dims, + ) + + gp.Unsqueeze([self.raw], 0) + # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) + ) + elif self.num_channels != 0 and self.num_samples != 0: + self.pipeline = ( + NapariImageSource(self.layer, self.raw, raw_spec) + + gp.RandomLocation() + + gp.ElasticAugment( + control_point_spacing=(self.control_point_spacing,) + * self.num_spatial_dims, + jitter_sigma=(self.control_point_jitter,) + * self.num_spatial_dims, + rotation_interval=(0, math.pi / 2), + scale_interval=(0.9, 1.1), + subsample=4, + spatial_dims=self.num_spatial_dims, + ) + # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) ) - # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) - ) def __yield_sample(self): """An infinite generator of crops.""" From 9c5b603e21c154dc5957531d51810841dc4b7bd6 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 20 Sep 2023 15:03:10 -0400 Subject: [PATCH 09/40] Specify all cases for request --- src/napari_cellulus/dataset.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py index 415dd82..ab17277 100644 --- a/src/napari_cellulus/dataset.py +++ b/src/napari_cellulus/dataset.py @@ -163,13 +163,33 @@ def __yield_sample(self): while True: # request one sample, all channels, plus crop dimensions request = gp.BatchRequest() - request[self.raw] = gp.ArraySpec( - roi=gp.Roi( - (0,) * self.num_dims, - (1, self.num_channels, *self.crop_size), + if self.num_channels == 0 and self.num_samples == 0: + request[self.raw] = gp.ArraySpec( + roi=gp.Roi( + (0,) * (self.num_dims), + self.crop_size, + ) + ) + elif self.num_channels == 0 and self.num_samples != 0: + request[self.raw] = gp.ArraySpec( + roi=gp.Roi( + (0,) * (self.num_dims), (1, *self.crop_size) + ) + ) + elif self.num_channels != 0 and self.num_samples == 0: + request[self.raw] = gp.ArraySpec( + roi=gp.Roi( + (0,) * (self.num_dims), + (self.num_channels, *self.crop_size), + ) + ) + elif self.num_channels != 0 and self.num_samples != 0: + request[self.raw] = gp.ArraySpec( + roi=gp.Roi( + (0,) * (self.num_dims), + (1, self.num_channels, *self.crop_size), + ) ) - ) - sample = self.pipeline.request_batch(request) yield sample[self.raw].data[0] From 751bb38967d08eb197f9d59f2714322bb77a50b4 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 20 Sep 2023 15:05:28 -0400 Subject: [PATCH 10/40] Return num_channels = 1 if no channel dim provided --- src/napari_cellulus/dataset.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py index ab17277..8cba5a1 100644 --- a/src/napari_cellulus/dataset.py +++ b/src/napari_cellulus/dataset.py @@ -205,12 +205,9 @@ def __read_meta_data(self): self.sample_dim = meta_data.sample_dim self.channel_dim = meta_data.channel_dim self.time_dim = meta_data.time_dim - print( - f"{self.num_dims}, {self.num_spatial_dims}, {self.num_channels}, {self.num_samples}, {self.sample_dim}, {self.channel_dim}, {self.time_dim}" - ) def get_num_channels(self): - return self.num_channels + return 1 if self.num_channels == 0 else self.num_channels def get_num_spatial_dims(self): return self.num_spatial_dims From 199e956d96315d024409448dccb21385b9760567 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 20 Sep 2023 15:06:03 -0400 Subject: [PATCH 11/40] Update imports --- src/napari_cellulus/meta_data.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/napari_cellulus/meta_data.py b/src/napari_cellulus/meta_data.py index 9ae1a17..40c2889 100644 --- a/src/napari_cellulus/meta_data.py +++ b/src/napari_cellulus/meta_data.py @@ -1,9 +1,5 @@ from typing import Tuple -import zarr - -from cellulus.configs import DatasetConfig - class NapariDatasetMetaData: def __init__(self, shape, axis_names): @@ -34,4 +30,3 @@ def __init__(self, shape, axis_names): elif axis_name == "x": self.num_spatial_dims += 1 self.spatial_array += (shape[dim],) - From 80cc22812252d999060bef8e7f616909b8fdf7ad Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 20 Sep 2023 15:07:05 -0400 Subject: [PATCH 12/40] Update NapariImageSource class --- .../gp/nodes/napari_image_source.py | 48 +------------------ 1 file changed, 1 insertion(+), 47 deletions(-) diff --git a/src/napari_cellulus/gp/nodes/napari_image_source.py b/src/napari_cellulus/gp/nodes/napari_image_source.py index a3abc4c..3999bfa 100644 --- a/src/napari_cellulus/gp/nodes/napari_image_source.py +++ b/src/napari_cellulus/gp/nodes/napari_image_source.py @@ -2,7 +2,6 @@ import gunpowder as gp from gunpowder.array_spec import ArraySpec -from gunpowder.profiling import Timing from napari.layers import Image @@ -23,9 +22,7 @@ def __init__( self.array_spec = self._read_metadata(image) else: self.array_spec = spec - self.image = gp.Array( - self._remove_leading_dims(image.data), self.array_spec - ) + self.image = gp.Array(image.data.astype(float), self.array_spec) self.key = key def setup(self): @@ -33,48 +30,5 @@ def setup(self): def provide(self, request): output = gp.Batch() - - timing_provide = Timing(self, "provide") - timing_provide.start() - output[self.key] = self.image.crop(request[self.key].roi) - - timing_provide.stop() - - output.profiling_stats.add(timing_provide) - return output - - def _remove_leading_dims(self, data): - while data.shape[0] == 1: - data = data[0] - return data - - def _read_metadata(self, image): - # offset assumed to be in world coordinates - # TODO: read from metadata - data_shape = image.data.shape - # strip leading singleton dimensions (2D data is often given a leading singleton 3rd dimension) - while data_shape[0] == 1: - data_shape = data_shape[1:] - axes = image.metadata.get("axes") - if axes is not None: - ndims = len(axes) - assert ndims <= len( - data_shape - ), f"{axes} incompatible with shape: {data_shape}" - else: - ndims = len(data_shape) - - offset = gp.Coordinate(image.metadata.get("offset", (0,) * ndims)) - voxel_size = gp.Coordinate( - image.metadata.get("resolution", (1,) * ndims) - ) - shape = gp.Coordinate(image.data.shape[-offset.dims :]) - - return gp.ArraySpec( - roi=gp.Roi(offset, voxel_size * shape), - dtype=image.dtype, - interpolatable=True, - voxel_size=voxel_size, - ) From 8922bb33dd6a948439e0c6671c2b77ddcabfdd5e Mon Sep 17 00:00:00 2001 From: lmanan Date: Thu, 21 Sep 2023 10:21:26 -0400 Subject: [PATCH 13/40] Use qtpy instead of PyQt5 --- src/napari_cellulus/gui_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/napari_cellulus/gui_helpers.py b/src/napari_cellulus/gui_helpers.py index 84fd39f..a4aa620 100644 --- a/src/napari_cellulus/gui_helpers.py +++ b/src/napari_cellulus/gui_helpers.py @@ -6,7 +6,7 @@ NavigationToolbar2QT as NavigationToolbar, ) from matplotlib.figure import Figure -from PyQt5 import QtWidgets +from qtpy import QtWidgets class MplCanvas(FigureCanvasQTAgg): From 94124495ef5ea036abc68b836deb953861f806ca Mon Sep 17 00:00:00 2001 From: lmanan Date: Thu, 21 Sep 2023 10:21:51 -0400 Subject: [PATCH 14/40] Specify dtype as np.float32 --- src/napari_cellulus/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py index 8cba5a1..f89ae0e 100644 --- a/src/napari_cellulus/dataset.py +++ b/src/napari_cellulus/dataset.py @@ -2,6 +2,7 @@ from typing import List, Tuple import gunpowder as gp +import numpy as np from napari.layers import Image from torch.utils.data import IterableDataset @@ -81,7 +82,7 @@ def __setup_pipeline(self): shape = gp.Coordinate(self.layer.data.shape) raw_spec = gp.ArraySpec( roi=gp.Roi(offset, voxel_size * shape), - dtype=float, + dtype=np.float32, interpolatable=True, voxel_size=voxel_size, ) From 3b8699e23a4d61c845dec621d6b974818afa38c6 Mon Sep 17 00:00:00 2001 From: lmanan Date: Thu, 21 Sep 2023 10:23:05 -0400 Subject: [PATCH 15/40] Specify dtype as np.float32 --- src/napari_cellulus/gp/nodes/napari_image_source.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/napari_cellulus/gp/nodes/napari_image_source.py b/src/napari_cellulus/gp/nodes/napari_image_source.py index 3999bfa..1dd8c12 100644 --- a/src/napari_cellulus/gp/nodes/napari_image_source.py +++ b/src/napari_cellulus/gp/nodes/napari_image_source.py @@ -1,6 +1,7 @@ from typing import Optional import gunpowder as gp +import numpy as np from gunpowder.array_spec import ArraySpec from napari.layers import Image @@ -10,7 +11,7 @@ class NapariImageSource(gp.BatchProvider): A gunpowder interface to a napari Image Args: image (Image): - The napari Image to pull data from + The napari image layer to pull data from key (``gp.ArrayKey``): The key to provide data into """ @@ -22,7 +23,7 @@ def __init__( self.array_spec = self._read_metadata(image) else: self.array_spec = spec - self.image = gp.Array(image.data.astype(float), self.array_spec) + self.image = gp.Array(image.data.astype(np.float32), self.array_spec) self.key = key def setup(self): From df5389f376c5bfa7bfdecee262a79a954b6381df Mon Sep 17 00:00:00 2001 From: lmanan Date: Thu, 21 Sep 2023 10:25:05 -0400 Subject: [PATCH 16/40] Enumerate all cases for pipeline --- src/napari_cellulus/widgets/_widget.py | 136 +++++++++++++++++++++---- 1 file changed, 115 insertions(+), 21 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 594b334..91660ea 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -51,6 +51,7 @@ # local package imports from ..gui_helpers import MplCanvas, layer_choice_widget +from ..meta_data import NapariDatasetMetaData @dataclasses.dataclass @@ -382,8 +383,8 @@ def async_segment( bandwidth: int, min_size: int, ) -> List[napari.types.LayerDataTuple]: - raw.data = raw.data.astype(np.float32) + raw.data = raw.data.astype(np.float32) global _model assert ( @@ -397,8 +398,14 @@ def async_segment( # device device = torch.device(_train_config.device) - num_spatial_dims = len(raw.data.shape) - 2 - num_channels = raw.data.shape[1] + axis_names = self.get_selected_axes() + meta_data = NapariDatasetMetaData(raw.data.shape, axis_names) + + num_spatial_dims = meta_data.num_spatial_dims + num_channels = meta_data.num_channels + + if meta_data.num_channels == 0: + num_channels = 1 voxel_size = gp.Coordinate((1,) * num_spatial_dims) model.set_infer( @@ -423,8 +430,13 @@ def async_segment( ).shape ) - input_size = gp.Coordinate(input_shape[2:]) * voxel_size - output_size = gp.Coordinate(output_shape[2:]) * voxel_size + input_size = ( + gp.Coordinate(input_shape[-num_spatial_dims:]) * voxel_size + ) + output_size = ( + gp.Coordinate(output_shape[-num_spatial_dims:]) + * voxel_size + ) context = (input_size - output_size) / 2 @@ -443,27 +455,110 @@ def async_segment( prediction_key: gp.ArraySpec(voxel_size=voxel_size) }, ) + if meta_data.num_samples == 0 and meta_data.num_channels == 0: + pipeline = ( + NapariImageSource( + raw, + raw_key, + gp.ArraySpec( + gp.Roi( + (0,) * num_spatial_dims, + raw.data.shape[-num_spatial_dims:], + ), + voxel_size=voxel_size, + ), + ) + + gp.Pad(raw_key, context) + + gp.Unsqueeze([raw_key], 0) + + gp.Unsqueeze([raw_key], 0) + + predict + + gp.Scan(scan_request) + ) + elif ( + meta_data.num_samples != 0 and meta_data.num_channels == 0 + ): + pipeline = ( + NapariImageSource( + raw, + raw_key, + gp.ArraySpec( + gp.Roi( + (0,) * num_spatial_dims, + raw.data.shape[-num_spatial_dims:], + ), + voxel_size=voxel_size, + ), + ) + + gp.Pad(raw_key, context) + + gp.Unsqueeze([raw_key], 1) + + predict + + gp.Scan(scan_request) + ) + elif ( + meta_data.num_samples == 0 and meta_data.num_channels != 0 + ): + pipeline = ( + NapariImageSource( + raw, + raw_key, + gp.ArraySpec( + gp.Roi( + (0,) * num_spatial_dims, + raw.data.shape[-num_spatial_dims:], + ), + voxel_size=voxel_size, + ), + ) + + gp.Pad(raw_key, context) + + gp.Unsqueeze([raw_key], 0) + + predict + + gp.Scan(scan_request) + ) - pipeline = ( - NapariImageSource( - raw, - raw_key, - gp.ArraySpec( - gp.Roi( - (0,) * num_spatial_dims, raw.data.shape[2:] + elif ( + meta_data.num_samples != 0 and meta_data.num_channels == 0 + ): + pipeline = ( + NapariImageSource( + raw, + raw_key, + gp.ArraySpec( + gp.Roi( + (0,) * num_spatial_dims, + raw.data.shape[-num_spatial_dims:], + ), + voxel_size=voxel_size, ), - voxel_size=voxel_size, - ), + ) + + gp.Pad(raw_key, context) + + gp.Unsqueeze([raw_key], 1) + + predict + + gp.Scan(scan_request) ) - + gp.Pad(raw_key, context) - + predict - + gp.Scan(scan_request) - ) + elif ( + meta_data.num_samples != 0 and meta_data.num_channels != 0 + ): + pipeline = ( + NapariImageSource( + raw, + raw_key, + gp.ArraySpec( + gp.Roi( + (0,) * num_spatial_dims, + raw.data.shape[-num_spatial_dims:], + ), + voxel_size=voxel_size, + ), + ) + + gp.Pad(raw_key, context) + + predict + + gp.Scan(scan_request) + ) # request to pipeline for ROI of whole image/volume request = gp.BatchRequest() - request.add(raw_key, raw.data.shape[2:]) - request.add(prediction_key, raw.data.shape[2:]) + request.add(raw_key, raw.data.shape[-num_spatial_dims:]) + request.add(prediction_key, raw.data.shape[-num_spatial_dims:]) with gp.build(pipeline): batch = pipeline.request_batch(request) @@ -489,7 +584,6 @@ def async_segment( labels = np.zeros_like( prediction[:, 0:1, ...].data, dtype=np.uint64 ) - num_spatial_dims = len(prediction.data.shape) - 2 for sample in tqdm(range(prediction.data.shape[0])): embeddings = prediction[sample] From 0af93d2b39d943cec9cb03ad784763e8e0ab1aba Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 26 Sep 2023 09:42:05 -0400 Subject: [PATCH 17/40] Update default value of density --- src/napari_cellulus/widgets/_widget.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 91660ea..7a7875f 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -106,7 +106,7 @@ def train_config_widget( batch_size: int = 8, max_iterations: int = 100_000, initial_learning_rate: float = 4e-5, - density: float = 0.2, + density: float = 0.1, kappa: float = 10.0, temperature: float = 10.0, regularizer_weight: float = 1e-5, From e40370711e297fabd4b18f429037abe67a807c8b Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 26 Sep 2023 14:14:31 -0400 Subject: [PATCH 18/40] Comment out elastic augment --- src/napari_cellulus/dataset.py | 45 ---------------------------------- 1 file changed, 45 deletions(-) diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py index f89ae0e..95ad536 100644 --- a/src/napari_cellulus/dataset.py +++ b/src/napari_cellulus/dataset.py @@ -1,4 +1,3 @@ -import math from typing import List, Tuple import gunpowder as gp @@ -91,70 +90,26 @@ def __setup_pipeline(self): self.pipeline = ( NapariImageSource(self.layer, self.raw, raw_spec) + gp.RandomLocation() - + gp.ElasticAugment( - control_point_spacing=(self.control_point_spacing,) - * self.num_spatial_dims, - jitter_sigma=(self.control_point_jitter,) - * self.num_spatial_dims, - rotation_interval=(0, math.pi / 2), - scale_interval=(0.9, 1.1), - subsample=4, - spatial_dims=self.num_spatial_dims, - ) + gp.Unsqueeze([self.raw], 0) + gp.Unsqueeze([self.raw], 0) - # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) ) elif self.num_channels == 0 and self.num_samples != 0: self.pipeline = ( NapariImageSource(self.layer, self.raw, raw_spec) + gp.Unsqueeze([self.raw], 1) + gp.RandomLocation() - + gp.ElasticAugment( - control_point_spacing=(self.control_point_spacing,) - * self.num_spatial_dims, - jitter_sigma=(self.control_point_jitter,) - * self.num_spatial_dims, - rotation_interval=(0, math.pi / 2), - scale_interval=(0.9, 1.1), - subsample=4, - spatial_dims=self.num_spatial_dims, - ) + gp.Unsqueeze([self.raw], 1) - # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) ) elif self.num_channels != 0 and self.num_samples == 0: self.pipeline = ( NapariImageSource(self.layer, self.raw, raw_spec) + gp.RandomLocation() - + gp.ElasticAugment( - control_point_spacing=(self.control_point_spacing,) - * self.num_spatial_dims, - jitter_sigma=(self.control_point_jitter,) - * self.num_spatial_dims, - rotation_interval=(0, math.pi / 2), - scale_interval=(0.9, 1.1), - subsample=4, - spatial_dims=self.num_spatial_dims, - ) + gp.Unsqueeze([self.raw], 0) - # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) ) elif self.num_channels != 0 and self.num_samples != 0: self.pipeline = ( NapariImageSource(self.layer, self.raw, raw_spec) + gp.RandomLocation() - + gp.ElasticAugment( - control_point_spacing=(self.control_point_spacing,) - * self.num_spatial_dims, - jitter_sigma=(self.control_point_jitter,) - * self.num_spatial_dims, - rotation_interval=(0, math.pi / 2), - scale_interval=(0.9, 1.1), - subsample=4, - spatial_dims=self.num_spatial_dims, - ) - # + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims) ) def __yield_sample(self): From fee80c8932ff33232de3cd3cd7a7cea9cba5a52b Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 26 Sep 2023 14:15:19 -0400 Subject: [PATCH 19/40] Correct visualization of offsets --- src/napari_cellulus/widgets/_widget.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 7a7875f..f0d6e40 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -429,7 +429,6 @@ def async_segment( ).to(device) ).shape ) - input_size = ( gp.Coordinate(input_shape[-num_spatial_dims:]) * voxel_size ) @@ -568,17 +567,20 @@ def async_segment( ( prediction[:, i : i + 1, ...].copy(), { - "name": "offset-" + "zyx"[num_channels - i] - if i < num_channels + "name": "offset-" + + "zyx"[meta_data.num_spatial_dims - i] + if i < meta_data.num_spatial_dims else "std", - "colormap": colormaps[num_channels - i] - if i < num_channels + "colormap": colormaps[ + meta_data.num_spatial_dims - i + ] + if i < meta_data.num_spatial_dims else "gray", "blending": "additive", }, "image", ) - for i in range(num_channels + 1) + for i in range(meta_data.num_spatial_dims + 1) ] labels = np.zeros_like( From 652a576e9b0856397250176f09c47a3f692fb353 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 2 Oct 2023 08:18:34 +0200 Subject: [PATCH 20/40] Redesign plugin --- src/napari_cellulus/__init__.py | 14 +- src/napari_cellulus/napari.yaml | 20 +- src/napari_cellulus/widgets/_widget.py | 979 +------------------------ 3 files changed, 47 insertions(+), 966 deletions(-) diff --git a/src/napari_cellulus/__init__.py b/src/napari_cellulus/__init__.py index c1d6450..47a9747 100644 --- a/src/napari_cellulus/__init__.py +++ b/src/napari_cellulus/__init__.py @@ -1,15 +1,5 @@ -__version__ = "0.0.2" +__version__ = "0.0.3" from ._sample_data import tissuenet_sample -from .widgets._widget import ( - TrainWidget, - model_config_widget, - train_config_widget, -) -__all__ = ( - "tissuenet_sample", - "train_config_widget", - "model_config_widget", - "TrainWidget", -) +__all__ = ("tissuenet_sample",) diff --git a/src/napari_cellulus/napari.yaml b/src/napari_cellulus/napari.yaml index 36f1d03..1c22723 100644 --- a/src/napari_cellulus/napari.yaml +++ b/src/napari_cellulus/napari.yaml @@ -5,23 +5,13 @@ contributions: - id: napari-cellulus.tissuenet_sample python_name: napari_cellulus._sample_data:tissuenet_sample title: Load sample data from Cellulus - - id: napari-cellulus.train_config - python_name: napari_cellulus.widgets._widget:train_config_widget - title: Make the training config widget - - id: napari-cellulus.model_config - python_name: napari_cellulus.widgets._widget:model_config_widget - title: Make the model config widget - - id: napari-cellulus.train_widget - python_name: napari_cellulus.widgets._widget:TrainWidget - title: Make the train widget + - id: napari-cellulus.SegmentationWidget + python_name: napari_cellulus.widgets._widget:SegmentationWidget + title: Cellulus sample_data: - command: napari-cellulus.tissuenet_sample display_name: Cellulus key: tissuenet_sample widgets: - - command: napari-cellulus.train_config - display_name: Train Config Widget - - command: napari-cellulus.model_config - display_name: Model Config Widget - - command: napari-cellulus.train_widget - display_name: Train Widget + - command: napari-cellulus.SegmentationWidget + display_name: Cellulus diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index f0d6e40..0a251bd 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -1,991 +1,92 @@ -""" -This module is an example of a barebones QWidget plugin for napari +from typing import List -It implements the Widget specification. -see: https://napari.org/stable/plugins/guides.html?#widgets - -Replace code below according to your needs. -""" -import contextlib -import dataclasses - -# python built in libraries -from pathlib import Path -from typing import List, Optional - -# github repo libraries -import gunpowder as gp - -# pip installed libraries import napari -import numpy as np -import torch -from cellulus.configs.model_config import ModelConfig -from cellulus.configs.train_config import TrainConfig -from cellulus.criterions import get_loss -from cellulus.models import get_model -from cellulus.utils.mean_shift import mean_shift_segmentation from magicgui import magic_factory -from magicgui.widgets import Container - -# widget stuff -from matplotlib.backends.backend_qt5agg import ( - NavigationToolbar2QT as NavigationToolbar, -) from napari.qt.threading import FunctionWorker, thread_worker -from qtpy.QtCore import QUrl -from qtpy.QtGui import QDesktopServices from qtpy.QtWidgets import ( - QCheckBox, - QGroupBox, - QHBoxLayout, + QGridLayout, + QLabel, + QLineEdit, QPushButton, + QScrollArea, QVBoxLayout, - QWidget, ) from superqt import QCollapsible -from tqdm import tqdm - -from ..dataset import NapariDataset -from ..gp.nodes.napari_image_source import NapariImageSource - -# local package imports -from ..gui_helpers import MplCanvas, layer_choice_widget -from ..meta_data import NapariDatasetMetaData - - -@dataclasses.dataclass -class TrainingStats: - iteration: int = 0 - losses: list[float] = dataclasses.field(default_factory=list) - iterations: list[int] = dataclasses.field(default_factory=list) - def reset(self): - self.iteration = 0 - self.losses = [] - self.iterations = [] - def load(self, other): - self.iteration = other.iteration - self.losses = other.losses - self.iterations = other.iterations - - -################################## GLOBALS #################################### -_train_config: Optional[TrainConfig] = None -_model_config: Optional[ModelConfig] = None -_model: Optional[torch.nn.Module] = None -_optimizer: Optional[torch.optim.Optimizer] = None -_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None -_training_stats: TrainingStats = TrainingStats() - - -def get_training_stats(): - global _training_stats - return _training_stats - - -def get_train_config(**kwargs): - global _train_config - # set dataset configs to None - kwargs["train_data_config"] = None - kwargs["validate_data_config"] = None - if _train_config is None: - _train_config = TrainConfig(**kwargs) - elif len(kwargs) > 0: - for k, v in kwargs.items(): - _train_config.__setattr__(k, v) - return _train_config - - -@magic_factory( - call_button="Save", device={"choices": ["cpu", "cuda:0", "mps"]} -) -def train_config_widget( - crop_size: list[int] = [252, 252], - batch_size: int = 8, - max_iterations: int = 100_000, - initial_learning_rate: float = 4e-5, - density: float = 0.1, - kappa: float = 10.0, - temperature: float = 10.0, - regularizer_weight: float = 1e-5, - reduce_mean: bool = True, - save_model_every: int = 1_000, - save_snapshot_every: int = 1_000, - num_workers: int = 8, - control_point_spacing: int = 64, - control_point_jitter: float = 2.0, - device="cpu", -): - get_train_config( - crop_size=crop_size, - batch_size=batch_size, - max_iterations=max_iterations, - initial_learning_rate=initial_learning_rate, - density=density, - kappa=kappa, - temperature=temperature, - regularizer_weight=regularizer_weight, - reduce_mean=reduce_mean, - save_model_every=save_model_every, - save_snapshot_every=save_snapshot_every, - num_workers=num_workers, - control_point_spacing=control_point_spacing, - control_point_jitter=control_point_jitter, - device=device, - ) - - -def get_model_config(**kwargs): - global _model_config - if _model_config is None: - _model_config = ModelConfig(**kwargs) - elif len(kwargs) > 0: - for k, v in kwargs.items(): - _model_config.__setattr__(k, v) - return _model_config - - -@magic_factory -def model_config_widget( - num_fmaps: int = 256, - fmap_inc_factor: int = 3, - features_in_last_layer: int = 64, - downsampling_factors: list[list[int]] = [[2, 2]], -): - get_model_config( - num_fmaps=num_fmaps, - fmap_inc_factor=fmap_inc_factor, - features_in_last_layer=features_in_last_layer, - downsampling_factors=downsampling_factors, - ) - - -def get_training_state(dataset: Optional[NapariDataset] = None): - global _model - global _optimizer - global _scheduler - global _model_config - global _train_config - - # set device - device = torch.device(_train_config.device) - - if _model_config is None: - # TODO: deal with hard coded defaults - _model_config = ModelConfig(num_fmaps=24, fmap_inc_factor=3) - if _train_config is None: - _train_config = get_train_config() - if _model is None: - # Build model - _model = get_model( - in_channels=dataset.get_num_channels(), - out_channels=dataset.get_num_spatial_dims(), - num_fmaps=_model_config.num_fmaps, - fmap_inc_factor=_model_config.fmap_inc_factor, - features_in_last_layer=_model_config.features_in_last_layer, - downsampling_factors=[ - tuple(factor) for factor in _model_config.downsampling_factors - ], - num_spatial_dims=dataset.get_num_spatial_dims(), - ).to(device) - - # Weight initialization - # TODO: move weight initialization to funlib.learn.torch - for _name, layer in _model.named_modules(): - if isinstance(layer, torch.nn.modules.conv._ConvNd): - torch.nn.init.kaiming_normal_( - layer.weight, nonlinearity="relu" - ) - - _optimizer = torch.optim.Adam( - _model.parameters(), - lr=_train_config.initial_learning_rate, - ) - - def lambda_(iteration): - return pow((1 - ((iteration) / _train_config.max_iterations)), 0.9) - - _scheduler = torch.optim.lr_scheduler.LambdaLR( - _optimizer, lr_lambda=lambda_ - ) - return (_model, _optimizer, _scheduler) - - -class TrainWidget(QWidget): +class SegmentationWidget(QScrollArea): def __init__(self, napari_viewer): - # basic initialization - self.viewer = napari_viewer super().__init__() + self.viewer = napari_viewer - # initialize state variables - self.__training_generator = None - - # Widget layout - layout = QVBoxLayout() - - # add loss/iterations widget - self.progress_plot = MplCanvas(self, width=5, height=3, dpi=100) - toolbar = NavigationToolbar(self.progress_plot, self) - progress_plot_layout = QVBoxLayout() - progress_plot_layout.addWidget(toolbar) - progress_plot_layout.addWidget(self.progress_plot) - self.loss_plot = None - self.val_plot = None - plot_container_widget = QWidget() - plot_container_widget.setLayout(progress_plot_layout) - layout.addWidget(plot_container_widget) - - # add raw layer choice - self.raw_selector = layer_choice_widget( - self.viewer, - annotation=napari.layers.Image, - name="raw", + # define components + logo_path = "" + self.logo_label = QLabel(f'

Cellulus

') + self.method_description_label = QLabel( + 'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.
' ) - layout.addWidget(self.raw_selector.native) - - self.s_checkbox = QCheckBox("s") - self.c_checkbox = QCheckBox("c") - self.t_checkbox = QCheckBox("t") - self.z_checkbox = QCheckBox("z") - self.y_checkbox = QCheckBox("y") - self.x_checkbox = QCheckBox("x") - axis_layout = QHBoxLayout() - axis_layout.addWidget(self.s_checkbox) - axis_layout.addWidget(self.c_checkbox) - axis_layout.addWidget(self.t_checkbox) - axis_layout.addWidget(self.z_checkbox) - axis_layout.addWidget(self.y_checkbox) - axis_layout.addWidget(self.x_checkbox) + self.download_data_label = QLabel("

Download Data

") + self.data_dir_label = QLabel("Data Directory") + self.data_dir_pushbutton = QPushButton("Browse") + self.data_dir_pushbutton.setMaximumWidth(280) + # self.data_dir_pushbutton.clicked.connect(self._prepare_data_dir) - self.axis_selector = QGroupBox("Axis Names:") - self.axis_selector.setLayout(axis_layout) - layout.addWidget(self.axis_selector) + self.object_size_label = QLabel("Rough Object size [px]") + self.object_size_edit = QLineEdit("30") - # add buttons - self.train_button = QPushButton("Train!", self) - self.train_button.clicked.connect(self.train) - layout.addWidget(self.train_button) + # define layout + outer_layout = QVBoxLayout() - # add save and load widgets - collapsable_save_load_widget = QCollapsible("Save/Load", self) - collapsable_save_load_widget.addWidget(self.save_widget.native) - collapsable_save_load_widget.addWidget(self.load_widget.native) - - layout.addWidget(collapsable_save_load_widget) + # inner layout + grid_0 = QGridLayout() + grid_0.addWidget(self.logo_label, 0, 0, 1, 1) + grid_0.addWidget(self.method_description_label, 0, 1, 1, 1) + grid_0.setSpacing(10) # Add segment widget - collapsable_segment_widget = QCollapsible("Segment", self) - collapsable_segment_widget.addWidget(self.segment_widget) - layout.addWidget(collapsable_segment_widget) - - # add feedback button - self.feedback_button = QPushButton("Feedback!", self) - self.feedback_button.clicked.connect( - lambda: QDesktopServices.openUrl( - QUrl( - "https://github.com/funkelab/napari-cellulus/issues/new/choose" - ) - ) - ) - layout.addWidget(self.feedback_button) - - # activate layout - self.setLayout(layout) - - # Widget state - self.model = None - self.reset_training_state() - - # TODO: Can we do this better? - # connect napari events - self.viewer.layers.events.inserted.connect( - self.__segment_widget.raw.reset_choices - ) - self.viewer.layers.events.removed.connect( - self.__segment_widget.raw.reset_choices - ) - - # handle button activations and deactivations - # buttons: save, load, (train/pause), segment - self.save_button = self.__save_widget.call_button.native - self.load_button = self.__load_widget.call_button.native - self.segment_button = self.__segment_widget.call_button.native - self.segment_button.clicked.connect( - lambda: self.set_buttons("segmenting") - ) - - self.set_buttons("initial") + collapsible_0 = QCollapsible("Inference", self) + collapsible_0.addWidget(self.segment_widget) - def set_buttons(self, state: str): - if state == "training": - self.train_button.setText("Pause!") - self.train_button.setEnabled(True) - self.save_button.setText("Stop training to save!") - self.save_button.setEnabled(False) - self.load_button.setText("Stop training to load!") - self.load_button.setEnabled(False) - self.segment_button.setText("Stop training to segment!") - self.segment_button.setEnabled(False) - if state == "paused": - self.train_button.setText("Train!") - self.train_button.setEnabled(True) - self.save_button.setText("Save") - self.save_button.setEnabled(True) - self.load_button.setText("Load") - self.load_button.setEnabled(True) - self.segment_button.setText("Segment") - self.segment_button.setEnabled(True) - if state == "segmenting": - self.train_button.setText("Can't train while segmenting!") - self.train_button.setEnabled(False) - self.save_button.setText("Can't save while segmenting!") - self.save_button.setEnabled(False) - self.load_button.setText("Can't load while segmenting!") - self.load_button.setEnabled(False) - self.segment_button.setText("Segmenting...") - self.segment_button.setEnabled(False) - if state == "initial": - self.train_button.setText("Train!") - self.train_button.setEnabled(True) - self.save_button.setText("No state to Save!") - self.save_button.setEnabled(False) - self.load_button.setText( - "Load data and test data before loading an old model!" - ) - self.load_button.setEnabled(False) - self.segment_button.setText("Segment") - self.segment_button.setEnabled(True) + outer_layout.addLayout(grid_0) + outer_layout.addWidget(collapsible_0) + self.setLayout(outer_layout) + self.setFixedWidth(500) @property def segment_widget(self): @magic_factory(call_button="Segment") def segment( raw: napari.layers.Image, - crop_size: list[int] = [252, 252], + crop_size: int = 252, p_salt_pepper: float = 0.1, num_infer_iterations: int = 16, bandwidth: int = 7, - min_size: int = 10, + min_size: int = 25, ) -> FunctionWorker[List[napari.types.LayerDataTuple]]: - # TODO: do this better? @thread_worker( connect={"returned": lambda: self.set_buttons("paused")}, progress={"total": 0, "desc": "Segmenting"}, ) def async_segment( raw: napari.layers.Image, - crop_size: list[int], + crop_size: int, p_salt_pepper: float, num_infer_iterations: int, bandwidth: int, min_size: int, ) -> List[napari.types.LayerDataTuple]: - raw.data = raw.data.astype(np.float32) - global _model - - assert ( - _model is not None - ), "You must train a model before running inference" - model = _model - - # set in eval mode - model.eval() - - # device - device = torch.device(_train_config.device) - - axis_names = self.get_selected_axes() - meta_data = NapariDatasetMetaData(raw.data.shape, axis_names) - - num_spatial_dims = meta_data.num_spatial_dims - num_channels = meta_data.num_channels - - if meta_data.num_channels == 0: - num_channels = 1 - - voxel_size = gp.Coordinate((1,) * num_spatial_dims) - model.set_infer( + return async_segment( + raw, + crop_size=crop_size, p_salt_pepper=p_salt_pepper, num_infer_iterations=num_infer_iterations, - device=device, - ) - - # prediction crop size is the size of the scanned tiles to be provided to the model - input_shape = gp.Coordinate((1, num_channels, *crop_size)) - - output_shape = gp.Coordinate( - model( - torch.zeros( - ( - 1, - num_channels, - *crop_size, - ), - dtype=torch.float32, - ).to(device) - ).shape - ) - input_size = ( - gp.Coordinate(input_shape[-num_spatial_dims:]) * voxel_size - ) - output_size = ( - gp.Coordinate(output_shape[-num_spatial_dims:]) - * voxel_size - ) - - context = (input_size - output_size) / 2 - - raw_key = gp.ArrayKey("RAW") - prediction_key = gp.ArrayKey("PREDICT") - - scan_request = gp.BatchRequest() - scan_request.add(raw_key, input_size) - scan_request.add(prediction_key, output_size) - - predict = gp.torch.Predict( - model, - inputs={"raw": raw_key}, - outputs={0: prediction_key}, - array_specs={ - prediction_key: gp.ArraySpec(voxel_size=voxel_size) - }, + bandwidth=bandwidth, + min_size=min_size, ) - if meta_data.num_samples == 0 and meta_data.num_channels == 0: - pipeline = ( - NapariImageSource( - raw, - raw_key, - gp.ArraySpec( - gp.Roi( - (0,) * num_spatial_dims, - raw.data.shape[-num_spatial_dims:], - ), - voxel_size=voxel_size, - ), - ) - + gp.Pad(raw_key, context) - + gp.Unsqueeze([raw_key], 0) - + gp.Unsqueeze([raw_key], 0) - + predict - + gp.Scan(scan_request) - ) - elif ( - meta_data.num_samples != 0 and meta_data.num_channels == 0 - ): - pipeline = ( - NapariImageSource( - raw, - raw_key, - gp.ArraySpec( - gp.Roi( - (0,) * num_spatial_dims, - raw.data.shape[-num_spatial_dims:], - ), - voxel_size=voxel_size, - ), - ) - + gp.Pad(raw_key, context) - + gp.Unsqueeze([raw_key], 1) - + predict - + gp.Scan(scan_request) - ) - elif ( - meta_data.num_samples == 0 and meta_data.num_channels != 0 - ): - pipeline = ( - NapariImageSource( - raw, - raw_key, - gp.ArraySpec( - gp.Roi( - (0,) * num_spatial_dims, - raw.data.shape[-num_spatial_dims:], - ), - voxel_size=voxel_size, - ), - ) - + gp.Pad(raw_key, context) - + gp.Unsqueeze([raw_key], 0) - + predict - + gp.Scan(scan_request) - ) - - elif ( - meta_data.num_samples != 0 and meta_data.num_channels == 0 - ): - pipeline = ( - NapariImageSource( - raw, - raw_key, - gp.ArraySpec( - gp.Roi( - (0,) * num_spatial_dims, - raw.data.shape[-num_spatial_dims:], - ), - voxel_size=voxel_size, - ), - ) - + gp.Pad(raw_key, context) - + gp.Unsqueeze([raw_key], 1) - + predict - + gp.Scan(scan_request) - ) - - elif ( - meta_data.num_samples != 0 and meta_data.num_channels != 0 - ): - pipeline = ( - NapariImageSource( - raw, - raw_key, - gp.ArraySpec( - gp.Roi( - (0,) * num_spatial_dims, - raw.data.shape[-num_spatial_dims:], - ), - voxel_size=voxel_size, - ), - ) - + gp.Pad(raw_key, context) - + predict - + gp.Scan(scan_request) - ) - # request to pipeline for ROI of whole image/volume - request = gp.BatchRequest() - request.add(raw_key, raw.data.shape[-num_spatial_dims:]) - request.add(prediction_key, raw.data.shape[-num_spatial_dims:]) - with gp.build(pipeline): - batch = pipeline.request_batch(request) - - prediction = batch.arrays[prediction_key].data - colormaps = ["red", "green", "blue"] - prediction_layers = [ - ( - prediction[:, i : i + 1, ...].copy(), - { - "name": "offset-" - + "zyx"[meta_data.num_spatial_dims - i] - if i < meta_data.num_spatial_dims - else "std", - "colormap": colormaps[ - meta_data.num_spatial_dims - i - ] - if i < meta_data.num_spatial_dims - else "gray", - "blending": "additive", - }, - "image", - ) - for i in range(meta_data.num_spatial_dims + 1) - ] - - labels = np.zeros_like( - prediction[:, 0:1, ...].data, dtype=np.uint64 - ) - - for sample in tqdm(range(prediction.data.shape[0])): - embeddings = prediction[sample] - embeddings_std = embeddings[-1, ...] - embeddings_mean = embeddings[ - np.newaxis, :num_spatial_dims, ... - ] - segmentation = mean_shift_segmentation( - embeddings_mean, - embeddings_std, - bandwidth=bandwidth, - min_size=min_size, - ) - labels[ - sample, - 0, - ..., - ] = segmentation - return prediction_layers + [ - (labels, {"name": "Segmentation"}, "labels") - ] - - return async_segment( - raw, - crop_size=crop_size, - p_salt_pepper=p_salt_pepper, - num_infer_iterations=num_infer_iterations, - bandwidth=bandwidth, - min_size=min_size, - ) if not hasattr(self, "__segment_widget"): self.__segment_widget = segment() self.__segment_widget_native = self.__segment_widget.native return self.__segment_widget_native - - @property - def save_widget(self): - # TODO: block buttons on call. This shouldn't take long, but other operations such - # as continuing to train should be blocked until this is done. - def on_return(): - self.set_buttons("paused") - - @magic_factory(call_button="Save") - def save(path: Path = Path("checkpoint.pt")) -> FunctionWorker[None]: - @thread_worker( - connect={"returned": lambda: on_return()}, - progress={"total": 0, "desc": "Saving"}, - ) - def async_save(path: Path = Path("checkpoint.pt")) -> None: - model, optimizer, scheduler = get_training_state() - training_stats = get_training_stats() - torch.save( - ( - model.state_dict(), - optimizer.state_dict(), - scheduler.state_dict(), - training_stats, - ), - path, - ) - - return async_save(path) - - if not hasattr(self, "__save_widget"): - self.__save_widget = save() - - return self.__save_widget - - @property - def load_widget(self): - # TODO: block buttons on call. This shouldn't take long, but other operations such - # as continuing to train should be blocked until this is done. - def on_return(): - self.update_progress_plot() - self.set_buttons("paused") - - @magic_factory(call_button="Load") - def load(path: Path = Path("checkpoint.pt")) -> FunctionWorker[None]: - @thread_worker( - connect={"returned": on_return}, - progress={"total": 0, "desc": "Saving"}, - ) - def async_load(path: Path = Path("checkpoint.pt")) -> None: - model, optimizer, scheduler = get_training_state() - training_stats = get_training_stats() - state_dicts = torch.load( - path, - ) - model.load_state_dict(state_dicts[0]) - optimizer.load_state_dict(state_dicts[1]) - scheduler.load_state_dict(state_dicts[2]) - training_stats.load(state_dicts[3]) - - return async_load(path) - - if not hasattr(self, "__load_widget"): - self.__load_widget = load() - - return self.__load_widget - - @property - def training(self) -> bool: - try: - return self.__training - except AttributeError: - return False - - @training.setter - def training(self, training: bool): - self.__training = training - if training: - if self.__training_generator is None: - self.start_training_loop() - assert self.__training_generator is not None - self.__training_generator.resume() - self.set_buttons("training") - else: - if self.__training_generator is not None: - self.__training_generator.send("stop") - # button state handled by on_return - - def reset_training_state(self, keep_stats=False): - if self.__training_generator is not None: - self.__training_generator.quit() - self.__training_generator = None - if not keep_stats: - training_stats = get_training_stats() - training_stats.reset() - if self.loss_plot is None: - self.loss_plot = self.progress_plot.axes.plot( - [], - [], - label="Training Loss", - )[0] - self.progress_plot.axes.legend() - self.progress_plot.axes.set_title("Training Progress") - self.progress_plot.axes.set_xlabel("Iterations") - self.progress_plot.axes.set_ylabel("Loss") - self.update_progress_plot() - - def update_progress_plot(self): - training_stats = get_training_stats() - self.loss_plot.set_xdata(training_stats.iterations) - self.loss_plot.set_ydata(training_stats.losses) - self.progress_plot.axes.relim() - self.progress_plot.axes.autoscale_view() - with contextlib.suppress(np.linalg.LinAlgError): - # matplotlib seems to throw a LinAlgError on draw sometimes. Not sure - # why yet. Seems to only happen when initializing models without any - # layers loaded. No idea whats going wrong. - # For now just avoid drawing. Seems to work as soon as there is data to plot - self.progress_plot.draw() - - def get_selected_axes(self): - names = [] - for name, checkbox in zip( - "sctzyx", - [ - self.s_checkbox, - self.c_checkbox, - self.t_checkbox, - self.z_checkbox, - self.y_checkbox, - self.x_checkbox, - ], - ): - if checkbox.isChecked(): - names.append(name) - return names - - def start_training_loop(self): - self.reset_training_state(keep_stats=True) - training_stats = get_training_stats() - - self.__training_generator = self.train_cellulus( - self.raw_selector.value, - self.get_selected_axes(), - iteration=training_stats.iteration, - ) - self.__training_generator.yielded.connect(self.on_yield) - self.__training_generator.returned.connect(self.on_return) - self.__training_generator.start() - - def train(self): - self.training = not self.training - - def snapshot(self): - self.__training_generator.send("snapshot") - self.training = True - - def spatial_dims(self, ndims): - return ["time", "z", "y", "x"][-ndims:] - - def create_train_widget(self, viewer): - # inputs: - raw = layer_choice_widget( - viewer, - annotation=napari.layers.Image, - name="raw", - ) - train_widget = Container(widgets=[raw]) - - return train_widget - - def on_yield(self, step_data): - iteration, loss, *layers = step_data - if len(layers) > 0: - self.add_layers(layers) - if iteration is not None and loss is not None: - training_stats = get_training_stats() - training_stats.iteration = iteration - training_stats.iterations.append(iteration) - training_stats.losses.append(loss) - self.update_progress_plot() - - def on_return(self, weights_path: Path): - """ - Update model to use provided returned weights - """ - global _model - global _optimizer - global _scheduler - assert ( - _model is not None - and _optimizer is not None - and _scheduler is not None - ) - model_state_dict, optim_state_dict, scheduler_state_dict = torch.load( - weights_path - ) - _model.load_state_dict(model_state_dict) - _optimizer.load_state_dict(optim_state_dict) - _scheduler.load_state_dict(scheduler_state_dict) - self.reset_training_state(keep_stats=True) - self.set_buttons("paused") - - def add_layers(self, layers): - viewer_axis_labels = self.viewer.dims.axis_labels - - for data, metadata, layer_type in layers: - # then try to update the viewer layer with that name. - name = metadata.pop("name") - axes = metadata.pop("axes") - overwrite = metadata.pop("overwrite", False) - slices = metadata.pop("slices", None) - shape = metadata.pop("shape", None) - - # handle viewer axes if still default numerics - # TODO: Support using xarray axis labels as soon as napari does - if len(set(viewer_axis_labels).intersection(set(axes))) == 0: - spatial_axes = [ - axis for axis in axes if axis not in ["batch", "channel"] - ] - assert ( - len(viewer_axis_labels) - len(spatial_axes) <= 1 - ), f"Viewer has axes: {viewer_axis_labels}, but we expect ((channels), {spatial_axes})" - viewer_axis_labels = ( - ("channels", *spatial_axes) - if len(viewer_axis_labels) > len(spatial_axes) - else spatial_axes - ) - self.viewer.dims.axis_labels = viewer_axis_labels - - batch_dim = axes.index("batch") if "batch" in axes else -1 - assert batch_dim in [ - -1, - 0, - ], "Batch dim must be first" - if batch_dim == 0: - data = data[0] - - if slices is not None and shape is not None: - # strip channel dimension from slices and shape - slices = (slice(None, None), *slices[1:]) - shape = (data.shape[0], *shape[1:]) - - # create new data array with filled in chunk - full_data = np.zeros(shape, dtype=data.dtype) - full_data[slices] = data - - else: - slices = tuple(slice(None, None) for _ in data.shape) - full_data = data - - try: - # add to existing layer - layer = self.viewer.layers[name] - - if overwrite: - layer.data[slices] = data - layer.refresh() - else: - # concatenate along batch dimension - layer.data = np.concatenate( - [ - layer.data.reshape(-1, *full_data.shape), - full_data.reshape(-1, *full_data.shape).astype( - layer.data.dtype - ), - ], - axis=0, - ) - # make first dimension "batch" if it isn't - if not overwrite and viewer_axis_labels[0] != "batch": - viewer_axis_labels = ("batch", *viewer_axis_labels) - self.viewer.dims.axis_labels = viewer_axis_labels - - except KeyError: # layer not in the viewer - # TODO: Support defining layer axes as soon as napari does - if layer_type == "image": - self.viewer.add_image(full_data, name=name, **metadata) - elif layer_type == "labels": - self.viewer.add_labels( - full_data.astype(int), name=name, **metadata - ) - - @thread_worker - def train_cellulus( - self, - raw, - axis_names, - iteration=0, - ): - - train_config = get_train_config() - - # Turn layer into dataset: - train_dataset = NapariDataset( - raw, - axis_names, - crop_size=train_config.crop_size, - control_point_spacing=train_config.control_point_spacing, - control_point_jitter=train_config.control_point_jitter, - ) - model, optimizer, scheduler = get_training_state(train_dataset) - - model.train() - - train_dataloader = torch.utils.data.DataLoader( - dataset=train_dataset, - batch_size=train_config.batch_size, - drop_last=True, - num_workers=train_config.num_workers, - pin_memory=True, - ) - - # set loss - criterion = get_loss( - regularizer_weight=train_config.regularizer_weight, - temperature=train_config.temperature, - kappa=train_config.kappa, - density=train_config.density, - num_spatial_dims=train_dataset.get_num_spatial_dims(), - reduce_mean=train_config.reduce_mean, - device=train_config.device, - ) - - def train_iteration(batch, model, criterion, optimizer, device): - prediction = model(batch.to(device)) - loss = criterion(prediction) - loss = loss.mean() - optimizer.zero_grad() - loss.backward() - optimizer.step() - return loss.item(), prediction - - mode = yield (None, None) - # call `train_iteration` - for iteration, batch in tqdm( - zip( - range(iteration, train_config.max_iterations), - train_dataloader, - ) - ): - train_loss, prediction = train_iteration( - batch.float(), - model=model, - criterion=criterion, - optimizer=optimizer, - device=train_config.device, - ) - scheduler.step() - - if mode is None: - mode = yield ( - iteration, - train_loss, - ) - - elif mode == "stop": - checkpoint = Path(f"/tmp/checkpoints/{iteration}.pt") - if not checkpoint.parent.exists(): - checkpoint.parent.mkdir(parents=True) - torch.save( - ( - model.state_dict(), - optimizer.state_dict(), - scheduler.state_dict(), - ), - checkpoint, - ) - return checkpoint From a1b3341236f8348e0a520c92f89e6987c202aef7 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 15 Oct 2023 16:35:26 -0400 Subject: [PATCH 21/40] Add url to github repo --- src/napari_cellulus/widgets/_widget.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 0a251bd..1828799 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -20,10 +20,8 @@ def __init__(self, napari_viewer): self.viewer = napari_viewer # define components - logo_path = "" - self.logo_label = QLabel(f'

Cellulus

') self.method_description_label = QLabel( - 'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.
' + 'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.

https://github.com/funkelab/cellulus' ) self.download_data_label = QLabel("

Download Data

") @@ -40,7 +38,6 @@ def __init__(self, napari_viewer): # inner layout grid_0 = QGridLayout() - grid_0.addWidget(self.logo_label, 0, 0, 1, 1) grid_0.addWidget(self.method_description_label, 0, 1, 1, 1) grid_0.setSpacing(10) From f7d7e3b235cfb7769cea57dcadb3646e28d641eb Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 15 Oct 2023 18:20:44 -0400 Subject: [PATCH 22/40] Add collapsible config widgets --- src/napari_cellulus/widgets/_widget.py | 74 +++++++++++++++++++++----- 1 file changed, 62 insertions(+), 12 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 1828799..fdb27bb 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -6,8 +6,6 @@ from qtpy.QtWidgets import ( QGridLayout, QLabel, - QLineEdit, - QPushButton, QScrollArea, QVBoxLayout, ) @@ -24,15 +22,6 @@ def __init__(self, napari_viewer): 'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.

https://github.com/funkelab/cellulus' ) - self.download_data_label = QLabel("

Download Data

") - self.data_dir_label = QLabel("Data Directory") - self.data_dir_pushbutton = QPushButton("Browse") - self.data_dir_pushbutton.setMaximumWidth(280) - # self.data_dir_pushbutton.clicked.connect(self._prepare_data_dir) - - self.object_size_label = QLabel("Rough Object size [px]") - self.object_size_edit = QLineEdit("30") - # define layout outer_layout = QVBoxLayout() @@ -41,14 +30,75 @@ def __init__(self, napari_viewer): grid_0.addWidget(self.method_description_label, 0, 1, 1, 1) grid_0.setSpacing(10) + # Add train configs widget + collapsible_train_configs = QCollapsible("Train Configs", self) + collapsible_train_configs.addWidget(self.create_train_configs_widget) + + # Add model configs widget + collapsible_model_configs = QCollapsible("Model Configs", self) + collapsible_model_configs.addWidget(self.create_model_configs_widget) + # Add segment widget collapsible_0 = QCollapsible("Inference", self) collapsible_0.addWidget(self.segment_widget) outer_layout.addLayout(grid_0) + outer_layout.addWidget(collapsible_train_configs) + outer_layout.addWidget(collapsible_model_configs) outer_layout.addWidget(collapsible_0) self.setLayout(outer_layout) - self.setFixedWidth(500) + self.setFixedWidth(400) + + @property + def create_train_configs_widget(self): + @magic_factory( + call_button="Save", device={"choices": ["cuda:0", "cpu", "mps"]} + ) + def train_configs_widget( + crop_size: int = 252, + batch_size: int = 8, + max_iterations: int = 100_000, + initial_learning_rate: float = 4e-5, + temperature: float = 10.0, + regularizer_weight: float = 1e-5, + reduce_mean: bool = True, + density: float = 0.1, + kappa: float = 10.0, + save_model_every: int = 1e3, + save_snapshot_every: int = 1e3, + num_workers: int = 8, + device="mps", + ): + # Specify what should happen when 'Save' button is pressed + pass + + if not hasattr(self, "__create_train_configs_widget"): + self.__create_train_configs_widget = train_configs_widget() + self.__create_train_configs_widget_native = ( + self.__create_train_configs_widget.native + ) + return self.__create_train_configs_widget_native + + @property + def create_model_configs_widget(self): + @magic_factory(call_button="Save") + def model_configs_widget( + num_fmaps: int = 256, + fmap_inc_factor: int = 3, + features_in_last_layer: int = 64, + downsampling_factors: int = 2, + downsampling_layers: int = 1, + initialize: bool = True, + ): + # Specify what should happen when 'Save' button is pressed + pass + + if not hasattr(self, "__create_model_configs_widget"): + self.__create_model_configs_widget = model_configs_widget() + self.__create_model_configs_widget_native = ( + self.__create_model_configs_widget.native + ) + return self.__create_model_configs_widget_native @property def segment_widget(self): From be38c4d22c0f07fcdb9db619866f5c42d9a08d80 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 16 Oct 2023 00:22:13 -0400 Subject: [PATCH 23/40] Make crop_size int type --- src/napari_cellulus/dataset.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py index 95ad536..9fafac4 100644 --- a/src/napari_cellulus/dataset.py +++ b/src/napari_cellulus/dataset.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List import gunpowder as gp import numpy as np @@ -14,7 +14,7 @@ def __init__( self, layer: Image, axis_names: List[str], - crop_size: Tuple[int, ...], + crop_size: int, control_point_spacing: int, control_point_jitter: float, ): @@ -57,17 +57,11 @@ def __init__( self.layer = layer self.axis_names = axis_names - self.crop_size = crop_size self.control_point_spacing = control_point_spacing self.control_point_jitter = control_point_jitter self.__read_meta_data() - - assert len(crop_size) == self.num_spatial_dims, ( - f'"crop_size" must have the same dimension as the ' - f'spatial(temporal) dimensions of the "{self.layer.name}"' - f"layer which is {self.num_spatial_dims}, but it is {crop_size}" - ) - + print(f"Number of spatial dims is {self.num_spatial_dims}") + self.crop_size = (crop_size,) * self.num_spatial_dims self.__setup_pipeline() def __iter__(self): From 0e9f86bdeda11d35fe2f2a1e2d95623f39d4ad5c Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 16 Oct 2023 00:22:42 -0400 Subject: [PATCH 24/40] Update train code --- src/napari_cellulus/widgets/_widget.py | 274 +++++++++++++++++++++++-- 1 file changed, 255 insertions(+), 19 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index fdb27bb..a4189dc 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -1,52 +1,112 @@ +import os from typing import List import napari +import torch +from cellulus.criterions import get_loss +from cellulus.models import get_model +from cellulus.train import train_iteration from magicgui import magic_factory + +# widget stuff +from matplotlib.backends.backend_qt5agg import ( + NavigationToolbar2QT as NavigationToolbar, +) from napari.qt.threading import FunctionWorker, thread_worker +from PyQt5.QtCore import Qt from qtpy.QtWidgets import ( - QGridLayout, + QCheckBox, + QGroupBox, + QHBoxLayout, QLabel, + QPushButton, QScrollArea, QVBoxLayout, + QWidget, ) from superqt import QCollapsible +from ..dataset import NapariDataset + +# local package imports +from ..gui_helpers import MplCanvas, layer_choice_widget + class SegmentationWidget(QScrollArea): def __init__(self, napari_viewer): super().__init__() + self.widget = QWidget() self.viewer = napari_viewer # define components - self.method_description_label = QLabel( + method_description_label = QLabel( 'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.

https://github.com/funkelab/cellulus' ) # define layout outer_layout = QVBoxLayout() - # inner layout - grid_0 = QGridLayout() - grid_0.addWidget(self.method_description_label, 0, 1, 1, 1) - grid_0.setSpacing(10) - - # Add train configs widget + # Initialize train configs widget collapsible_train_configs = QCollapsible("Train Configs", self) collapsible_train_configs.addWidget(self.create_train_configs_widget) - # Add model configs widget + # Initialize model configs widget collapsible_model_configs = QCollapsible("Model Configs", self) collapsible_model_configs.addWidget(self.create_model_configs_widget) - # Add segment widget - collapsible_0 = QCollapsible("Inference", self) - collapsible_0.addWidget(self.segment_widget) + # Initialize loss/iterations widget + self.progress_plot = MplCanvas(self, width=5, height=3, dpi=100) + toolbar = NavigationToolbar(self.progress_plot, self) + progress_plot_layout = QVBoxLayout() + progress_plot_layout.addWidget(toolbar) + progress_plot_layout.addWidget(self.progress_plot) + self.loss_plot = None + self.val_plot = None + plot_container_widget = QWidget() + plot_container_widget.setLayout(progress_plot_layout) + + # Initialize Layer Choice + self.raw_selector = layer_choice_widget( + self.viewer, annotation=napari.layers.Image, name="raw" + ) + + # Initialize Checkboxes + self.s_checkbox = QCheckBox("s") + self.c_checkbox = QCheckBox("c") + self.t_checkbox = QCheckBox("t") + self.z_checkbox = QCheckBox("z") + self.y_checkbox = QCheckBox("y") + self.x_checkbox = QCheckBox("x") + + axis_layout = QHBoxLayout() + axis_layout.addWidget(self.s_checkbox) + axis_layout.addWidget(self.c_checkbox) + axis_layout.addWidget(self.t_checkbox) + axis_layout.addWidget(self.z_checkbox) + axis_layout.addWidget(self.y_checkbox) + axis_layout.addWidget(self.x_checkbox) + axis_selector = QGroupBox("Axis Names:") + axis_selector.setLayout(axis_layout) - outer_layout.addLayout(grid_0) + # Initialize Train Button + train_button = QPushButton("Train!", self) + train_button.clicked.connect(self.prepare_for_training) + + # Add everything to outer_layout + outer_layout.addWidget(method_description_label) outer_layout.addWidget(collapsible_train_configs) outer_layout.addWidget(collapsible_model_configs) - outer_layout.addWidget(collapsible_0) - self.setLayout(outer_layout) + outer_layout.addWidget(plot_container_widget) + outer_layout.addWidget(self.raw_selector.native) + outer_layout.addWidget(axis_selector) + outer_layout.addWidget(train_button) + + outer_layout.setSpacing(20) + self.widget.setLayout(outer_layout) + self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) + self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.setWidgetResizable(True) + self.setWidget(self.widget) self.setFixedWidth(400) @property @@ -64,13 +124,27 @@ def train_configs_widget( reduce_mean: bool = True, density: float = 0.1, kappa: float = 10.0, - save_model_every: int = 1e3, - save_snapshot_every: int = 1e3, num_workers: int = 8, + control_point_spacing: int = 64, + control_point_jitter: float = 2.0, device="mps", ): # Specify what should happen when 'Save' button is pressed - pass + self.train_config = { + "crop_size": crop_size, + "batch_size": batch_size, + "max_iterations": max_iterations, + "initial_learning_rate": initial_learning_rate, + "temperature": temperature, + "regularizer_weight": regularizer_weight, + "reduce_mean": reduce_mean, + "density": density, + "kappa": kappa, + "num_workers": num_workers, + "control_point_spacing": control_point_spacing, + "control_point_jitter": control_point_jitter, + "device": device, + } if not hasattr(self, "__create_train_configs_widget"): self.__create_train_configs_widget = train_configs_widget() @@ -91,7 +165,14 @@ def model_configs_widget( initialize: bool = True, ): # Specify what should happen when 'Save' button is pressed - pass + self.model_config = { + "num_fmaps": num_fmaps, + "fmap_inc_factor": fmap_inc_factor, + "features_in_last_layer": features_in_last_layer, + "downsampling_factors": downsampling_factors, + "downsampling_layers": downsampling_layers, + "initialize": initialize, + } if not hasattr(self, "__create_model_configs_widget"): self.__create_model_configs_widget = model_configs_widget() @@ -137,3 +218,158 @@ def async_segment( self.__segment_widget = segment() self.__segment_widget_native = self.__segment_widget.native return self.__segment_widget_native + + def get_selected_axes(self): + names = [] + for name, checkbox in zip( + "sctzyx", + [ + self.s_checkbox, + self.c_checkbox, + self.t_checkbox, + self.z_checkbox, + self.y_checkbox, + self.x_checkbox, + ], + ): + if checkbox.isChecked(): + names.append(name) + return names + + def prepare_for_training(self): + + # check if train_config object exists + # check if model_config object exists + + if self.train_config is None: + pass # TODO + if self.model_config is None: + pass # TODO + + self.__training_generator = self.train_napari(iteration=0) # TODO + self.__training_generator.yielded.connect(self.on_yield) + self.__training_generator.returned.connect(self.on_return) + self.__training_generator.start() + + @thread_worker + def train_napari(self, iteration=0): + + # Turn layer into dataset + train_dataset = NapariDataset( + layer=self.raw_selector.value, + axis_names=self.get_selected_axes(), + crop_size=self.train_config["crop_size"], + control_point_spacing=self.train_config["control_point_spacing"], + control_point_jitter=self.train_config["control_point_jitter"], + ) + + if not os.path.exists("models"): + os.makedirs("models") + + # create train dataloader + train_dataloader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=self.train_config["batch_size"], + drop_last=True, + num_workers=self.train_config["num_workers"], + pin_memory=True, + ) + + downsampling_factors = [ + [ + self.model_config["downsampling_factors"], + ] + * train_dataset.get_num_spatial_dims() + ] * self.model_config["downsampling_layers"] + + # set model + model = get_model( + in_channels=train_dataset.get_num_channels(), + out_channels=train_dataset.get_num_spatial_dims(), + num_fmaps=self.model_config["num_fmaps"], + fmap_inc_factor=self.model_config["fmap_inc_factor"], + features_in_last_layer=self.model_config["features_in_last_layer"], + downsampling_factors=[ + tuple(factor) for factor in downsampling_factors + ], + num_spatial_dims=train_dataset.get_num_spatial_dims(), + ) + + # set device + device = torch.device(self.train_config["device"]) + + model = model.to(device) + + # initialize model weights + if self.model_config["initialize"]: + for _name, layer in model.named_modules(): + if isinstance(layer, torch.nn.modules.conv._ConvNd): + torch.nn.init.kaiming_normal_( + layer.weight, nonlinearity="relu" + ) + + # set loss + criterion = get_loss( + regularizer_weight=self.train_config["regularizer_weight"], + temperature=self.train_config["temperature"], + kappa=self.train_config["kappa"], + density=self.train_config["density"], + num_spatial_dims=train_dataset.get_num_spatial_dims(), + reduce_mean=self.train_config["reduce_mean"], + device=device, + ) + + # set optimizer + optimizer = torch.optim.Adam( + model.parameters(), + lr=self.train_config["initial_learning_rate"], + ) + + # set scheduler: + + def lambda_(iteration): + return pow( + (1 - ((iteration) / self.train_config["max_iterations"])), 0.9 + ) + + # resume training + start_iteration = 0 + + # TODO + # if self.model_config.checkpoint is None: + # pass + # else: + # print(f"Resuming model from {self.model_config.checkpoint}") + # state = torch.load(self.model_config.checkpoint, map_location=device) + # start_iteration = state["iteration"] + 1 + # lowest_loss = state["lowest_loss"] + # model.load_state_dict(state["model_state_dict"], strict=True) + # optimizer.load_state_dict(state["optim_state_dict"]) + # logger.data = state["logger_data"] + + # call `train_iteration` + for iteration, batch in zip( + range(start_iteration, self.train_config["max_iterations"]), + train_dataloader, + ): + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=lambda_, last_epoch=iteration - 1 + ) + + train_loss, prediction = train_iteration( + batch, + model=model, + criterion=criterion, + optimizer=optimizer, + device=device, + ) + scheduler.step() + yield (iteration, train_loss) + + def on_yield(self, step_data): + # TODO + iteration, loss = step_data + print(iteration, loss) + + def on_return(self): + pass # TODO From 4524f172d488b1daefd47f6e422a3d64cfa357ad Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 16 Oct 2023 10:35:27 -0400 Subject: [PATCH 25/40] Add action in case the configs aren't changed --- src/napari_cellulus/widgets/_widget.py | 41 +++++++++++++++++++++----- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index a4189dc..45fd6ec 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -38,12 +38,17 @@ def __init__(self, napari_viewer): self.widget = QWidget() self.viewer = napari_viewer - # define components + # initialize train_config and model_config + + self.train_config = None + self.model_config = None + + # initialize UI components method_description_label = QLabel( 'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.

https://github.com/funkelab/cellulus' ) - # define layout + # specify layout outer_layout = QVBoxLayout() # Initialize train configs widget @@ -92,7 +97,7 @@ def __init__(self, napari_viewer): train_button = QPushButton("Train!", self) train_button.clicked.connect(self.prepare_for_training) - # Add everything to outer_layout + # Add all components to outer_layout outer_layout.addWidget(method_description_label) outer_layout.addWidget(collapsible_train_configs) outer_layout.addWidget(collapsible_model_configs) @@ -239,12 +244,34 @@ def get_selected_axes(self): def prepare_for_training(self): # check if train_config object exists - # check if model_config object exists - if self.train_config is None: - pass # TODO + # set default values + self.train_config = { + "crop_size": 252, + "batch_size": 8, + "max_iterations": 100_000, + "initial_learning_rate": 4e-5, + "temperature": 10.0, + "regularizer_weight": 1e-5, + "reduce_mean": True, + "density": 0.1, + "kappa": 10.0, + "num_workers": 8, + "control_point_spacing": 64, + "control_point_jitter": 2.0, + "device": "mps", + } + + # check if model_config object exists if self.model_config is None: - pass # TODO + self.model_config = { + "num_fmaps": 256, + "fmap_inc_factor": 3, + "features_in_last_layer": 64, + "downsampling_factors": 2, + "downsampling_layers": 1, + "initialize": True, + } self.__training_generator = self.train_napari(iteration=0) # TODO self.__training_generator.yielded.connect(self.on_yield) From 3454542d5d4308f9e96ce1adc70c3b2f5486ac74 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 16 Oct 2023 16:40:47 -0400 Subject: [PATCH 26/40] Enable training and pausing --- src/napari_cellulus/widgets/_widget.py | 122 +++++++++++++++++-------- 1 file changed, 85 insertions(+), 37 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 45fd6ec..c79bf91 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -43,6 +43,13 @@ def __init__(self, napari_viewer): self.train_config = None self.model_config = None + # initialize losses and iterations + self.losses = [] + self.iterations = [] + + # initialize mode. this will change to 'training' and 'inferring' later + self.mode = "configuring" + # initialize UI components method_description_label = QLabel( 'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.

https://github.com/funkelab/cellulus' @@ -60,15 +67,21 @@ def __init__(self, napari_viewer): collapsible_model_configs.addWidget(self.create_model_configs_widget) # Initialize loss/iterations widget - self.progress_plot = MplCanvas(self, width=5, height=3, dpi=100) - toolbar = NavigationToolbar(self.progress_plot, self) - progress_plot_layout = QVBoxLayout() - progress_plot_layout.addWidget(toolbar) - progress_plot_layout.addWidget(self.progress_plot) - self.loss_plot = None - self.val_plot = None + self.canvas = MplCanvas(self, width=5, height=3, dpi=100) + toolbar = NavigationToolbar(self.canvas, self) + canvas_layout = QVBoxLayout() + canvas_layout.addWidget(toolbar) + canvas_layout.addWidget(self.canvas) + if len(self.iterations) == 0: + self.loss_plot = self.canvas.axes.plot( + [], [], label="Training Loss" + )[0] + self.canvas.axes.legend() + self.canvas.axes.set_title("Training Progress") + self.canvas.axes.set_xlabel("Iterations") + self.canvas.axes.set_ylabel("Loss") plot_container_widget = QWidget() - plot_container_widget.setLayout(progress_plot_layout) + plot_container_widget.setLayout(canvas_layout) # Initialize Layer Choice self.raw_selector = layer_choice_widget( @@ -94,8 +107,8 @@ def __init__(self, napari_viewer): axis_selector.setLayout(axis_layout) # Initialize Train Button - train_button = QPushButton("Train!", self) - train_button.clicked.connect(self.prepare_for_training) + self.train_button = QPushButton("Train!", self) + self.train_button.clicked.connect(self.prepare_for_training) # Add all components to outer_layout outer_layout.addWidget(method_description_label) @@ -104,7 +117,7 @@ def __init__(self, napari_viewer): outer_layout.addWidget(plot_container_widget) outer_layout.addWidget(self.raw_selector.native) outer_layout.addWidget(axis_selector) - outer_layout.addWidget(train_button) + outer_layout.addWidget(self.train_button) outer_layout.setSpacing(20) self.widget.setLayout(outer_layout) @@ -273,13 +286,36 @@ def prepare_for_training(self): "initialize": True, } - self.__training_generator = self.train_napari(iteration=0) # TODO - self.__training_generator.yielded.connect(self.on_yield) - self.__training_generator.returned.connect(self.on_return) - self.__training_generator.start() + self.update_mode() + + if self.mode == "training": + self.worker = self.train_napari() + self.worker.yielded.connect(self.on_yield) + self.worker.returned.connect(self.on_return) + self.worker.start() + elif self.mode == "configuring": + state = { + "iteration": self.iterations[-1], + "model_state_dict": self.model.state_dict(), + "optim_state_dict": self.optimizer.state_dict(), + "iterations": self.iterations, + "losses": self.losses, + } + filename = os.path.join("models", "last.pth") + torch.save(state, filename) + self.worker.quit() + + def update_mode(self): + + if self.train_button.text() == "Train!": + self.train_button.setText("Pause!") + self.mode = "training" + elif self.train_button.text() == "Pause!": + self.train_button.setText("Train!") + self.mode = "configuring" @thread_worker - def train_napari(self, iteration=0): + def train_napari(self): # Turn layer into dataset train_dataset = NapariDataset( @@ -310,7 +346,7 @@ def train_napari(self, iteration=0): ] * self.model_config["downsampling_layers"] # set model - model = get_model( + self.model = get_model( in_channels=train_dataset.get_num_channels(), out_channels=train_dataset.get_num_spatial_dims(), num_fmaps=self.model_config["num_fmaps"], @@ -325,11 +361,11 @@ def train_napari(self, iteration=0): # set device device = torch.device(self.train_config["device"]) - model = model.to(device) + self.model = self.model.to(device) # initialize model weights if self.model_config["initialize"]: - for _name, layer in model.named_modules(): + for _name, layer in self.model.named_modules(): if isinstance(layer, torch.nn.modules.conv._ConvNd): torch.nn.init.kaiming_normal_( layer.weight, nonlinearity="relu" @@ -347,8 +383,8 @@ def train_napari(self, iteration=0): ) # set optimizer - optimizer = torch.optim.Adam( - model.parameters(), + self.optimizer = torch.optim.Adam( + self.model.parameters(), lr=self.train_config["initial_learning_rate"], ) @@ -360,19 +396,21 @@ def lambda_(iteration): ) # resume training - start_iteration = 0 - - # TODO - # if self.model_config.checkpoint is None: - # pass - # else: - # print(f"Resuming model from {self.model_config.checkpoint}") - # state = torch.load(self.model_config.checkpoint, map_location=device) - # start_iteration = state["iteration"] + 1 - # lowest_loss = state["lowest_loss"] - # model.load_state_dict(state["model_state_dict"], strict=True) - # optimizer.load_state_dict(state["optim_state_dict"]) - # logger.data = state["logger_data"] + if len(self.iterations) == 0: + start_iteration = 0 + else: + start_iteration = self.iterations[-1] + + if not os.path.exists("models/last.pth"): + pass + else: + print("Resuming model from 'models/last.pth'") + state = torch.load("models/last.pth", map_location=device) + start_iteration = state["iteration"] + 1 + self.iterations = state["iterations"] + self.losses = state["losses"] + self.model.load_state_dict(state["model_state_dict"], strict=True) + self.optimizer.load_state_dict(state["optim_state_dict"]) # call `train_iteration` for iteration, batch in zip( @@ -380,14 +418,14 @@ def lambda_(iteration): train_dataloader, ): scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, lr_lambda=lambda_, last_epoch=iteration - 1 + self.optimizer, lr_lambda=lambda_, last_epoch=iteration - 1 ) train_loss, prediction = train_iteration( batch, - model=model, + model=self.model, criterion=criterion, - optimizer=optimizer, + optimizer=self.optimizer, device=device, ) scheduler.step() @@ -397,6 +435,16 @@ def on_yield(self, step_data): # TODO iteration, loss = step_data print(iteration, loss) + self.iterations.append(iteration) + self.losses.append(loss) + self.update_canvas() + + def update_canvas(self): + self.loss_plot.set_xdata(self.iterations) + self.loss_plot.set_ydata(self.losses) + self.canvas.axes.relim() + self.canvas.axes.autoscale_view() + self.canvas.draw() def on_return(self): pass # TODO From 4dd43ef0d63739bb0959e4316d2f3b491c3d75e5 Mon Sep 17 00:00:00 2001 From: lmanan Date: Tue, 17 Oct 2023 14:02:19 -0400 Subject: [PATCH 27/40] Remove title from figure --- src/napari_cellulus/widgets/_widget.py | 86 ++++++++++++++++---------- 1 file changed, 53 insertions(+), 33 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index c79bf91..7b2be19 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -1,5 +1,4 @@ import os -from typing import List import napari import torch @@ -12,13 +11,14 @@ from matplotlib.backends.backend_qt5agg import ( NavigationToolbar2QT as NavigationToolbar, ) -from napari.qt.threading import FunctionWorker, thread_worker +from napari.qt.threading import thread_worker from PyQt5.QtCore import Qt from qtpy.QtWidgets import ( QCheckBox, QGroupBox, QHBoxLayout, QLabel, + QLineEdit, QPushButton, QScrollArea, QVBoxLayout, @@ -58,6 +58,16 @@ def __init__(self, napari_viewer): # specify layout outer_layout = QVBoxLayout() + # Initialize object size widget + object_size_label = QLabel(self) + object_size_label.setText("Object Size [px]:") + object_size_line = QLineEdit(self) + hbox_layout = QHBoxLayout() + hbox_layout.addWidget(object_size_label) + hbox_layout.addWidget(object_size_line) + object_size_box = QGroupBox("") + object_size_box.setLayout(hbox_layout) + # Initialize train configs widget collapsible_train_configs = QCollapsible("Train Configs", self) collapsible_train_configs.addWidget(self.create_train_configs_widget) @@ -77,7 +87,6 @@ def __init__(self, napari_viewer): [], [], label="Training Loss" )[0] self.canvas.axes.legend() - self.canvas.axes.set_title("Training Progress") self.canvas.axes.set_xlabel("Iterations") self.canvas.axes.set_ylabel("Loss") plot_container_widget = QWidget() @@ -110,14 +119,28 @@ def __init__(self, napari_viewer): self.train_button = QPushButton("Train!", self) self.train_button.clicked.connect(self.prepare_for_training) + # Initialize Model Configs widget + collapsible_inference_configs = QCollapsible("Inference Configs", self) + collapsible_inference_configs.addWidget( + self.create_inference_configs_widget + ) + + # Initialize Segment Button + self.segment_button = QPushButton("Segment!", self) + self.segment_button.clicked.connect(self.prepare_for_segmenting) + # Add all components to outer_layout + outer_layout.addWidget(method_description_label) + outer_layout.addWidget(object_size_box) outer_layout.addWidget(collapsible_train_configs) outer_layout.addWidget(collapsible_model_configs) outer_layout.addWidget(plot_container_widget) outer_layout.addWidget(self.raw_selector.native) outer_layout.addWidget(axis_selector) outer_layout.addWidget(self.train_button) + outer_layout.addWidget(collapsible_inference_configs) + outer_layout.addWidget(self.segment_button) outer_layout.setSpacing(20) self.widget.setLayout(outer_layout) @@ -200,42 +223,36 @@ def model_configs_widget( return self.__create_model_configs_widget_native @property - def segment_widget(self): - @magic_factory(call_button="Segment") - def segment( - raw: napari.layers.Image, + def create_inference_configs_widget(self): + @magic_factory(call_button="Save") + def inference_configs_widget( crop_size: int = 252, p_salt_pepper: float = 0.1, num_infer_iterations: int = 16, bandwidth: int = 7, + reduction_probability: float = 0.1, min_size: int = 25, - ) -> FunctionWorker[List[napari.types.LayerDataTuple]]: - @thread_worker( - connect={"returned": lambda: self.set_buttons("paused")}, - progress={"total": 0, "desc": "Segmenting"}, + grow_distance: int = 3, + shrink_distance: int = 6, + ): + # Specify what should happen when 'Save' button is pressed + self.inference_config = { + "crop_size": crop_size, + "p_salt_pepper": p_salt_pepper, + "num_infer_iterations": num_infer_iterations, + "bandwidth": bandwidth, + "reduction_probability": reduction_probability, + "min_size": min_size, + "grow_distance": grow_distance, + "shrink_distance": shrink_distance, + } + + if not hasattr(self, "__create_inference_configs_widget"): + self.__create_inference_configs_widget = inference_configs_widget() + self.__create_inference_configs_widget_native = ( + self.__create_inference_configs_widget.native ) - def async_segment( - raw: napari.layers.Image, - crop_size: int, - p_salt_pepper: float, - num_infer_iterations: int, - bandwidth: int, - min_size: int, - ) -> List[napari.types.LayerDataTuple]: - - return async_segment( - raw, - crop_size=crop_size, - p_salt_pepper=p_salt_pepper, - num_infer_iterations=num_infer_iterations, - bandwidth=bandwidth, - min_size=min_size, - ) - - if not hasattr(self, "__segment_widget"): - self.__segment_widget = segment() - self.__segment_widget_native = self.__segment_widget.native - return self.__segment_widget_native + return self.__create_inference_configs_widget_native def get_selected_axes(self): names = [] @@ -448,3 +465,6 @@ def update_canvas(self): def on_return(self): pass # TODO + + def prepare_for_segmenting(self): + pass From 02429325d14225f435dd5737f658a86a0b191fc5 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 18 Oct 2023 16:33:56 -0400 Subject: [PATCH 28/40] Add spatial_dims in dataset.py --- src/napari_cellulus/dataset.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py index 9fafac4..60840d1 100644 --- a/src/napari_cellulus/dataset.py +++ b/src/napari_cellulus/dataset.py @@ -82,27 +82,35 @@ def __setup_pipeline(self): if self.num_channels == 0 and self.num_samples == 0: self.pipeline = ( - NapariImageSource(self.layer, self.raw, raw_spec) + NapariImageSource( + self.layer, self.raw, raw_spec, self.spatial_dims + ) + gp.RandomLocation() + gp.Unsqueeze([self.raw], 0) + gp.Unsqueeze([self.raw], 0) ) elif self.num_channels == 0 and self.num_samples != 0: self.pipeline = ( - NapariImageSource(self.layer, self.raw, raw_spec) + NapariImageSource( + self.layer, self.raw, raw_spec, self.spatial_dims + ) + gp.Unsqueeze([self.raw], 1) + gp.RandomLocation() + gp.Unsqueeze([self.raw], 1) ) elif self.num_channels != 0 and self.num_samples == 0: self.pipeline = ( - NapariImageSource(self.layer, self.raw, raw_spec) + NapariImageSource( + self.layer, self.raw, raw_spec, self.spatial_dims + ) + gp.RandomLocation() + gp.Unsqueeze([self.raw], 0) ) elif self.num_channels != 0 and self.num_samples != 0: self.pipeline = ( - NapariImageSource(self.layer, self.raw, raw_spec) + NapariImageSource( + self.layer, self.raw, raw_spec, self.spatial_dims + ) + gp.RandomLocation() ) @@ -155,6 +163,7 @@ def __read_meta_data(self): self.sample_dim = meta_data.sample_dim self.channel_dim = meta_data.channel_dim self.time_dim = meta_data.time_dim + self.spatial_dims = meta_data.spatial_dims def get_num_channels(self): return 1 if self.num_channels == 0 else self.num_channels From 55c444becf144930c52e6aff5e8fa0779de8d41c Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 18 Oct 2023 16:35:15 -0400 Subject: [PATCH 29/40] Add normalize import --- .../gp/nodes/napari_image_source.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/napari_cellulus/gp/nodes/napari_image_source.py b/src/napari_cellulus/gp/nodes/napari_image_source.py index 1dd8c12..a13e996 100644 --- a/src/napari_cellulus/gp/nodes/napari_image_source.py +++ b/src/napari_cellulus/gp/nodes/napari_image_source.py @@ -1,7 +1,6 @@ -from typing import Optional - import gunpowder as gp import numpy as np +from csbdeep.utils import normalize from gunpowder.array_spec import ArraySpec from napari.layers import Image @@ -17,13 +16,19 @@ class NapariImageSource(gp.BatchProvider): """ def __init__( - self, image: Image, key: gp.ArrayKey, spec: Optional[ArraySpec] = None + self, image: Image, key: gp.ArrayKey, spec: ArraySpec, spatial_dims ): - if spec is None: - self.array_spec = self._read_metadata(image) - else: - self.array_spec = spec - self.image = gp.Array(image.data.astype(np.float32), self.array_spec) + self.array_spec = spec + self.image = gp.Array( + normalize( + image.data.astype(np.float32), + pmin=1, + pmax=99.8, + axis=spatial_dims, + ), + self.array_spec, + ) + self.spatial_dims = spatial_dims self.key = key def setup(self): From e506d4480dffb9caf901490be1ebf93001df97d5 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 18 Oct 2023 16:35:40 -0400 Subject: [PATCH 30/40] Add spatial_dims --- src/napari_cellulus/meta_data.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/napari_cellulus/meta_data.py b/src/napari_cellulus/meta_data.py index 40c2889..7e8751e 100644 --- a/src/napari_cellulus/meta_data.py +++ b/src/napari_cellulus/meta_data.py @@ -11,6 +11,7 @@ def __init__(self, shape, axis_names): self.channel_dim = None self.time_dim = None self.spatial_array: Tuple[int, ...] = () + self.spatial_dims = () for dim, axis_name in enumerate(axis_names): if axis_name == "s": self.sample_dim = dim @@ -24,9 +25,12 @@ def __init__(self, shape, axis_names): elif axis_name == "z": self.num_spatial_dims += 1 self.spatial_array += (shape[dim],) + self.spatial_dims += (-3,) elif axis_name == "y": self.num_spatial_dims += 1 self.spatial_array += (shape[dim],) + self.spatial_dims += (-2,) elif axis_name == "x": self.num_spatial_dims += 1 self.spatial_array += (shape[dim],) + self.spatial_dims += (-1,) From b760d0ae49f35369bb070b0d8bdeb97854a22488 Mon Sep 17 00:00:00 2001 From: lmanan Date: Wed, 18 Oct 2023 16:36:16 -0400 Subject: [PATCH 31/40] Add segment code - producing embeddings takes quite long. --- src/napari_cellulus/widgets/_widget.py | 299 +++++++++++++++++++++---- 1 file changed, 258 insertions(+), 41 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 7b2be19..ed703cd 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -1,10 +1,13 @@ import os +import gunpowder as gp import napari +import numpy as np import torch from cellulus.criterions import get_loss from cellulus.models import get_model from cellulus.train import train_iteration +from cellulus.utils.mean_shift import mean_shift_segmentation from magicgui import magic_factory # widget stuff @@ -12,21 +15,24 @@ NavigationToolbar2QT as NavigationToolbar, ) from napari.qt.threading import thread_worker -from PyQt5.QtCore import Qt +from qtpy.QtCore import Qt from qtpy.QtWidgets import ( QCheckBox, QGroupBox, QHBoxLayout, QLabel, QLineEdit, + QProgressBar, QPushButton, QScrollArea, QVBoxLayout, QWidget, ) +from scipy.ndimage import distance_transform_edt as dtedt from superqt import QCollapsible from ..dataset import NapariDataset +from ..gp.nodes.napari_image_source import NapariImageSource # local package imports from ..gui_helpers import MplCanvas, layer_choice_widget @@ -42,12 +48,13 @@ def __init__(self, napari_viewer): self.train_config = None self.model_config = None + self.segment_config = None # initialize losses and iterations self.losses = [] self.iterations = [] - # initialize mode. this will change to 'training' and 'inferring' later + # initialize mode. this will change to 'training' and 'segmentring' later self.mode = "configuring" # initialize UI components @@ -61,10 +68,10 @@ def __init__(self, napari_viewer): # Initialize object size widget object_size_label = QLabel(self) object_size_label.setText("Object Size [px]:") - object_size_line = QLineEdit(self) + self.object_size_line = QLineEdit(self) hbox_layout = QHBoxLayout() hbox_layout.addWidget(object_size_label) - hbox_layout.addWidget(object_size_line) + hbox_layout.addWidget(self.object_size_line) object_size_box = QGroupBox("") object_size_box.setLayout(hbox_layout) @@ -116,19 +123,22 @@ def __init__(self, napari_viewer): axis_selector.setLayout(axis_layout) # Initialize Train Button - self.train_button = QPushButton("Train!", self) + self.train_button = QPushButton("Train", self) self.train_button.clicked.connect(self.prepare_for_training) # Initialize Model Configs widget - collapsible_inference_configs = QCollapsible("Inference Configs", self) - collapsible_inference_configs.addWidget( - self.create_inference_configs_widget + collapsible_segment_configs = QCollapsible("Inference Configs", self) + collapsible_segment_configs.addWidget( + self.create_segment_configs_widget ) # Initialize Segment Button - self.segment_button = QPushButton("Segment!", self) + self.segment_button = QPushButton("Segment", self) self.segment_button.clicked.connect(self.prepare_for_segmenting) + # Initialize progress bar + self.pbar = QProgressBar(self) + # Add all components to outer_layout outer_layout.addWidget(method_description_label) @@ -139,8 +149,9 @@ def __init__(self, napari_viewer): outer_layout.addWidget(self.raw_selector.native) outer_layout.addWidget(axis_selector) outer_layout.addWidget(self.train_button) - outer_layout.addWidget(collapsible_inference_configs) + outer_layout.addWidget(collapsible_segment_configs) outer_layout.addWidget(self.segment_button) + outer_layout.addWidget(self.pbar) outer_layout.setSpacing(20) self.widget.setLayout(outer_layout) @@ -223,9 +234,9 @@ def model_configs_widget( return self.__create_model_configs_widget_native @property - def create_inference_configs_widget(self): + def create_segment_configs_widget(self): @magic_factory(call_button="Save") - def inference_configs_widget( + def segment_configs_widget( crop_size: int = 252, p_salt_pepper: float = 0.1, num_infer_iterations: int = 16, @@ -236,7 +247,7 @@ def inference_configs_widget( shrink_distance: int = 6, ): # Specify what should happen when 'Save' button is pressed - self.inference_config = { + self.segment_config = { "crop_size": crop_size, "p_salt_pepper": p_salt_pepper, "num_infer_iterations": num_infer_iterations, @@ -247,12 +258,12 @@ def inference_configs_widget( "shrink_distance": shrink_distance, } - if not hasattr(self, "__create_inference_configs_widget"): - self.__create_inference_configs_widget = inference_configs_widget() - self.__create_inference_configs_widget_native = ( - self.__create_inference_configs_widget.native + if not hasattr(self, "__create_segment_configs_widget"): + self.__create_segment_configs_widget = segment_configs_widget() + self.__create_segment_configs_widget_native = ( + self.__create_segment_configs_widget.native ) - return self.__create_inference_configs_widget_native + return self.__create_segment_configs_widget_native def get_selected_axes(self): names = [] @@ -302,8 +313,7 @@ def prepare_for_training(self): "downsampling_layers": 1, "initialize": True, } - - self.update_mode() + self.update_mode(self.sender()) if self.mode == "training": self.worker = self.train_napari() @@ -322,20 +332,34 @@ def prepare_for_training(self): torch.save(state, filename) self.worker.quit() - def update_mode(self): + def update_mode(self, sender): - if self.train_button.text() == "Train!": - self.train_button.setText("Pause!") + if self.train_button.text() == "Train" and sender == self.train_button: + self.train_button.setText("Pause") self.mode = "training" - elif self.train_button.text() == "Pause!": - self.train_button.setText("Train!") + elif ( + self.train_button.text() == "Pause" and sender == self.train_button + ): + self.train_button.setText("Train") + self.mode = "configuring" + elif ( + self.segment_button.text() == "Segment" + and sender == self.segment_button + ): + self.segment_button.setText("Pause") + self.mode = "segmenting" + elif ( + self.segment_button.text() == "Pause" + and sender == self.segment_button + ): + self.segment_button.setText("Segment") self.mode = "configuring" @thread_worker def train_napari(self): # Turn layer into dataset - train_dataset = NapariDataset( + self.dataset = NapariDataset( layer=self.raw_selector.value, axis_names=self.get_selected_axes(), crop_size=self.train_config["crop_size"], @@ -347,8 +371,8 @@ def train_napari(self): os.makedirs("models") # create train dataloader - train_dataloader = torch.utils.data.DataLoader( - dataset=train_dataset, + self.dataloader = torch.utils.data.DataLoader( + dataset=self.dataset, batch_size=self.train_config["batch_size"], drop_last=True, num_workers=self.train_config["num_workers"], @@ -359,20 +383,20 @@ def train_napari(self): [ self.model_config["downsampling_factors"], ] - * train_dataset.get_num_spatial_dims() + * self.dataset.get_num_spatial_dims() ] * self.model_config["downsampling_layers"] # set model self.model = get_model( - in_channels=train_dataset.get_num_channels(), - out_channels=train_dataset.get_num_spatial_dims(), + in_channels=self.dataset.get_num_channels(), + out_channels=self.dataset.get_num_spatial_dims(), num_fmaps=self.model_config["num_fmaps"], fmap_inc_factor=self.model_config["fmap_inc_factor"], features_in_last_layer=self.model_config["features_in_last_layer"], downsampling_factors=[ tuple(factor) for factor in downsampling_factors ], - num_spatial_dims=train_dataset.get_num_spatial_dims(), + num_spatial_dims=self.dataset.get_num_spatial_dims(), ) # set device @@ -394,7 +418,7 @@ def train_napari(self): temperature=self.train_config["temperature"], kappa=self.train_config["kappa"], density=self.train_config["density"], - num_spatial_dims=train_dataset.get_num_spatial_dims(), + num_spatial_dims=self.dataset.get_num_spatial_dims(), reduce_mean=self.train_config["reduce_mean"], device=device, ) @@ -432,7 +456,7 @@ def lambda_(iteration): # call `train_iteration` for iteration, batch in zip( range(start_iteration, self.train_config["max_iterations"]), - train_dataloader, + self.dataloader, ): scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=lambda_, last_epoch=iteration - 1 @@ -449,12 +473,15 @@ def lambda_(iteration): yield (iteration, train_loss) def on_yield(self, step_data): - # TODO - iteration, loss = step_data - print(iteration, loss) - self.iterations.append(iteration) - self.losses.append(loss) - self.update_canvas() + if self.mode == "training": + iteration, loss = step_data + print(iteration, loss) + self.iterations.append(iteration) + self.losses.append(loss) + self.update_canvas() + elif self.mode == "segmenting": + print(step_data) + self.pbar.setValue(step_data) def update_canvas(self): self.loss_plot.set_xdata(self.iterations) @@ -467,4 +494,194 @@ def on_return(self): pass # TODO def prepare_for_segmenting(self): - pass + # check if segment_config exists + if self.segment_config is None: + self.segment_config = { + "crop_size": 252, + "p_salt_pepper": 0.1, + "num_infer_iterations": 16, + "bandwidth": None, + "reduction_probability": 0.1, + "min_size": None, + "grow_distance": 3, + "shrink_distance": 6, + } + + # update mode + self.update_mode(self.sender()) + + if self.mode == "segmenting": + self.worker = self.segment_napari() + self.worker.yielded.connect(self.on_yield) + self.worker.returned.connect(self.on_return) + self.worker.start() + elif self.mode == "configuring": + self.worker.quit() + + @thread_worker + def segment_napari(self): + raw = self.raw_selector.value + + if self.segment_config["bandwidth"] is None: + self.segment_config["bandwidth"] = int( + 0.5 * float(self.object_size_line.text()) + ) + if self.segment_config["min_size"] is None: + self.segment_config["min_size"] = int( + 0.1 * np.pi * (float(self.object_size_line.text()) ** 2) / 4 + ) + self.model.eval() + + num_spatial_dims = self.dataset.num_spatial_dims + num_channels = self.dataset.num_channels + spatial_dims = self.dataset.spatial_dims + num_samples = self.dataset.num_samples + + print( + f"Num spatial dims {num_spatial_dims} num channels {num_channels} spatial_dims {spatial_dims} num_samples {num_samples}" + ) + + crop_size = (self.segment_config["crop_size"],) * num_spatial_dims + device = self.train_config["device"] + + if num_channels == 0: + num_channels = 1 + + voxel_size = gp.Coordinate((1,) * num_spatial_dims) + self.model.set_infer( + p_salt_pepper=self.segment_config["p_salt_pepper"], + num_infer_iterations=self.segment_config["num_infer_iterations"], + device=device, + ) + + print(f"Current device is {device}") + + input_shape = gp.Coordinate((1, num_channels, *crop_size)) + output_shape = gp.Coordinate( + self.model( + torch.zeros( + (1, num_channels, *crop_size), dtype=torch.float32 + ).to(device) + ).shape + ) + input_size = ( + gp.Coordinate(input_shape[-num_spatial_dims:]) * voxel_size + ) + output_size = ( + gp.Coordinate(output_shape[-num_spatial_dims:]) * voxel_size + ) + context = (input_size - output_size) / 2 + + raw_key = gp.ArrayKey("RAW") + prediction_key = gp.ArrayKey("PREDICT") + scan_request = gp.BatchRequest() + scan_request.add(raw_key, input_size) + scan_request.add(prediction_key, output_size) + + predict = gp.torch.Predict( + self.model, + inputs={"raw": raw_key}, + outputs={0: prediction_key}, + array_specs={prediction_key: gp.ArraySpec(voxel_size=voxel_size)}, + ) + pipeline = NapariImageSource( + raw, + raw_key, + gp.ArraySpec( + gp.Roi( + (0,) * num_spatial_dims, + raw.data.shape[-num_spatial_dims:], + ), + voxel_size=voxel_size, + ), + spatial_dims, + ) + + if num_samples == 0 and num_channels == 0: + pipeline += ( + gp.Pad(raw_key, context) + + gp.Unsqueeze([raw_key], 0) + + gp.Unsqueeze([raw_key], 0) + + predict + + gp.Scan(scan_request) + ) + elif num_samples != 0 and num_channels == 0: + pipeline += ( + gp.Pad(raw_key, context) + + gp.Unsqueeze([raw_key], 1) + + predict + + gp.Scan(scan_request) + ) + elif num_samples == 0 and num_channels != 0: + pipeline += ( + gp.Pad(raw_key, context) + + gp.Unsqueeze([raw_key], 0) + + predict + + gp.Scan(scan_request) + ) + elif num_samples != 0 and num_channels != 0: + pipeline += ( + gp.Pad(raw_key, context) + predict + gp.Scan(scan_request) + ) + + # request to pipeline for ROI of whole image/volume + request = gp.BatchRequest() + request.add(prediction_key, raw.data.shape[-num_spatial_dims:]) + counter = 0 + with gp.build(pipeline): + batch = pipeline.request_batch(request) + yield counter + counter += 0.1 + + prediction = batch.arrays[prediction_key].data + colormaps = ["red", "green", "blue"] + prediction_layers = [ + ( + prediction[:, i : i + 1, ...].copy(), + { + "name": "offset-" + "zyx"[num_spatial_dims - i] + if i < num_spatial_dims + else "std", + "colormap": colormaps[num_spatial_dims - i] + if i < num_spatial_dims + else "gray", + "blending": "additive", + }, + "image", + ) + for i in range(num_spatial_dims + 1) + ] + + labels = np.zeros_like(prediction[:, 0:1, ...].data, dtype=np.uint64) + for sample in range(num_samples): + embeddings = prediction[sample] + embeddings_std = embeddings[-1, ...] + embeddings_mean = embeddings[np.newaxis, :num_spatial_dims, ...] + segmentation = mean_shift_segmentation( + embeddings_mean, + embeddings_std, + self.segment_config["bandwidth"], + self.segment_config["min_size"], + self.segment_config["reduction_probability"], + ) + labels[sample, 0, ...] = segmentation + + pp_labels = np.zeros_like( + prediction[:, 0:1, ...].data, dtype=np.uint64 + ) + for sample in range(num_samples): + segmentation = labels[sample, 0] + distance_foreground = dtedt(segmentation == 0) + expanded_mask = ( + distance_foreground < self.inference_config["grow_distance"] + ) + distance_background = dtedt(expanded_mask) + segmentation[ + distance_background < self.inference_config["shrink_distance"] + ] = 0 + pp_labels[sample, 0, ...] = segmentation + return ( + prediction_layers + + [(labels, {"name": "Segmentation"}, "labels")] + + [(pp_labels, {"name": "Post Processed"}, "labels")] + ) From 855f2e65fb5036125f4a1123dd85e7a4baf9b1e5 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 22 Oct 2023 18:47:19 -0400 Subject: [PATCH 32/40] profiling plugin - Currently it takes 1.9 seconds per iteration which is slower than previous version --- src/napari_cellulus/widgets/_widget.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index ed703cd..eddb12a 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -1,4 +1,5 @@ import os +import time import gunpowder as gp import napari @@ -37,6 +38,9 @@ # local package imports from ..gui_helpers import MplCanvas, layer_choice_widget +############ GLOBALS ################### +time_now = 0 + class SegmentationWidget(QScrollArea): def __init__(self, napari_viewer): @@ -474,8 +478,14 @@ def lambda_(iteration): def on_yield(self, step_data): if self.mode == "training": + global time_now iteration, loss = step_data - print(iteration, loss) + current_time = time.time() + time_elapsed = current_time - time_now + time_now = current_time + print( + f"iteration {iteration}, loss {loss}, seconds/iteration {time_elapsed}" + ) self.iterations.append(iteration) self.losses.append(loss) self.update_canvas() From da38c881be9bf0ae6e2d5cb414facc9601ed086c Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 22 Oct 2023 20:40:39 -0400 Subject: [PATCH 33/40] Add device as a global UI config param --- src/napari_cellulus/widgets/_widget.py | 244 +++++++++++++++---------- 1 file changed, 148 insertions(+), 96 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index eddb12a..85081d4 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -1,3 +1,4 @@ +import dataclasses import os import time @@ -12,18 +13,17 @@ from magicgui import magic_factory # widget stuff -from matplotlib.backends.backend_qt5agg import ( - NavigationToolbar2QT as NavigationToolbar, -) from napari.qt.threading import thread_worker -from qtpy.QtCore import Qt +from qtpy.QtCore import Qt, QUrl +from qtpy.QtGui import QDesktopServices from qtpy.QtWidgets import ( QCheckBox, + QComboBox, + QGridLayout, QGroupBox, QHBoxLayout, QLabel, QLineEdit, - QProgressBar, QPushButton, QScrollArea, QVBoxLayout, @@ -38,8 +38,33 @@ # local package imports from ..gui_helpers import MplCanvas, layer_choice_widget + +@dataclasses.dataclass +class TrainingStats: + iteration: int = 0 + losses: list[float] = dataclasses.field(default_factory=list) + iterations: list[int] = dataclasses.field(default_factory=list) + + def reset(self): + self.iteration = 0 + self.losses = [] + self.iterations = [] + + def load(self, other): + self.iteration = other.iteration + self.losses = other.losses + self.iterations = other.iterations + + ############ GLOBALS ################### time_now = 0 +_train_config = None +_model_config = None +_segment_config = None +_model = None +_optimizer = None +_scheduler = None +_dataset = None class SegmentationWidget(QScrollArea): @@ -73,11 +98,23 @@ def __init__(self, napari_viewer): object_size_label = QLabel(self) object_size_label.setText("Object Size [px]:") self.object_size_line = QLineEdit(self) - hbox_layout = QHBoxLayout() - hbox_layout.addWidget(object_size_label) - hbox_layout.addWidget(self.object_size_line) - object_size_box = QGroupBox("") - object_size_box.setLayout(hbox_layout) + self.object_size_line.setText("30") + device_label = QLabel(self) + device_label.setText("Device") + self.device_combo_box = QComboBox(self) + self.device_combo_box.addItem("cpu") + self.device_combo_box.addItem("cuda:0") + self.device_combo_box.addItem("mps") + self.device_combo_box.setCurrentText("mps") + + grid_layout = QGridLayout() + grid_layout.addWidget(object_size_label, 0, 0) + grid_layout.addWidget(self.object_size_line, 0, 1) + grid_layout.addWidget(device_label, 1, 0) + grid_layout.addWidget(self.device_combo_box, 1, 1) + + # global_params_widget = QWidget("") + # global_params_widget.setLayout(grid_layout) # Initialize train configs widget collapsible_train_configs = QCollapsible("Train Configs", self) @@ -88,15 +125,11 @@ def __init__(self, napari_viewer): collapsible_model_configs.addWidget(self.create_model_configs_widget) # Initialize loss/iterations widget - self.canvas = MplCanvas(self, width=5, height=3, dpi=100) - toolbar = NavigationToolbar(self.canvas, self) + self.canvas = MplCanvas(self, width=5, height=5, dpi=100) canvas_layout = QVBoxLayout() - canvas_layout.addWidget(toolbar) canvas_layout.addWidget(self.canvas) if len(self.iterations) == 0: - self.loss_plot = self.canvas.axes.plot( - [], [], label="Training Loss" - )[0] + self.loss_plot = self.canvas.axes.plot([], [], label="")[0] self.canvas.axes.legend() self.canvas.axes.set_xlabel("Iterations") self.canvas.axes.set_ylabel("Loss") @@ -130,7 +163,12 @@ def __init__(self, napari_viewer): self.train_button = QPushButton("Train", self) self.train_button.clicked.connect(self.prepare_for_training) - # Initialize Model Configs widget + # Initialize Save and Load Widget + # collapsible_save_load_widget = QCollapsible("Save/Load", self) + # collapsible_save_load_widget.addWidget(self.save_widget.native) + # collapsible.save_load_widget.addWidget(self.load_widget.native) + + # Initialize Segment Configs widget collapsible_segment_configs = QCollapsible("Inference Configs", self) collapsible_segment_configs.addWidget( self.create_segment_configs_widget @@ -141,21 +179,33 @@ def __init__(self, napari_viewer): self.segment_button.clicked.connect(self.prepare_for_segmenting) # Initialize progress bar - self.pbar = QProgressBar(self) + # self.pbar = QProgressBar(self) + + # Initialize Feedback Button + self.feedback_button = QPushButton("Feedback!", self) + self.feedback_button.clicked.connect( + lambda: QDesktopServices.openUrl( + QUrl( + "https://github.com/funkelab/napari-cellulus/issues/new/choose" + ) + ) + ) # Add all components to outer_layout outer_layout.addWidget(method_description_label) - outer_layout.addWidget(object_size_box) + outer_layout.addLayout(grid_layout) + # outer_layout.addWidget(global_params_widget) outer_layout.addWidget(collapsible_train_configs) outer_layout.addWidget(collapsible_model_configs) outer_layout.addWidget(plot_container_widget) outer_layout.addWidget(self.raw_selector.native) outer_layout.addWidget(axis_selector) outer_layout.addWidget(self.train_button) + # outer_layout.addWidget(collapsible_save_load_widget) outer_layout.addWidget(collapsible_segment_configs) outer_layout.addWidget(self.segment_button) - outer_layout.addWidget(self.pbar) + outer_layout.addWidget(self.feedback_button) outer_layout.setSpacing(20) self.widget.setLayout(outer_layout) @@ -167,9 +217,7 @@ def __init__(self, napari_viewer): @property def create_train_configs_widget(self): - @magic_factory( - call_button="Save", device={"choices": ["cuda:0", "cpu", "mps"]} - ) + @magic_factory(call_button="Save") def train_configs_widget( crop_size: int = 252, batch_size: int = 8, @@ -183,7 +231,6 @@ def train_configs_widget( num_workers: int = 8, control_point_spacing: int = 64, control_point_jitter: float = 2.0, - device="mps", ): # Specify what should happen when 'Save' button is pressed self.train_config = { @@ -199,7 +246,6 @@ def train_configs_widget( "num_workers": num_workers, "control_point_spacing": control_point_spacing, "control_point_jitter": control_point_jitter, - "device": device, } if not hasattr(self, "__create_train_configs_widget"): @@ -289,9 +335,11 @@ def get_selected_axes(self): def prepare_for_training(self): # check if train_config object exists - if self.train_config is None: + global _train_config, _model_config, _model, _optimizer + + if _train_config is None: # set default values - self.train_config = { + _train_config = { "crop_size": 252, "batch_size": 8, "max_iterations": 100_000, @@ -304,12 +352,11 @@ def prepare_for_training(self): "num_workers": 8, "control_point_spacing": 64, "control_point_jitter": 2.0, - "device": "mps", } # check if model_config object exists - if self.model_config is None: - self.model_config = { + if _model_config is None: + _model_config = { "num_fmaps": 256, "fmap_inc_factor": 3, "features_in_last_layer": 64, @@ -327,8 +374,8 @@ def prepare_for_training(self): elif self.mode == "configuring": state = { "iteration": self.iterations[-1], - "model_state_dict": self.model.state_dict(), - "optim_state_dict": self.optimizer.state_dict(), + "model_state_dict": _model.state_dict(), + "optim_state_dict": _optimizer.state_dict(), "iterations": self.iterations, "losses": self.losses, } @@ -362,55 +409,57 @@ def update_mode(self, sender): @thread_worker def train_napari(self): + global _train_config, _model_config, _model, _scheduler, _optimizer, _dataset + # Turn layer into dataset - self.dataset = NapariDataset( + _dataset = NapariDataset( layer=self.raw_selector.value, axis_names=self.get_selected_axes(), - crop_size=self.train_config["crop_size"], - control_point_spacing=self.train_config["control_point_spacing"], - control_point_jitter=self.train_config["control_point_jitter"], + crop_size=_train_config["crop_size"], + control_point_spacing=_train_config["control_point_spacing"], + control_point_jitter=_train_config["control_point_jitter"], ) if not os.path.exists("models"): os.makedirs("models") # create train dataloader - self.dataloader = torch.utils.data.DataLoader( - dataset=self.dataset, - batch_size=self.train_config["batch_size"], + dataloader = torch.utils.data.DataLoader( + dataset=_dataset, + batch_size=_train_config["batch_size"], drop_last=True, - num_workers=self.train_config["num_workers"], + num_workers=_train_config["num_workers"], pin_memory=True, ) downsampling_factors = [ [ - self.model_config["downsampling_factors"], + _model_config["downsampling_factors"], ] - * self.dataset.get_num_spatial_dims() - ] * self.model_config["downsampling_layers"] + * _dataset.get_num_spatial_dims() + ] * _model_config["downsampling_layers"] # set model - self.model = get_model( - in_channels=self.dataset.get_num_channels(), - out_channels=self.dataset.get_num_spatial_dims(), - num_fmaps=self.model_config["num_fmaps"], - fmap_inc_factor=self.model_config["fmap_inc_factor"], - features_in_last_layer=self.model_config["features_in_last_layer"], + _model = get_model( + in_channels=_dataset.get_num_channels(), + out_channels=_dataset.get_num_spatial_dims(), + num_fmaps=_model_config["num_fmaps"], + fmap_inc_factor=_model_config["fmap_inc_factor"], + features_in_last_layer=_model_config["features_in_last_layer"], downsampling_factors=[ tuple(factor) for factor in downsampling_factors ], - num_spatial_dims=self.dataset.get_num_spatial_dims(), + num_spatial_dims=_dataset.get_num_spatial_dims(), ) # set device - device = torch.device(self.train_config["device"]) + device = torch.device(self.device_combo_box.currentText()) - self.model = self.model.to(device) + _model = _model.to(device) # initialize model weights - if self.model_config["initialize"]: - for _name, layer in self.model.named_modules(): + if _model_config["initialize"]: + for _name, layer in _model.named_modules(): if isinstance(layer, torch.nn.modules.conv._ConvNd): torch.nn.init.kaiming_normal_( layer.weight, nonlinearity="relu" @@ -418,26 +467,26 @@ def train_napari(self): # set loss criterion = get_loss( - regularizer_weight=self.train_config["regularizer_weight"], - temperature=self.train_config["temperature"], - kappa=self.train_config["kappa"], - density=self.train_config["density"], - num_spatial_dims=self.dataset.get_num_spatial_dims(), - reduce_mean=self.train_config["reduce_mean"], + regularizer_weight=_train_config["regularizer_weight"], + temperature=_train_config["temperature"], + kappa=_train_config["kappa"], + density=_train_config["density"], + num_spatial_dims=_dataset.get_num_spatial_dims(), + reduce_mean=_train_config["reduce_mean"], device=device, ) # set optimizer - self.optimizer = torch.optim.Adam( - self.model.parameters(), - lr=self.train_config["initial_learning_rate"], + _optimizer = torch.optim.Adam( + _model.parameters(), + lr=_train_config["initial_learning_rate"], ) # set scheduler: def lambda_(iteration): return pow( - (1 - ((iteration) / self.train_config["max_iterations"])), 0.9 + (1 - ((iteration) / _train_config["max_iterations"])), 0.9 ) # resume training @@ -454,26 +503,26 @@ def lambda_(iteration): start_iteration = state["iteration"] + 1 self.iterations = state["iterations"] self.losses = state["losses"] - self.model.load_state_dict(state["model_state_dict"], strict=True) - self.optimizer.load_state_dict(state["optim_state_dict"]) + _model.load_state_dict(state["model_state_dict"], strict=True) + _optimizer.load_state_dict(state["optim_state_dict"]) # call `train_iteration` for iteration, batch in zip( - range(start_iteration, self.train_config["max_iterations"]), - self.dataloader, + range(start_iteration, _train_config["max_iterations"]), + dataloader, ): - scheduler = torch.optim.lr_scheduler.LambdaLR( - self.optimizer, lr_lambda=lambda_, last_epoch=iteration - 1 + _scheduler = torch.optim.lr_scheduler.LambdaLR( + _optimizer, lr_lambda=lambda_, last_epoch=iteration - 1 ) train_loss, prediction = train_iteration( batch, - model=self.model, + model=_model, criterion=criterion, - optimizer=self.optimizer, + optimizer=_optimizer, device=device, ) - scheduler.step() + _scheduler.step() yield (iteration, train_loss) def on_yield(self, step_data): @@ -504,9 +553,10 @@ def on_return(self): pass # TODO def prepare_for_segmenting(self): + global _segment_config # check if segment_config exists - if self.segment_config is None: - self.segment_config = { + if _segment_config is None: + _segment_config = { "crop_size": 252, "p_salt_pepper": 0.1, "num_infer_iterations": 16, @@ -530,37 +580,39 @@ def prepare_for_segmenting(self): @thread_worker def segment_napari(self): + global _segment_config, _model, _dataset + raw = self.raw_selector.value - if self.segment_config["bandwidth"] is None: - self.segment_config["bandwidth"] = int( + if _segment_config["bandwidth"] is None: + _segment_config["bandwidth"] = int( 0.5 * float(self.object_size_line.text()) ) - if self.segment_config["min_size"] is None: - self.segment_config["min_size"] = int( + if _segment_config["min_size"] is None: + _segment_config["min_size"] = int( 0.1 * np.pi * (float(self.object_size_line.text()) ** 2) / 4 ) - self.model.eval() + _model.eval() - num_spatial_dims = self.dataset.num_spatial_dims - num_channels = self.dataset.num_channels - spatial_dims = self.dataset.spatial_dims - num_samples = self.dataset.num_samples + num_spatial_dims = _dataset.num_spatial_dims + num_channels = _dataset.num_channels + spatial_dims = _dataset.spatial_dims + num_samples = _dataset.num_samples print( f"Num spatial dims {num_spatial_dims} num channels {num_channels} spatial_dims {spatial_dims} num_samples {num_samples}" ) - crop_size = (self.segment_config["crop_size"],) * num_spatial_dims - device = self.train_config["device"] + crop_size = (_segment_config["crop_size"],) * num_spatial_dims + device = self.device_combo_box.currentText() if num_channels == 0: num_channels = 1 voxel_size = gp.Coordinate((1,) * num_spatial_dims) - self.model.set_infer( - p_salt_pepper=self.segment_config["p_salt_pepper"], - num_infer_iterations=self.segment_config["num_infer_iterations"], + _model.set_infer( + p_salt_pepper=_segment_config["p_salt_pepper"], + num_infer_iterations=_segment_config["num_infer_iterations"], device=device, ) @@ -568,7 +620,7 @@ def segment_napari(self): input_shape = gp.Coordinate((1, num_channels, *crop_size)) output_shape = gp.Coordinate( - self.model( + _model( torch.zeros( (1, num_channels, *crop_size), dtype=torch.float32 ).to(device) @@ -589,7 +641,7 @@ def segment_napari(self): scan_request.add(prediction_key, output_size) predict = gp.torch.Predict( - self.model, + _model, inputs={"raw": raw_key}, outputs={0: prediction_key}, array_specs={prediction_key: gp.ArraySpec(voxel_size=voxel_size)}, @@ -670,9 +722,9 @@ def segment_napari(self): segmentation = mean_shift_segmentation( embeddings_mean, embeddings_std, - self.segment_config["bandwidth"], - self.segment_config["min_size"], - self.segment_config["reduction_probability"], + _segment_config["bandwidth"], + _segment_config["min_size"], + _segment_config["reduction_probability"], ) labels[sample, 0, ...] = segmentation @@ -683,11 +735,11 @@ def segment_napari(self): segmentation = labels[sample, 0] distance_foreground = dtedt(segmentation == 0) expanded_mask = ( - distance_foreground < self.inference_config["grow_distance"] + distance_foreground < _segment_config["grow_distance"] ) distance_background = dtedt(expanded_mask) segmentation[ - distance_background < self.inference_config["shrink_distance"] + distance_background < _segment_config["shrink_distance"] ] = 0 pp_labels[sample, 0, ...] = segmentation return ( From f07c1101bd01127f431fffeef0a2bec0b1266f13 Mon Sep 17 00:00:00 2001 From: lmanan Date: Sun, 22 Oct 2023 21:30:05 -0400 Subject: [PATCH 34/40] update default value of num_fmaps for plugin --- src/napari_cellulus/widgets/_widget.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 85081d4..43368aa 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -232,8 +232,9 @@ def train_configs_widget( control_point_spacing: int = 64, control_point_jitter: float = 2.0, ): + global _train_config # Specify what should happen when 'Save' button is pressed - self.train_config = { + _train_config = { "crop_size": crop_size, "batch_size": batch_size, "max_iterations": max_iterations, @@ -259,15 +260,16 @@ def train_configs_widget( def create_model_configs_widget(self): @magic_factory(call_button="Save") def model_configs_widget( - num_fmaps: int = 256, + num_fmaps: int = 24, fmap_inc_factor: int = 3, features_in_last_layer: int = 64, downsampling_factors: int = 2, downsampling_layers: int = 1, initialize: bool = True, ): + global _model_config # Specify what should happen when 'Save' button is pressed - self.model_config = { + _model_config = { "num_fmaps": num_fmaps, "fmap_inc_factor": fmap_inc_factor, "features_in_last_layer": features_in_last_layer, @@ -296,8 +298,9 @@ def segment_configs_widget( grow_distance: int = 3, shrink_distance: int = 6, ): + global _segment_config # Specify what should happen when 'Save' button is pressed - self.segment_config = { + _segment_config = { "crop_size": crop_size, "p_salt_pepper": p_salt_pepper, "num_infer_iterations": num_infer_iterations, @@ -357,7 +360,7 @@ def prepare_for_training(self): # check if model_config object exists if _model_config is None: _model_config = { - "num_fmaps": 256, + "num_fmaps": 24, "fmap_inc_factor": 3, "features_in_last_layer": 64, "downsampling_factors": 2, @@ -366,6 +369,8 @@ def prepare_for_training(self): } self.update_mode(self.sender()) + print(_model_config) + if self.mode == "training": self.worker = self.train_napari() self.worker.yielded.connect(self.on_yield) @@ -540,7 +545,7 @@ def on_yield(self, step_data): self.update_canvas() elif self.mode == "segmenting": print(step_data) - self.pbar.setValue(step_data) + # self.pbar.setValue(step_data) def update_canvas(self): self.loss_plot.set_xdata(self.iterations) From e194432b1f7d98814228d269f7b9880eab6069c4 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 23 Oct 2023 00:40:48 -0400 Subject: [PATCH 35/40] Update segment button text after completion of segmentation --- src/napari_cellulus/widgets/_widget.py | 45 +++++++++++++++++--------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 43368aa..60ed770 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -30,6 +30,7 @@ QWidget, ) from scipy.ndimage import distance_transform_edt as dtedt +from skimage.filters import threshold_otsu from superqt import QCollapsible from ..dataset import NapariDataset @@ -367,9 +368,9 @@ def prepare_for_training(self): "downsampling_layers": 1, "initialize": True, } - self.update_mode(self.sender()) - print(_model_config) + print(self.sender()) + self.update_mode(self.sender()) if self.mode == "training": self.worker = self.train_napari() @@ -554,8 +555,15 @@ def update_canvas(self): self.canvas.axes.autoscale_view() self.canvas.draw() - def on_return(self): - pass # TODO + def on_return(self, layers): + # Describes what happens once segment button has completed + + for data, metadata, layer_type in layers: + if layer_type == "image": + self.viewer.add_image(data, **metadata) + elif layer_type == "labels": + self.viewer.add_labels(data.astype(int), **metadata) + self.update_mode(self.segment_button) def prepare_for_segmenting(self): global _segment_config @@ -611,8 +619,7 @@ def segment_napari(self): crop_size = (_segment_config["crop_size"],) * num_spatial_dims device = self.device_combo_box.currentText() - if num_channels == 0: - num_channels = 1 + num_channels_temp = 1 if num_channels == 0 else num_channels voxel_size = gp.Coordinate((1,) * num_spatial_dims) _model.set_infer( @@ -621,13 +628,11 @@ def segment_napari(self): device=device, ) - print(f"Current device is {device}") - - input_shape = gp.Coordinate((1, num_channels, *crop_size)) + input_shape = gp.Coordinate((1, num_channels_temp, *crop_size)) output_shape = gp.Coordinate( _model( torch.zeros( - (1, num_channels, *crop_size), dtype=torch.float32 + (1, num_channels_temp, *crop_size), dtype=torch.float32 ).to(device) ).shape ) @@ -719,8 +724,17 @@ def segment_napari(self): for i in range(num_spatial_dims + 1) ] - labels = np.zeros_like(prediction[:, 0:1, ...].data, dtype=np.uint64) - for sample in range(num_samples): + foreground = np.zeros_like(prediction[:, 0:1, ...], dtype=bool) + for sample in range(prediction.shape[0]): + embeddings = prediction[sample] + embeddings_std = embeddings[-1, ...] + thresh = threshold_otsu(embeddings_std) + print(f"Threshold for sample {sample} is {thresh}") + binary_mask = embeddings_std < thresh + foreground[sample, 0, ...] = binary_mask + + labels = np.zeros_like(prediction[:, 0:1, ...], dtype=np.uint64) + for sample in range(prediction.shape[0]): embeddings = prediction[sample] embeddings_std = embeddings[-1, ...] embeddings_mean = embeddings[np.newaxis, :num_spatial_dims, ...] @@ -733,10 +747,8 @@ def segment_napari(self): ) labels[sample, 0, ...] = segmentation - pp_labels = np.zeros_like( - prediction[:, 0:1, ...].data, dtype=np.uint64 - ) - for sample in range(num_samples): + pp_labels = np.zeros_like(prediction[:, 0:1, ...], dtype=np.uint64) + for sample in range(prediction.shape[0]): segmentation = labels[sample, 0] distance_foreground = dtedt(segmentation == 0) expanded_mask = ( @@ -749,6 +761,7 @@ def segment_napari(self): pp_labels[sample, 0, ...] = segmentation return ( prediction_layers + + [(foreground, {"name": "Foreground"}, "labels")] + [(labels, {"name": "Segmentation"}, "labels")] + [(pp_labels, {"name": "Post Processed"}, "labels")] ) From 321b7696df4fa1e88db145e462e1b466e0b8a5fe Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 23 Oct 2023 01:10:02 -0400 Subject: [PATCH 36/40] Disable segment button while training, and vice-versa --- src/napari_cellulus/widgets/_widget.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 60ed770..9dc8db6 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -394,23 +394,27 @@ def update_mode(self, sender): if self.train_button.text() == "Train" and sender == self.train_button: self.train_button.setText("Pause") self.mode = "training" + self.segment_button.setEnabled(False) elif ( self.train_button.text() == "Pause" and sender == self.train_button ): self.train_button.setText("Train") self.mode = "configuring" + self.segment_button.setEnabled(True) elif ( self.segment_button.text() == "Segment" and sender == self.segment_button ): self.segment_button.setText("Pause") self.mode = "segmenting" + self.train_button.setEnabled(False) elif ( self.segment_button.text() == "Pause" and sender == self.segment_button ): self.segment_button.setText("Segment") self.mode = "configuring" + self.train_button.setEnabled(True) @thread_worker def train_napari(self): From a730c99fe78ea7ca6e93b5ee55c7c65d89fee18f Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 23 Oct 2023 01:27:26 -0400 Subject: [PATCH 37/40] Remove clas TrainingStats --- src/napari_cellulus/widgets/_widget.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 9dc8db6..bb505ec 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -1,4 +1,3 @@ -import dataclasses import os import time @@ -39,24 +38,6 @@ # local package imports from ..gui_helpers import MplCanvas, layer_choice_widget - -@dataclasses.dataclass -class TrainingStats: - iteration: int = 0 - losses: list[float] = dataclasses.field(default_factory=list) - iterations: list[int] = dataclasses.field(default_factory=list) - - def reset(self): - self.iteration = 0 - self.losses = [] - self.iterations = [] - - def load(self, other): - self.iteration = other.iteration - self.losses = other.losses - self.iterations = other.iterations - - ############ GLOBALS ################### time_now = 0 _train_config = None From 1f33fdbf62b643cfdf892c05c341ddb024b514e8 Mon Sep 17 00:00:00 2001 From: lmanan Date: Mon, 23 Oct 2023 10:45:11 -0400 Subject: [PATCH 38/40] Slight edit for scenario where non zero samples and 0 channels --- src/napari_cellulus/dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py index 60840d1..9b3e4ae 100644 --- a/src/napari_cellulus/dataset.py +++ b/src/napari_cellulus/dataset.py @@ -94,7 +94,6 @@ def __setup_pipeline(self): NapariImageSource( self.layer, self.raw, raw_spec, self.spatial_dims ) - + gp.Unsqueeze([self.raw], 1) + gp.RandomLocation() + gp.Unsqueeze([self.raw], 1) ) From f92e357c7a79b6ccdf36ca8a66e8f03bfd54252c Mon Sep 17 00:00:00 2001 From: lmanan Date: Fri, 27 Oct 2023 11:15:08 -0400 Subject: [PATCH 39/40] Update default value of p_salt_pepper --- src/napari_cellulus/widgets/_widget.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index bb505ec..a9d312a 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -556,7 +556,7 @@ def prepare_for_segmenting(self): if _segment_config is None: _segment_config = { "crop_size": 252, - "p_salt_pepper": 0.1, + "p_salt_pepper": 0.01, "num_infer_iterations": 16, "bandwidth": None, "reduction_probability": 0.1, From f0831ae67a792044833fd1c89ed0590d5113aabb Mon Sep 17 00:00:00 2001 From: lmanan Date: Fri, 27 Oct 2023 16:57:52 -0400 Subject: [PATCH 40/40] Update canvas --- src/napari_cellulus/widgets/_widget.py | 95 +++++++++++++++----------- 1 file changed, 54 insertions(+), 41 deletions(-) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index a9d312a..09e1dc5 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -4,6 +4,7 @@ import gunpowder as gp import napari import numpy as np +import pyqtgraph as pg import torch from cellulus.criterions import get_loss from cellulus.models import get_model @@ -13,8 +14,7 @@ # widget stuff from napari.qt.threading import thread_worker -from qtpy.QtCore import Qt, QUrl -from qtpy.QtGui import QDesktopServices +from qtpy.QtCore import Qt from qtpy.QtWidgets import ( QCheckBox, QComboBox, @@ -23,6 +23,7 @@ QHBoxLayout, QLabel, QLineEdit, + QMainWindow, QPushButton, QScrollArea, QVBoxLayout, @@ -36,7 +37,7 @@ from ..gp.nodes.napari_image_source import NapariImageSource # local package imports -from ..gui_helpers import MplCanvas, layer_choice_widget +from ..gui_helpers import layer_choice_widget ############ GLOBALS ################### time_now = 0 @@ -49,10 +50,11 @@ _dataset = None -class SegmentationWidget(QScrollArea): +class SegmentationWidget(QMainWindow): def __init__(self, napari_viewer): super().__init__() self.widget = QWidget() + self.scroll = QScrollArea() self.viewer = napari_viewer # initialize train_config and model_config @@ -95,9 +97,6 @@ def __init__(self, napari_viewer): grid_layout.addWidget(device_label, 1, 0) grid_layout.addWidget(self.device_combo_box, 1, 1) - # global_params_widget = QWidget("") - # global_params_widget.setLayout(grid_layout) - # Initialize train configs widget collapsible_train_configs = QCollapsible("Train Configs", self) collapsible_train_configs.addWidget(self.create_train_configs_widget) @@ -107,16 +106,22 @@ def __init__(self, napari_viewer): collapsible_model_configs.addWidget(self.create_model_configs_widget) # Initialize loss/iterations widget - self.canvas = MplCanvas(self, width=5, height=5, dpi=100) - canvas_layout = QVBoxLayout() - canvas_layout.addWidget(self.canvas) - if len(self.iterations) == 0: - self.loss_plot = self.canvas.axes.plot([], [], label="")[0] - self.canvas.axes.legend() - self.canvas.axes.set_xlabel("Iterations") - self.canvas.axes.set_ylabel("Loss") - plot_container_widget = QWidget() - plot_container_widget.setLayout(canvas_layout) + + self.losses_widget = pg.PlotWidget() + self.losses_widget.setBackground((37, 41, 49)) + styles = {"color": "white", "font-size": "16px"} + self.losses_widget.setLabel("left", "Loss", **styles) + self.losses_widget.setLabel("bottom", "Iterations", **styles) + # self.canvas = MplCanvas(self, width=5, height=10, dpi=100) + # canvas_layout = QVBoxLayout() + # canvas_layout.addWidget(self.canvas) + # if len(self.iterations) == 0: + # self.loss_plot = self.canvas.axes.plot([], [], label="")[0] + # self.canvas.axes.legend() + # self.canvas.axes.set_xlabel("Iterations") + # self.canvas.axes.set_ylabel("Loss") + # plot_container_widget = QWidget() + # plot_container_widget.setLayout(canvas_layout) # Initialize Layer Choice self.raw_selector = layer_choice_widget( @@ -146,9 +151,17 @@ def __init__(self, napari_viewer): self.train_button.clicked.connect(self.prepare_for_training) # Initialize Save and Load Widget - # collapsible_save_load_widget = QCollapsible("Save/Load", self) - # collapsible_save_load_widget.addWidget(self.save_widget.native) - # collapsible.save_load_widget.addWidget(self.load_widget.native) + collapsible_save_load_widget = QCollapsible( + "Save and Load Model", self + ) + save_load_layout = QHBoxLayout() + save_model_button = QPushButton("Save Model", self) + load_model_button = QPushButton("Load Model", self) + save_load_layout.addWidget(save_model_button) + save_load_layout.addWidget(load_model_button) + save_load_widget = QWidget() + save_load_widget.setLayout(save_load_layout) + collapsible_save_load_widget.addWidget(save_load_widget) # Initialize Segment Configs widget collapsible_segment_configs = QCollapsible("Inference Configs", self) @@ -164,13 +177,8 @@ def __init__(self, napari_viewer): # self.pbar = QProgressBar(self) # Initialize Feedback Button - self.feedback_button = QPushButton("Feedback!", self) - self.feedback_button.clicked.connect( - lambda: QDesktopServices.openUrl( - QUrl( - "https://github.com/funkelab/napari-cellulus/issues/new/choose" - ) - ) + self.feedback_label = QLabel( + 'Please share any feedback here.' ) # Add all components to outer_layout @@ -180,22 +188,25 @@ def __init__(self, napari_viewer): # outer_layout.addWidget(global_params_widget) outer_layout.addWidget(collapsible_train_configs) outer_layout.addWidget(collapsible_model_configs) - outer_layout.addWidget(plot_container_widget) + # outer_layout.addWidget(plot_container_widget) + outer_layout.addWidget(self.losses_widget) outer_layout.addWidget(self.raw_selector.native) outer_layout.addWidget(axis_selector) outer_layout.addWidget(self.train_button) - # outer_layout.addWidget(collapsible_save_load_widget) + outer_layout.addWidget(collapsible_save_load_widget) outer_layout.addWidget(collapsible_segment_configs) outer_layout.addWidget(self.segment_button) - outer_layout.addWidget(self.feedback_button) - + outer_layout.addWidget(self.feedback_label) outer_layout.setSpacing(20) self.widget.setLayout(outer_layout) - self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) - self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - self.setWidgetResizable(True) - self.setWidget(self.widget) + + self.scroll.setWidget(self.widget) + self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) + self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.scroll.setWidgetResizable(True) + self.setFixedWidth(400) + self.setCentralWidget(self.scroll) @property def create_train_configs_widget(self): @@ -272,7 +283,7 @@ def create_segment_configs_widget(self): @magic_factory(call_button="Save") def segment_configs_widget( crop_size: int = 252, - p_salt_pepper: float = 0.1, + p_salt_pepper: float = 0.01, num_infer_iterations: int = 16, bandwidth: int = 7, reduction_probability: float = 0.1, @@ -534,11 +545,13 @@ def on_yield(self, step_data): # self.pbar.setValue(step_data) def update_canvas(self): - self.loss_plot.set_xdata(self.iterations) - self.loss_plot.set_ydata(self.losses) - self.canvas.axes.relim() - self.canvas.axes.autoscale_view() - self.canvas.draw() + self.losses_widget.plot(self.iterations, self.losses) + + # self.loss_plot.set_xdata(self.iterations) + # self.loss_plot.set_ydata(self.losses) + # self.canvas.axes.relim() + # self.canvas.axes.autoscale_view() + # self.canvas.draw() def on_return(self, layers): # Describes what happens once segment button has completed