diff --git a/MANIFEST.in b/MANIFEST.in
index f3155af..411b69a 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,5 +1,3 @@
-include LICENSE
-include README.md
-
+include src/napari_cellulus/sample_data/*.npy
+include src/napari_cellulus/napari.yaml
recursive-exclude * __pycache__
-recursive-exclude * *.py[co]
diff --git a/README.md b/README.md
index 352e0aa..3123f90 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,8 @@
+
A napari plugin for cellulus
+
+
- **[Introduction](#introduction)**
- **[Installation](#installation)**
- **[Getting Started](#getting-started)**
@@ -55,8 +58,10 @@ Run the following commands in a terminal window:
conda activate napari-cellulus
napari
```
+[demo_cellulus.webm](https://github.com/funkelab/napari-cellulus/assets/34229641/35cb09de-c875-487d-9890-86082dcd95b2)
+
+
-Next, select `Cellulus` from the `Plugins` drop-down menu.
### Citation
diff --git a/setup.cfg b/setup.cfg
index ce0031b..8fd5c50 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -46,6 +46,10 @@ package_dir =
[options.packages.find]
where = src
+[options.package_data]
+napari_cellulus = *.npy
+* = *.yaml
+
[options.entry_points]
napari.manifest =
napari-cellulus = napari_cellulus:napari.yaml
@@ -58,7 +62,3 @@ testing =
pytest-qt # https://pytest-qt.readthedocs.io/en/latest/
napari
pyqt5
-
-
-[options.package_data]
-* = *.yaml
diff --git a/src/napari_cellulus/__init__.py b/src/napari_cellulus/__init__.py
index bf80da0..b666730 100644
--- a/src/napari_cellulus/__init__.py
+++ b/src/napari_cellulus/__init__.py
@@ -1,5 +1,5 @@
__version__ = "0.1.0"
-from .sample_data import tissue_net_sample
+from .load_sample_data import load_fluo_c2dl_huh7_sample
-__all__ = ("tissue_net_sample",)
+__all__ = ("load_fluo_c2dl_huh7_sample",)
diff --git a/src/napari_cellulus/datasets/napari_dataset.py b/src/napari_cellulus/datasets/napari_dataset.py
index 3952ab8..0a4b324 100644
--- a/src/napari_cellulus/datasets/napari_dataset.py
+++ b/src/napari_cellulus/datasets/napari_dataset.py
@@ -71,6 +71,7 @@ def __setup_pipeline(self):
)
+ gp.Unsqueeze([self.raw], 0)
+ gp.RandomLocation()
+ + gp.Normalize(self.raw, factor=self.normalization_factor)
)
else:
self.pipeline = (
@@ -81,6 +82,7 @@ def __setup_pipeline(self):
spatial_dims=self.spatial_dims,
)
+ gp.RandomLocation()
+ + gp.Normalize(self.raw, factor=self.normalization_factor)
)
def __iter__(self):
@@ -236,9 +238,20 @@ def sample_coordinates(self):
return anchor_samples, reference_samples
def get_num_anchors(self):
- return int(
- self.density * self.unbiased_shape[0] * self.unbiased_shape[1]
- )
+ if self.num_spatial_dims == 2:
+ return int(
+ self.density * self.unbiased_shape[0] * self.unbiased_shape[1]
+ )
+ elif self.num_spatial_dims == 3:
+ return int(
+ self.density
+ * self.unbiased_shape[0]
+ * self.unbiased_shape[1]
+ * self.unbiased_shape[2]
+ )
def get_num_references(self):
- return int(self.density * self.kappa**2 * np.pi)
+ if self.num_spatial_dims == 2:
+ return int(self.density * np.pi * self.kappa**2)
+ elif self.num_spatial_dims == 3:
+ return int(self.density * 4 / 3 * np.pi * self.kappa**3)
diff --git a/src/napari_cellulus/datasets/napari_image_source.py b/src/napari_cellulus/datasets/napari_image_source.py
index bee332f..2956af2 100644
--- a/src/napari_cellulus/datasets/napari_image_source.py
+++ b/src/napari_cellulus/datasets/napari_image_source.py
@@ -19,6 +19,7 @@ def __init__(
self, image: Image, key: gp.ArrayKey, spec: ArraySpec, spatial_dims
):
self.array_spec = spec
+
self.image = gp.Array(
data=normalize(
image.data.astype(np.float32), 1, 99.8, axis=spatial_dims
diff --git a/src/napari_cellulus/load_sample_data.py b/src/napari_cellulus/load_sample_data.py
new file mode 100644
index 0000000..61c4e45
--- /dev/null
+++ b/src/napari_cellulus/load_sample_data.py
@@ -0,0 +1,25 @@
+from pathlib import Path
+
+import numpy as np
+
+FLUO_C2DL_HUH7_SAMPLE_PATH = (
+ Path(__file__).parent / "sample_data/Fluo-C2DL-Huh7-sample.npy"
+)
+
+
+def load_fluo_c2dl_huh7_sample():
+ raw = np.load(FLUO_C2DL_HUH7_SAMPLE_PATH)
+ num_samples = raw.shape[0]
+ indices = np.random.choice(np.arange(num_samples), 5, replace=False)
+ raw = raw[indices]
+
+ return [
+ (
+ raw,
+ {
+ "name": "Raw",
+ "metadata": {"axes": ["s", "c", "y", "x"]},
+ },
+ "image",
+ )
+ ]
diff --git a/src/napari_cellulus/model.py b/src/napari_cellulus/model.py
new file mode 100644
index 0000000..f262e84
--- /dev/null
+++ b/src/napari_cellulus/model.py
@@ -0,0 +1,48 @@
+import torch
+
+
+class Model(torch.nn.Module):
+ """
+ This class is a wrapper on the model object returned by cellulus.
+ It updates the `forward` function and handles cases when the input raw
+ image is not (S, C, (Z), Y, X) type.
+ """
+
+ def __init__(self, model, selected_axes):
+ super().__init__()
+ self.model = model
+ self.selected_axes = selected_axes
+
+ def forward(self, raw):
+ if "s" in self.selected_axes and "c" in self.selected_axes:
+ pass
+ elif "s" in self.selected_axes and "c" not in self.selected_axes:
+
+ raw = torch.unsqueeze(raw, 1)
+ elif "s" not in self.selected_axes and "c" in self.selected_axes:
+ pass
+ elif "s" not in self.selected_axes and "c" not in self.selected_axes:
+ raw = torch.unsqueeze(raw, 1)
+ return self.model(raw)
+
+ @staticmethod
+ def select_and_add_coordinates(outputs, coordinates):
+ selections = []
+ # outputs.shape = (b, c, h, w) or (b, c, d, h, w)
+ for output, coordinate in zip(outputs, coordinates):
+ if output.ndim == 3:
+ selection = output[:, coordinate[:, 1], coordinate[:, 0]]
+ elif output.ndim == 4:
+ selection = output[
+ :, coordinate[:, 2], coordinate[:, 1], coordinate[:, 0]
+ ]
+ selection = selection.transpose(1, 0)
+ selection += coordinate
+ selections.append(selection)
+
+ # selection.shape = (b, c, p) where p is the number of selected positions
+ return torch.stack(selections, dim=0)
+
+ def set_infer(self, p_salt_pepper, num_infer_iterations, device):
+ self.model.eval()
+ self.model.set_infer(p_salt_pepper, num_infer_iterations, device)
diff --git a/src/napari_cellulus/napari.yaml b/src/napari_cellulus/napari.yaml
index 39ac969..0c1c0ae 100644
--- a/src/napari_cellulus/napari.yaml
+++ b/src/napari_cellulus/napari.yaml
@@ -1,23 +1,17 @@
name: napari-cellulus
-display_name: Cellulus
+display_name: napari-cellulus
contributions:
commands:
- - id: napari-cellulus.tissue_net_sample
- python_name: napari_cellulus.sample_data:tissue_net_sample
- title: Load sample data from Cellulus
- - id: napari-cellulus.fluo_n2dl_hela_sample
- python_name: napari_cellulus.sample_data:fluo_n2dl_hela_sample
- title: Load sample data from Cellulus
+ - id: napari-cellulus.load_fluo_c2dl_huh7_sample
+ python_name: napari_cellulus.load_sample_data:load_fluo_c2dl_huh7_sample
+ title: Load sample data
- id: napari-cellulus.Widget
python_name: napari_cellulus.widget:Widget
title: Cellulus
sample_data:
- - command: napari-cellulus.tissue_net_sample
- display_name: TissueNet
- key: tissue_net_sample
- - command: napari-cellulus.fluo_n2dl_hela_sample
- display_name: Fluo-N2DL-HeLa
- key: fluo_n2dl_hela_sample
+ - command: napari-cellulus.load_fluo_c2dl_huh7_sample
+ display_name: Fluo-C2DL-Huh7
+ key: load_fluo_c2dl_huh7_sample
widgets:
- command: napari-cellulus.Widget
display_name: Cellulus
diff --git a/src/napari_cellulus/sample_data.py b/src/napari_cellulus/sample_data.py
deleted file mode 100644
index 2a4efc0..0000000
--- a/src/napari_cellulus/sample_data.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from pathlib import Path
-
-import numpy as np
-import tifffile
-
-TISSUE_NET_SAMPLE = Path(__file__).parent / "sample_data/tissue_net_sample.npy"
-FLUO_N2DL_HELA = Path(__file__).parent / "sample_data/fluo_n2dl_hela.tif"
-
-
-def fluo_n2dl_hela_sample():
- x = tifffile.imread(FLUO_N2DL_HELA)
- return [
- (
- x,
- {
- "name": "Raw",
- "metadata": {"axes": ["s", "c", "y", "x"]},
- },
- "image",
- )
- ]
-
-
-def tissue_net_sample():
- (x, y) = np.load(TISSUE_NET_SAMPLE, "r")
- x = x.transpose(0, 3, 1, 2)
- y = y.transpose(0, 3, 1, 2).astype(np.uint8)
- return [
- (
- x,
- {
- "name": "Raw",
- "metadata": {"axes": ["s", "c", "y", "x"]},
- },
- "image",
- ),
- (
- y,
- {
- "name": "Labels",
- "metadata": {"axes": ["s", "c", "y", "x"]},
- },
- "Labels",
- ),
- ]
diff --git a/src/napari_cellulus/sample_data/Fluo-C2DL-Huh7-sample.npy b/src/napari_cellulus/sample_data/Fluo-C2DL-Huh7-sample.npy
new file mode 100644
index 0000000..1fe3ad3
Binary files /dev/null and b/src/napari_cellulus/sample_data/Fluo-C2DL-Huh7-sample.npy differ
diff --git a/src/napari_cellulus/sample_data/__init__.py b/src/napari_cellulus/sample_data/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/napari_cellulus/sample_data/fluo_n2dl_hela.tif b/src/napari_cellulus/sample_data/fluo_n2dl_hela.tif
deleted file mode 100644
index f28870c..0000000
Binary files a/src/napari_cellulus/sample_data/fluo_n2dl_hela.tif and /dev/null differ
diff --git a/src/napari_cellulus/sample_data/tissue_net_sample.npy b/src/napari_cellulus/sample_data/tissue_net_sample.npy
deleted file mode 100644
index a9191d4..0000000
Binary files a/src/napari_cellulus/sample_data/tissue_net_sample.npy and /dev/null differ
diff --git a/src/napari_cellulus/widget.py b/src/napari_cellulus/widget.py
index bbad8b8..d5f8f46 100644
--- a/src/napari_cellulus/widget.py
+++ b/src/napari_cellulus/widget.py
@@ -21,6 +21,7 @@
QButtonGroup,
QCheckBox,
QComboBox,
+ QFileDialog,
QGridLayout,
QLabel,
QLineEdit,
@@ -29,6 +30,7 @@
QRadioButton,
QScrollArea,
QVBoxLayout,
+ QWidget,
)
from scipy.ndimage import binary_fill_holes
from scipy.ndimage import distance_transform_edt as dtedt
@@ -37,47 +39,7 @@
from .datasets.napari_dataset import NapariDataset
from .datasets.napari_image_source import NapariImageSource
-
-
-class Model(torch.nn.Module):
- def __init__(self, model, selected_axes):
- super().__init__()
- self.model = model
- self.selected_axes = selected_axes
-
- def forward(self, x):
- if "s" in self.selected_axes and "c" in self.selected_axes:
- pass
- elif "s" in self.selected_axes and "c" not in self.selected_axes:
-
- x = torch.unsqueeze(x, 1)
- elif "s" not in self.selected_axes and "c" in self.selected_axes:
- pass
- elif "s" not in self.selected_axes and "c" not in self.selected_axes:
- x = torch.unsqueeze(x, 1)
- return self.model(x)
-
- @staticmethod
- def select_and_add_coordinates(outputs, coordinates):
- selections = []
- # outputs.shape = (b, c, h, w) or (b, c, d, h, w)
- for output, coordinate in zip(outputs, coordinates):
- if output.ndim == 3:
- selection = output[:, coordinate[:, 1], coordinate[:, 0]]
- elif output.ndim == 4:
- selection = output[
- :, coordinate[:, 2], coordinate[:, 1], coordinate[:, 0]
- ]
- selection = selection.transpose(1, 0)
- selection += coordinate
- selections.append(selection)
-
- # selection.shape = (b, c, p) where p is the number of selected positions
- return torch.stack(selections, dim=0)
-
- def set_infer(self, p_salt_pepper, num_infer_iterations, device):
- self.model.eval()
- self.model.set_infer(p_salt_pepper, num_infer_iterations, device)
+from .model import Model
class Widget(QMainWindow):
@@ -85,6 +47,7 @@ def __init__(self, napari_viewer):
super().__init__()
self.viewer = napari_viewer
self.scroll = QScrollArea()
+ self.widget = QWidget()
# initialize outer layout
layout = QVBoxLayout()
@@ -105,10 +68,6 @@ def __init__(self, napari_viewer):
self.set_grid_6()
self.grid_7 = QGridLayout() # feedback
self.set_grid_7()
- self.create_configs() # configs
- self.viewer.dims.events.current_step.connect(
- self.update_inference_widgets
- ) # listen to viewer slider
layout.addLayout(self.grid_0)
layout.addLayout(self.grid_1)
@@ -118,11 +77,18 @@ def __init__(self, napari_viewer):
layout.addLayout(self.grid_5)
layout.addLayout(self.grid_6)
layout.addLayout(self.grid_7)
- self.set_scroll_area(layout)
+ self.widget.setLayout(layout)
+ self.set_scroll_area()
self.viewer.layers.events.inserted.connect(self.update_raw_selector)
self.viewer.layers.events.removed.connect(self.update_raw_selector)
def update_raw_selector(self, event):
+ """
+ Whenever a new image is added or removed by the user,
+ this function is called.
+ It updates the `raw_selector` attribute.
+
+ """
count = 0
for i in range(self.raw_selector.count() - 1, -1, -1):
if self.raw_selector.itemText(i) == f"{event.value}":
@@ -133,6 +99,9 @@ def update_raw_selector(self, event):
self.raw_selector.addItems([f"{event.value}"])
def set_grid_0(self):
+ """
+ Specifies the title of the plugin.
+ """
text_label = QLabel("Cellulus
")
method_description_label = QLabel(
'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.
https://github.com/funkelab/cellulus'
@@ -141,6 +110,9 @@ def set_grid_0(self):
self.grid_0.addWidget(method_description_label, 1, 0, 2, 1)
def set_grid_1(self):
+ """
+ Specifies the device used for training and inference.
+ """
device_label = QLabel(self)
device_label.setText("Device")
self.device_combo_box = QComboBox(self)
@@ -152,6 +124,10 @@ def set_grid_1(self):
self.grid_1.addWidget(self.device_combo_box, 0, 1, 1, 1)
def set_grid_2(self):
+ """
+ Specifies the raw_selector attribute.
+ This is needed to identify which axes does the image contain.
+ """
self.raw_selector = QComboBox(self)
for layer in self.viewer.layers:
self.raw_selector.addItem(f"{layer}")
@@ -169,6 +145,9 @@ def set_grid_2(self):
self.grid_2.addWidget(self.x_check_box, 1, 4, 1, 1)
def set_grid_3(self):
+ """
+ Specifies the configuration parameters for training.
+ """
crop_size_label = QLabel(self)
crop_size_label.setText("Crop Size")
self.crop_size_line = QLineEdit(self)
@@ -183,7 +162,7 @@ def set_grid_3(self):
max_iterations_label.setText("Max iterations")
self.max_iterations_line = QLineEdit(self)
self.max_iterations_line.setAlignment(Qt.AlignCenter)
- self.max_iterations_line.setText("100000")
+ self.max_iterations_line.setText("5000")
self.grid_3.addWidget(crop_size_label, 0, 0, 1, 1)
self.grid_3.addWidget(self.crop_size_line, 0, 1, 1, 1)
self.grid_3.addWidget(batch_size_label, 1, 0, 1, 1)
@@ -192,6 +171,9 @@ def set_grid_3(self):
self.grid_3.addWidget(self.max_iterations_line, 2, 1, 1, 1)
def set_grid_4(self):
+ """
+ Specifies the configuration parameters for the model.
+ """
feature_maps_label = QLabel(self)
feature_maps_label.setText("Number of feature maps")
self.feature_maps_line = QLineEdit(self)
@@ -203,59 +185,80 @@ def set_grid_4(self):
self.feature_maps_increase_line.setAlignment(Qt.AlignCenter)
self.feature_maps_increase_line.setText("3")
self.train_model_from_scratch_checkbox = QCheckBox(
- "Train model from scratch"
+ "Train from scratch"
)
-
+ self.train_model_from_scratch_checkbox.stateChanged.connect(
+ self.affect_load_weights
+ )
+ self.load_model_button = QPushButton("Load weights")
+ self.load_model_button.clicked.connect(self.load_weights)
self.train_model_from_scratch_checkbox.setChecked(False)
self.grid_4.addWidget(feature_maps_label, 0, 0, 1, 1)
self.grid_4.addWidget(self.feature_maps_line, 0, 1, 1, 1)
self.grid_4.addWidget(feature_maps_increase_label, 1, 0, 1, 1)
self.grid_4.addWidget(self.feature_maps_increase_line, 1, 1, 1, 1)
self.grid_4.addWidget(
- self.train_model_from_scratch_checkbox, 2, 0, 1, 2
+ self.train_model_from_scratch_checkbox, 2, 0, 1, 1
)
+ self.grid_4.addWidget(self.load_model_button, 2, 1, 1, 1)
def set_grid_5(self):
+ """
+ Specifies the loss widget.
+ """
self.losses_widget = pg.PlotWidget()
self.losses_widget.setBackground((37, 41, 49))
styles = {"color": "white", "font-size": "16px"}
self.losses_widget.setLabel("left", "Loss", **styles)
self.losses_widget.setLabel("bottom", "Iterations", **styles)
self.start_training_button = QPushButton("Start training")
- self.start_training_button.setFixedSize(140, 30)
+ self.start_training_button.setFixedSize(88, 30)
self.stop_training_button = QPushButton("Stop training")
- self.stop_training_button.setFixedSize(140, 30)
+ self.stop_training_button.setFixedSize(88, 30)
+ self.save_weights_button = QPushButton("Save weights")
+ self.save_weights_button.setFixedSize(88, 30)
self.grid_5.addWidget(self.losses_widget, 0, 0, 4, 4)
- self.grid_5.addWidget(self.start_training_button, 5, 0, 1, 2)
- self.grid_5.addWidget(self.stop_training_button, 5, 2, 1, 2)
+ self.grid_5.addWidget(self.start_training_button, 5, 0, 1, 1)
+ self.grid_5.addWidget(self.stop_training_button, 5, 1, 1, 1)
+ self.grid_5.addWidget(self.save_weights_button, 5, 2, 1, 1)
+
self.start_training_button.clicked.connect(
self.prepare_for_start_training
)
self.stop_training_button.clicked.connect(
self.prepare_for_stop_training
)
+ self.save_weights_button.clicked.connect(self.save_weights)
def set_grid_6(self):
+ """
+ Specifies the inference configuration parameters.
+ """
threshold_label = QLabel("Threshold")
self.threshold_line = QLineEdit(self)
+ self.threshold_line.textChanged.connect(self.prepare_thresholds)
self.threshold_line.setAlignment(Qt.AlignCenter)
self.threshold_line.setText(None)
bandwidth_label = QLabel("Bandwidth")
self.bandwidth_line = QLineEdit(self)
self.bandwidth_line.setAlignment(Qt.AlignCenter)
+ self.bandwidth_line.textChanged.connect(self.prepare_bandwidths)
self.radio_button_group = QButtonGroup(self)
self.radio_button_cell = QRadioButton("Cell")
self.radio_button_nucleus = QRadioButton("Nucleus")
self.radio_button_group.addButton(self.radio_button_nucleus)
self.radio_button_group.addButton(self.radio_button_cell)
-
- self.radio_button_nucleus.setChecked(True)
+ self.radio_button_cell.toggled.connect(self.update_post_processing)
+ self.radio_button_nucleus.toggled.connect(self.update_post_processing)
+ self.radio_button_cell.setChecked(True)
self.min_size_label = QLabel("Minimum Size")
self.min_size_line = QLineEdit(self)
self.min_size_line.setAlignment(Qt.AlignCenter)
+ self.min_size_line.textChanged.connect(self.prepare_min_sizes)
+
self.start_inference_button = QPushButton("Start inference")
self.start_inference_button.setFixedSize(140, 30)
self.stop_inference_button = QPushButton("Stop inference")
@@ -263,6 +266,7 @@ def set_grid_6(self):
self.grid_6.addWidget(threshold_label, 0, 0, 1, 1)
self.grid_6.addWidget(self.threshold_line, 0, 1, 1, 1)
+
self.grid_6.addWidget(bandwidth_label, 1, 0, 1, 1)
self.grid_6.addWidget(self.bandwidth_line, 1, 1, 1, 1)
self.grid_6.addWidget(self.radio_button_cell, 2, 0, 1, 1)
@@ -279,22 +283,33 @@ def set_grid_6(self):
)
def set_grid_7(self):
- # Initialize Feedback Button
+ """
+ Specifies the feedback URL.
+ """
+
feedback_label = QLabel(
'Please share any feedback here.'
)
self.grid_7.addWidget(feedback_label, 0, 0, 2, 1)
- def set_scroll_area(self, layout):
- self.scroll.setLayout(layout)
+ def set_scroll_area(self):
+ """
+ Creates a scroll area.
+ In case the main napari window is resized, the scroll area
+ would appear.
+ """
+ self.scroll.setWidget(self.widget)
self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.scroll.setWidgetResizable(True)
- self.setFixedWidth(300)
+ self.setFixedWidth(320)
self.setCentralWidget(self.scroll)
def get_selected_axes(self):
+ """
+ Returns the axes based on which of the checkboxes were selected.
+ """
names = []
for name, check_box in zip(
"sczyx",
@@ -312,57 +327,78 @@ def get_selected_axes(self):
return names
def create_configs(self):
- self.train_config = TrainConfig(
- crop_size=[int(self.crop_size_line.text())],
- batch_size=int(self.batch_size_line.text()),
- max_iterations=int(self.max_iterations_line.text()),
- device=self.device_combo_box.currentText(),
- )
- self.model_config = ModelConfig(
- num_fmaps=int(self.feature_maps_line.text()),
- fmap_inc_factor=int(self.feature_maps_increase_line.text()),
- )
+ """
+ This reads from the various line edits and initializes config objects.
+ """
+ if not hasattr(self, "train_config"):
+ self.train_config = TrainConfig(
+ crop_size=[int(self.crop_size_line.text())],
+ batch_size=int(self.batch_size_line.text()),
+ max_iterations=int(self.max_iterations_line.text()),
+ device=self.device_combo_box.currentText(),
+ )
+ if not hasattr(self, "model_config"):
+ self.model_config = ModelConfig(
+ num_fmaps=int(self.feature_maps_line.text()),
+ fmap_inc_factor=int(self.feature_maps_increase_line.text()),
+ )
+ if not hasattr(self, "experiment_config"):
+ self.experiment_config = ExperimentConfig(
+ train_config=asdict(self.train_config),
+ model_config=asdict(self.model_config),
+ normalization_factor=1.0,
+ )
+ if not hasattr(self, "losses"):
+ self.losses = []
+ if not hasattr(self, "iterations"):
+ self.iterations = []
+ if not hasattr(self, "start_iteration"):
+ self.start_iteration = 0
- self.experiment_config = ExperimentConfig(
- train_config=asdict(self.train_config),
- model_config=asdict(self.model_config),
- )
- self.losses, self.iterations = [], []
- self.start_iteration = 0
self.model_dir = "/tmp/models"
- self.thresholds = []
- self.band_widths = []
- self.min_sizes = []
- if len(self.thresholds) == 0:
- self.threshold_line.setEnabled(False)
- if len(self.band_widths) == 0:
- self.bandwidth_line.setEnabled(False)
- if len(self.min_sizes) == 0:
- self.min_size_line.setEnabled(False)
+ self.threshold_line.setEnabled(False)
+ self.bandwidth_line.setEnabled(False)
+ self.min_size_line.setEnabled(False)
def update_inference_widgets(self, event: Event):
+ """
+ This function listens to which sample the viewer is currently on,
+ and displays the corresponding inference config parameter for the
+ present sample.
+ """
if self.s_check_box.isChecked():
shape = event.value
sample_index = shape[0]
- if len(self.thresholds) == self.napari_dataset.get_num_samples():
- if self.thresholds[sample_index]!=None:
- self.threshold_line.setText(
- str(round(self.thresholds[sample_index], 3))
- )
- if len(self.band_widths) == self.napari_dataset.get_num_samples():
- if self.band_widths[sample_index]!=None:
- self.bandwidth_line.setText(
- str(round(self.band_widths[sample_index], 3))
- )
- if len(self.min_sizes) == self.napari_dataset.get_num_samples():
- if self.min_sizes[sample_index]!=None:
- self.min_size_line.setText(
- str(round(self.min_sizes[sample_index], 3))
- )
+ if (
+ hasattr(self, "thresholds")
+ and self.thresholds[sample_index] is not None
+ ):
+ self.threshold_line.setText(
+ str(round(self.thresholds[sample_index], 3))
+ )
+ if (
+ hasattr(self, "bandwidths")
+ and self.bandwidths[sample_index] is not None
+ ):
+ self.bandwidth_line.setText(
+ str(round(self.bandwidths[sample_index], 3))
+ )
+ if (
+ hasattr(self, "min_sizes")
+ and self.min_sizes[sample_index] is not None
+ ):
+ self.min_size_line.setText(
+ str(round(self.min_sizes[sample_index], 3))
+ )
def prepare_for_start_training(self):
+ """
+ Each time the `train` button is clicked,
+ the inference config line edits and other buttons are disabled.
+ """
self.start_training_button.setEnabled(False)
self.stop_training_button.setEnabled(True)
+ self.save_weights_button.setEnabled(False)
self.threshold_line.setEnabled(False)
self.bandwidth_line.setEnabled(False)
self.radio_button_nucleus.setEnabled(False)
@@ -373,10 +409,48 @@ def prepare_for_start_training(self):
self.train_worker = self.train()
self.train_worker.yielded.connect(self.on_yield_training)
+ self.train_worker.returned.connect(self.prepare_for_stop_training)
self.train_worker.start()
+ def remove_inference_attributes(self):
+ """
+ When training is initiated, then existing attributes such as
+ `embeddings`, `detection` etc are removed.
+ """
+ if hasattr(self, "embeddings"):
+ delattr(self, "embeddings")
+ if hasattr(self, "detection"):
+ delattr(self, "detection")
+ if hasattr(self, "segmentation"):
+ delattr(self, "segmentation")
+ if hasattr(self, "thresholds"):
+ delattr(self, "thresholds")
+ if hasattr(self, "thresholds_last"):
+ delattr(self, "thresholds_last")
+ if hasattr(self, "bandwidths"):
+ delattr(self, "bandwidths")
+ if hasattr(self, "bandwidths_last"):
+ delattr(self, "bandwidths_last")
+ if hasattr(self, "min_sizes"):
+ delattr(self, "min_sizes")
+ if hasattr(self, "min_sizes_last"):
+ delattr(self, "min_sizes_last")
+ if hasattr(self, "post_processing"):
+ delattr(self, "post_processing")
+ if hasattr(self, "post_processing_last"):
+ delattr(self, "post_processing_last")
+
@thread_worker
def train(self):
+ """
+ The main function where training happens!
+ """
+ self.create_configs() # configs
+ self.remove_inference_attributes()
+ self.viewer.dims.events.current_step.connect(
+ self.update_inference_widgets
+ ) # listen to viewer slider
+
for layer in self.viewer.layers:
if f"{layer}" == self.raw_selector.currentText():
raw_image_layer = layer
@@ -420,6 +494,7 @@ def train(self):
# Set device
self.device = torch.device(self.train_config.device)
+
model = Model(
model=model_original, selected_axes=self.get_selected_axes()
)
@@ -448,6 +523,8 @@ def train(self):
lr=self.train_config.initial_learning_rate,
weight_decay=0.01,
)
+ if hasattr(self, "pre_trained_model_checkpoint"):
+ self.model_config.checkpoint = self.pre_trained_model_checkpoint
# Resume training
if self.train_model_from_scratch_checkbox.isChecked():
@@ -474,6 +551,7 @@ def train(self):
)
# Call Train Iteration
+
for iteration, batch in tqdm(
zip(
range(self.start_iteration, self.train_config.max_iterations),
@@ -488,8 +566,12 @@ def train(self):
device=self.device,
)
yield loss, iteration
+ return
def on_yield_training(self, loss_iteration):
+ """
+ The loss plot is updated every training iteration.
+ """
loss, iteration = loss_iteration
print(f"===> Iteration: {iteration}, loss: {loss:.6f}")
self.iterations.append(iteration)
@@ -497,19 +579,23 @@ def on_yield_training(self, loss_iteration):
self.losses_widget.plot(self.iterations, self.losses)
def prepare_for_stop_training(self):
+ """
+ This function defines the sequence of events once training is stopped.
+ """
self.start_training_button.setEnabled(True)
self.stop_training_button.setEnabled(True)
- if len(self.thresholds) == 0:
+ self.save_weights_button.setEnabled(True)
+ if not hasattr(self, "thresholds"):
self.threshold_line.setEnabled(False)
else:
self.threshold_line.setEnabled(True)
- if len(self.band_widths) == 0:
+ if not hasattr(self, "bandwidths"):
self.bandwidth_line.setEnabled(False)
else:
self.bandwidth_line.setEnabled(True)
self.radio_button_nucleus.setEnabled(True)
self.radio_button_cell.setEnabled(True)
- if len(self.min_sizes) == 0:
+ if not hasattr(self, "min_sizes"):
self.min_size_line.setEnabled(False)
else:
self.min_size_line.setEnabled(True)
@@ -528,8 +614,12 @@ def prepare_for_stop_training(self):
self.model_config.checkpoint = checkpoint_file_name
def prepare_for_start_inference(self):
+ """
+ When the inference begins, then training-related buttons are disabled.
+ """
self.start_training_button.setEnabled(False)
self.stop_training_button.setEnabled(False)
+ self.save_weights_button.setEnabled(False)
self.threshold_line.setEnabled(False)
self.bandwidth_line.setEnabled(False)
self.radio_button_nucleus.setEnabled(False)
@@ -546,13 +636,17 @@ def prepare_for_start_inference(self):
)
self.inference_worker = self.infer()
- # self.inference_worker.yielded.connect(self.on_yield_infer)
self.inference_worker.returned.connect(self.on_return_infer)
self.inference_worker.start()
def prepare_for_stop_inference(self):
+ """
+ This function defines the sequence of events which ensue once inference is stopped.
+ """
self.start_training_button.setEnabled(True)
self.stop_training_button.setEnabled(True)
+ self.save_weights_button.setEnabled(True)
+
self.threshold_line.setEnabled(True)
self.bandwidth_line.setEnabled(True)
self.radio_button_nucleus.setEnabled(True)
@@ -562,33 +656,48 @@ def prepare_for_stop_inference(self):
self.stop_inference_button.setEnabled(True)
if self.napari_dataset.get_num_samples() == 0:
self.threshold_line.setText(str(round(self.thresholds[0], 3)))
- self.bandwidth_line.setText(str(round(self.band_widths[0], 3)))
+ self.bandwidth_line.setText(str(round(self.bandwidths[0], 3)))
self.min_size_line.setText(str(round(self.min_sizes[0], 3)))
+ if self.inference_worker is not None:
+ self.inference_worker.quit()
@thread_worker
def infer(self):
+ """
+ The main inference function.
+ """
for layer in self.viewer.layers:
if f"{layer}" == self.raw_selector.currentText():
raw_image_layer = layer
break
- self.thresholds = (
- [None] * self.napari_dataset.get_num_samples()
- if self.napari_dataset.get_num_samples() != 0
- else [None] * 1
- )
- if (
+ if not hasattr(self, "thresholds"):
+ self.thresholds = (
+ [None] * self.napari_dataset.get_num_samples()
+ if self.napari_dataset.get_num_samples() != 0
+ else [None] * 1
+ )
+
+ if not hasattr(self, "thresholds_last"):
+ self.thresholds_last = self.thresholds.copy()
+
+ if not hasattr(self, "bandwidths") and (
self.inference_config.bandwidth is None
- and len(self.band_widths) == 0
):
- self.band_widths = (
+ self.bandwidths = (
[0.5 * self.experiment_config.object_size]
* self.napari_dataset.get_num_samples()
if self.napari_dataset.get_num_samples() != 0
else [0.5 * self.experiment_config.object_size]
)
- if self.inference_config.min_size is None and len(self.min_sizes) == 0:
+ if not hasattr(self, "bandwidths_last"):
+ self.bandwidths_last = self.bandwidths.copy()
+
+ if (
+ not hasattr(self, "min_sizes")
+ and self.inference_config.min_size is None
+ ):
if self.napari_dataset.get_num_spatial_dims() == 2:
self.min_sizes = (
[
@@ -639,7 +748,20 @@ def infer(self):
]
)
+ if not hasattr(self, "min_sizes_last"):
+ self.min_sizes_last = self.min_sizes.copy()
+
+ if not hasattr(self, "post_processing"):
+ self.post_processing = (
+ "cell" if self.radio_button_cell.isChecked() else "nucleus"
+ )
+
+ if not hasattr(self, "post_processing_last"):
+ self.post_processing_last = self.post_processing
+
# set in eval mode
+ self.model = self.model.to(self.device)
+
self.model.eval()
self.model.set_infer(
p_salt_pepper=self.inference_config.p_salt_pepper,
@@ -649,9 +771,14 @@ def infer(self):
if self.napari_dataset.get_num_spatial_dims() == 2:
crop_size_tuple = (self.inference_config.crop_size[0],) * 2
-
+ predicted_crop_size_tuple = (
+ self.inference_config.crop_size[0] - 16,
+ ) * 2
elif self.napari_dataset.get_num_spatial_dims() == 3:
crop_size_tuple = (self.inference_config.crop_size[0],) * 3
+ predicted_crop_size_tuple = (
+ self.inference_config.crop_size[0] - 16,
+ ) * 3
input_shape = gp.Coordinate(
(
@@ -675,7 +802,12 @@ def infer(self):
output_shape = gp.Coordinate(
self.model(
torch.zeros(
- (1, 1, *crop_size_tuple), dtype=torch.float32
+ (
+ 1,
+ self.napari_dataset.get_num_channels(),
+ *crop_size_tuple,
+ ),
+ dtype=torch.float32,
).to(self.device)
).shape
)
@@ -693,17 +825,20 @@ def infer(self):
context = (input_size - output_size) // 2
raw = gp.ArrayKey("RAW")
prediction = gp.ArrayKey("PREDICT")
- scan_request = gp.BatchRequest()
- # scan_request.add(raw, input_size)
+ scan_request = gp.BatchRequest()
scan_request[raw] = gp.Roi(
- (-8,) * (self.napari_dataset.get_num_spatial_dims()),
+ (-8,) * self.napari_dataset.get_num_spatial_dims(),
crop_size_tuple,
)
- scan_request.add(prediction, output_size)
+ scan_request[prediction] = gp.Roi(
+ (0,) * self.napari_dataset.get_num_spatial_dims(),
+ predicted_crop_size_tuple,
+ )
+
predict = gp.torch.Predict(
self.model,
- inputs={"x": raw},
+ inputs={"raw": raw},
outputs={0: prediction},
array_specs={prediction: gp.ArraySpec(voxel_size=voxel_size)},
)
@@ -748,25 +883,30 @@ def infer(self):
# Obtain Embeddings
print("Predicting Embeddings ...")
- with gp.build(pipeline):
- batch = pipeline.request_batch(request)
+ if hasattr(self, "embeddings"):
+ pass
+ else:
+ with gp.build(pipeline):
+ batch = pipeline.request_batch(request)
- embeddings = batch.arrays[prediction].data
- embeddings_centered = np.zeros_like(embeddings)
- foreground_mask = np.zeros_like(embeddings[:, 0:1, ...], dtype=bool)
+ self.embeddings = batch.arrays[prediction].data
+ embeddings_centered = np.zeros_like(self.embeddings)
+ foreground_mask = np.zeros_like(
+ self.embeddings[:, 0:1, ...], dtype=bool
+ )
colormaps = ["red", "green", "blue"]
# Obtain Object Centered Embeddings
- for sample in tqdm(range(embeddings.shape[0])):
- embeddings_sample = embeddings[sample]
+ for sample in tqdm(range(self.embeddings.shape[0])):
+ embeddings_sample = self.embeddings[sample]
embeddings_std = embeddings_sample[-1, ...]
embeddings_mean = embeddings_sample[
np.newaxis, : self.napari_dataset.get_num_spatial_dims(), ...
].copy()
- threshold = threshold_otsu(embeddings_std)
-
- self.thresholds[sample] = threshold
- binary_mask = embeddings_std < threshold
+ if self.thresholds[sample] is None:
+ threshold = threshold_otsu(embeddings_std)
+ self.thresholds[sample] = round(threshold, 3)
+ binary_mask = embeddings_std < self.thresholds[sample]
foreground_mask[sample] = binary_mask[np.newaxis, ...]
embeddings_centered_sample = embeddings_sample.copy()
embeddings_mean_masked = (
@@ -812,31 +952,51 @@ def infer(self):
)
for i in range(self.napari_dataset.get_num_spatial_dims() + 1)
]
+
print("Clustering Objects in the obtained Foreground Mask ...")
- detection = np.zeros_like(embeddings[:, 0:1, ...], dtype=np.uint16)
- for sample in tqdm(range(embeddings.shape[0])):
- embeddings_sample = embeddings[sample]
+ if hasattr(self, "detection"):
+ pass
+ else:
+ self.detection = np.zeros_like(
+ self.embeddings[:, 0:1, ...], dtype=np.uint16
+ )
+ for sample in tqdm(range(self.embeddings.shape[0])):
+ embeddings_sample = self.embeddings[sample]
embeddings_std = embeddings_sample[-1, ...]
embeddings_mean = embeddings_sample[
np.newaxis, : self.napari_dataset.get_num_spatial_dims(), ...
].copy()
- detection_sample = mean_shift_segmentation(
- embeddings_mean,
- embeddings_std,
- bandwidth=self.band_widths[sample],
- min_size=self.inference_config.min_size,
- reduction_probability=self.inference_config.reduction_probability,
- threshold=self.thresholds[sample],
- seeds=None,
- )
- detection[sample, 0, ...] = detection_sample
+ if (
+ self.thresholds[sample] != self.thresholds_last[sample]
+ or self.bandwidths[sample] != self.bandwidths_last[sample]
+ ):
+ detection_sample = mean_shift_segmentation(
+ embeddings_mean,
+ embeddings_std,
+ bandwidth=self.bandwidths[sample],
+ min_size=self.inference_config.min_size,
+ reduction_probability=self.inference_config.reduction_probability,
+ threshold=self.thresholds[sample],
+ seeds=None,
+ )
+ self.detection[sample, 0, ...] = detection_sample
+ self.thresholds_last[sample] = self.thresholds[sample]
+ self.bandwidths_last[sample] = self.bandwidths[sample]
print("Converting Detections to Segmentations ...")
- segmentation = np.zeros_like(embeddings[:, 0:1, ...], dtype=np.uint16)
+ if (
+ hasattr(self, "segmentation")
+ and self.post_processing == self.post_processing_last
+ ):
+ pass
+ else:
+ self.segmentation = np.zeros_like(
+ self.embeddings[:, 0:1, ...], dtype=np.uint16
+ )
if self.radio_button_cell.isChecked():
- for sample in tqdm(range(embeddings.shape[0])):
- segmentation_sample = detection[sample, 0].copy()
+ for sample in tqdm(range(self.embeddings.shape[0])):
+ segmentation_sample = self.detection[sample, 0].copy()
distance_foreground = dtedt(segmentation_sample == 0)
expanded_mask = (
distance_foreground < self.inference_config.grow_distance
@@ -845,11 +1005,11 @@ def infer(self):
segmentation_sample[
distance_background < self.inference_config.shrink_distance
] = 0
- segmentation[sample, 0, ...] = segmentation_sample
+ self.segmentation[sample, 0, ...] = segmentation_sample
elif self.radio_button_nucleus.isChecked():
raw_image = raw_image_layer.data
- for sample in tqdm(range(embeddings.shape[0])):
- segmentation_sample = detection[sample, 0]
+ for sample in tqdm(range(self.embeddings.shape[0])):
+ segmentation_sample = self.detection[sample, 0].copy()
if (
self.napari_dataset.get_num_samples() == 0
and self.napari_dataset.get_num_channels() == 0
@@ -903,7 +1063,7 @@ def infer(self):
)
mask[y_min : y_max + 1, x_min : x_max + 1] = mask_small
y, x = np.where(mask)
- segmentation[sample, 0, y, x] = id_
+ self.segmentation[sample, 0, y, x] = id_
elif self.napari_dataset.get_num_spatial_dims() == 3:
mask_small = binary_fill_holes(
mask[
@@ -918,28 +1078,66 @@ def infer(self):
x_min : x_max + 1,
] = mask_small
z, y, x = np.where(mask)
- segmentation[sample, 0, z, y, x] = id_
+ self.segmentation[sample, 0, z, y, x] = id_
print("Removing small objects ...")
# size filter - remove small objects
- for sample in tqdm(range(embeddings.shape[0])):
- segmentation[sample, 0, ...] = size_filter(
- segmentation[sample, 0], self.min_sizes[sample]
- )
+ for sample in tqdm(range(self.embeddings.shape[0])):
+ if (
+ self.min_sizes[sample] != self.min_sizes_last[sample]
+ or self.post_processing_last != self.post_processing
+ ):
+ self.segmentation[sample, 0, ...] = size_filter(
+ self.segmentation[sample, 0], self.min_sizes[sample]
+ )
+ self.min_sizes_last[sample] = self.min_sizes[sample]
+ self.post_processing_last = self.post_processing
+
return (
embeddings_layers
+ [(foreground_mask, {"name": "Foreground Mask"}, "labels")]
- + [(detection, {"name": "Detection"}, "labels")]
- + [(segmentation, {"name": "Segmentation"}, "labels")]
+ + [(self.detection, {"name": "Detection"}, "labels")]
+ + [(self.segmentation, {"name": "Segmentation"}, "labels")]
)
def on_return_infer(self, layers):
+ """
+ Once inference is over, the old result layers are removed
+ and the new output layers are displayed.
+
+ Args:
+ layers: Tuple
+ (embedding_layers, foreground layer, detection layer, segmentation layer)
+
+ """
+
+ if "Offset (x)" in self.viewer.layers:
+ del self.viewer.layers["Offset (x)"]
+ if "Offset (y)" in self.viewer.layers:
+ del self.viewer.layers["Offset (y)"]
+ if "Offset (z)" in self.viewer.layers:
+ del self.viewer.layers["Offset (z)"]
+ if "Uncertainty" in self.viewer.layers:
+ del self.viewer.layers["Uncertainty"]
+ if "Foreground Mask" in self.viewer.layers:
+ del self.viewer.layers["Foreground Mask"]
+ if "Segmentation" in self.viewer.layers:
+ del self.viewer.layers["Segmentation"]
+ if "Detection" in self.viewer.layers:
+ del self.viewer.layers["Detection"]
+
for data, metadata, layer_type in layers:
if layer_type == "image":
self.viewer.add_image(data, **metadata)
elif layer_type == "labels":
- self.viewer.add_labels(data.astype(int), **metadata)
+ if (
+ self.napari_dataset.get_num_samples() != 0
+ and self.napari_dataset.get_num_channels() != 0
+ ):
+ self.viewer.add_labels(data.astype(int), **metadata)
+ else:
+ self.viewer.add_labels(data[:, 0].astype(int), **metadata)
self.viewer.layers["Offset (x)"].visible = False
self.viewer.layers["Offset (y)"].visible = False
self.viewer.layers["Uncertainty"].visible = False
@@ -948,3 +1146,80 @@ def on_return_infer(self, layers):
self.viewer.layers["Segmentation"].visible = True
self.inference_worker.quit()
self.prepare_for_stop_inference()
+
+ def prepare_thresholds(self):
+ """
+ In case, the `Threshold` lineedit is changed by the user,
+ the attribute `thresholds` is updated.
+ """
+ sample_index = self.viewer.dims.current_step[0]
+ self.thresholds[sample_index] = float(self.threshold_line.text())
+
+ def prepare_bandwidths(self):
+ """
+ In case, the `Bandwidth` lineedit is changed by the user,
+ the attribute `bandwidths` is updated.
+ """
+ sample_index = self.viewer.dims.current_step[0]
+ self.bandwidths[sample_index] = float(self.bandwidth_line.text())
+
+ def prepare_min_sizes(self):
+ """
+ In case, the `Minimum Size` lineedit is changed by the user,
+ the attribute `min_sizes` is updated.
+
+ """
+ sample_index = self.viewer.dims.current_step[0]
+ self.min_sizes[sample_index] = float(self.min_size_line.text())
+
+ def load_weights(self):
+ """
+ Describes sequence of actions, which ensue after `Load Weights` button is pressed
+
+ """
+ file_name, _ = QFileDialog.getOpenFileName(
+ caption="Load Model Weights"
+ )
+ self.pre_trained_model_checkpoint = file_name
+ print(
+ f"Model weights will be loaded from {self.pre_trained_model_checkpoint}"
+ )
+
+ def update_post_processing(self):
+ self.post_processing = (
+ "cell" if self.radio_button_nucleus.isChecked() else "nucleus"
+ )
+
+ def affect_load_weights(self):
+ """
+ In case `train from scratch` checkbox is selected,
+ the `Load weights` is disabled, and vice versa.
+
+ """
+ if self.train_model_from_scratch_checkbox.isChecked():
+ self.load_model_button.setEnabled(False)
+ else:
+ self.load_model_button.setEnabled(True)
+
+ def save_weights(self):
+ """
+ Describes sequence of actions which ensue, after `Save weights` button is pressed
+
+ """
+ checkpoint_file_name, _ = QFileDialog.getSaveFileName(
+ caption="Save Model Weights"
+ )
+ if (
+ hasattr(self, "model")
+ and hasattr(self, "optimizer")
+ and hasattr(self, "iterations")
+ and hasattr(self, "losses")
+ ):
+ state = {
+ "model_state_dict": self.model.state_dict(),
+ "optim_state_dict": self.optimizer.state_dict(),
+ "iterations": self.iterations,
+ "losses": self.losses,
+ }
+ torch.save(state, checkpoint_file_name)
+ print(f"Model weights will be saved at {checkpoint_file_name}")
diff --git a/tox.ini b/tox.ini
deleted file mode 100644
index 0808e73..0000000
--- a/tox.ini
+++ /dev/null
@@ -1,32 +0,0 @@
-# For more information about tox, see https://tox.readthedocs.io/en/latest/
-[tox]
-envlist = py{38,39,310}-{linux,macos,windows}
-isolated_build=true
-
-[gh-actions]
-python =
- 3.8: py38
- 3.9: py39
- 3.10: py310
-
-[gh-actions:env]
-PLATFORM =
- ubuntu-latest: linux
- macos-latest: macos
- windows-latest: windows
-
-[testenv]
-platform =
- macos: darwin
- linux: linux
- windows: win32
-passenv =
- CI
- GITHUB_ACTIONS
- DISPLAY
- XAUTHORITY
- NUMPY_EXPERIMENTAL_ARRAY_FUNCTION
- PYVISTA_OFF_SCREEN
-extras =
- testing
-commands = pytest -v --color=yes --cov=napari_cellulus --cov-report=xml