diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index a9d312a..09e1dc5 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -4,6 +4,7 @@ import gunpowder as gp import napari import numpy as np +import pyqtgraph as pg import torch from cellulus.criterions import get_loss from cellulus.models import get_model @@ -13,8 +14,7 @@ # widget stuff from napari.qt.threading import thread_worker -from qtpy.QtCore import Qt, QUrl -from qtpy.QtGui import QDesktopServices +from qtpy.QtCore import Qt from qtpy.QtWidgets import ( QCheckBox, QComboBox, @@ -23,6 +23,7 @@ QHBoxLayout, QLabel, QLineEdit, + QMainWindow, QPushButton, QScrollArea, QVBoxLayout, @@ -36,7 +37,7 @@ from ..gp.nodes.napari_image_source import NapariImageSource # local package imports -from ..gui_helpers import MplCanvas, layer_choice_widget +from ..gui_helpers import layer_choice_widget ############ GLOBALS ################### time_now = 0 @@ -49,10 +50,11 @@ _dataset = None -class SegmentationWidget(QScrollArea): +class SegmentationWidget(QMainWindow): def __init__(self, napari_viewer): super().__init__() self.widget = QWidget() + self.scroll = QScrollArea() self.viewer = napari_viewer # initialize train_config and model_config @@ -95,9 +97,6 @@ def __init__(self, napari_viewer): grid_layout.addWidget(device_label, 1, 0) grid_layout.addWidget(self.device_combo_box, 1, 1) - # global_params_widget = QWidget("") - # global_params_widget.setLayout(grid_layout) - # Initialize train configs widget collapsible_train_configs = QCollapsible("Train Configs", self) collapsible_train_configs.addWidget(self.create_train_configs_widget) @@ -107,16 +106,22 @@ def __init__(self, napari_viewer): collapsible_model_configs.addWidget(self.create_model_configs_widget) # Initialize loss/iterations widget - self.canvas = MplCanvas(self, width=5, height=5, dpi=100) - canvas_layout = QVBoxLayout() - canvas_layout.addWidget(self.canvas) - if len(self.iterations) == 0: - self.loss_plot = self.canvas.axes.plot([], [], label="")[0] - self.canvas.axes.legend() - self.canvas.axes.set_xlabel("Iterations") - self.canvas.axes.set_ylabel("Loss") - plot_container_widget = QWidget() - plot_container_widget.setLayout(canvas_layout) + + self.losses_widget = pg.PlotWidget() + self.losses_widget.setBackground((37, 41, 49)) + styles = {"color": "white", "font-size": "16px"} + self.losses_widget.setLabel("left", "Loss", **styles) + self.losses_widget.setLabel("bottom", "Iterations", **styles) + # self.canvas = MplCanvas(self, width=5, height=10, dpi=100) + # canvas_layout = QVBoxLayout() + # canvas_layout.addWidget(self.canvas) + # if len(self.iterations) == 0: + # self.loss_plot = self.canvas.axes.plot([], [], label="")[0] + # self.canvas.axes.legend() + # self.canvas.axes.set_xlabel("Iterations") + # self.canvas.axes.set_ylabel("Loss") + # plot_container_widget = QWidget() + # plot_container_widget.setLayout(canvas_layout) # Initialize Layer Choice self.raw_selector = layer_choice_widget( @@ -146,9 +151,17 @@ def __init__(self, napari_viewer): self.train_button.clicked.connect(self.prepare_for_training) # Initialize Save and Load Widget - # collapsible_save_load_widget = QCollapsible("Save/Load", self) - # collapsible_save_load_widget.addWidget(self.save_widget.native) - # collapsible.save_load_widget.addWidget(self.load_widget.native) + collapsible_save_load_widget = QCollapsible( + "Save and Load Model", self + ) + save_load_layout = QHBoxLayout() + save_model_button = QPushButton("Save Model", self) + load_model_button = QPushButton("Load Model", self) + save_load_layout.addWidget(save_model_button) + save_load_layout.addWidget(load_model_button) + save_load_widget = QWidget() + save_load_widget.setLayout(save_load_layout) + collapsible_save_load_widget.addWidget(save_load_widget) # Initialize Segment Configs widget collapsible_segment_configs = QCollapsible("Inference Configs", self) @@ -164,13 +177,8 @@ def __init__(self, napari_viewer): # self.pbar = QProgressBar(self) # Initialize Feedback Button - self.feedback_button = QPushButton("Feedback!", self) - self.feedback_button.clicked.connect( - lambda: QDesktopServices.openUrl( - QUrl( - "https://github.com/funkelab/napari-cellulus/issues/new/choose" - ) - ) + self.feedback_label = QLabel( + 'Please share any feedback here.' ) # Add all components to outer_layout @@ -180,22 +188,25 @@ def __init__(self, napari_viewer): # outer_layout.addWidget(global_params_widget) outer_layout.addWidget(collapsible_train_configs) outer_layout.addWidget(collapsible_model_configs) - outer_layout.addWidget(plot_container_widget) + # outer_layout.addWidget(plot_container_widget) + outer_layout.addWidget(self.losses_widget) outer_layout.addWidget(self.raw_selector.native) outer_layout.addWidget(axis_selector) outer_layout.addWidget(self.train_button) - # outer_layout.addWidget(collapsible_save_load_widget) + outer_layout.addWidget(collapsible_save_load_widget) outer_layout.addWidget(collapsible_segment_configs) outer_layout.addWidget(self.segment_button) - outer_layout.addWidget(self.feedback_button) - + outer_layout.addWidget(self.feedback_label) outer_layout.setSpacing(20) self.widget.setLayout(outer_layout) - self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) - self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - self.setWidgetResizable(True) - self.setWidget(self.widget) + + self.scroll.setWidget(self.widget) + self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) + self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.scroll.setWidgetResizable(True) + self.setFixedWidth(400) + self.setCentralWidget(self.scroll) @property def create_train_configs_widget(self): @@ -272,7 +283,7 @@ def create_segment_configs_widget(self): @magic_factory(call_button="Save") def segment_configs_widget( crop_size: int = 252, - p_salt_pepper: float = 0.1, + p_salt_pepper: float = 0.01, num_infer_iterations: int = 16, bandwidth: int = 7, reduction_probability: float = 0.1, @@ -534,11 +545,13 @@ def on_yield(self, step_data): # self.pbar.setValue(step_data) def update_canvas(self): - self.loss_plot.set_xdata(self.iterations) - self.loss_plot.set_ydata(self.losses) - self.canvas.axes.relim() - self.canvas.axes.autoscale_view() - self.canvas.draw() + self.losses_widget.plot(self.iterations, self.losses) + + # self.loss_plot.set_xdata(self.iterations) + # self.loss_plot.set_ydata(self.losses) + # self.canvas.axes.relim() + # self.canvas.axes.autoscale_view() + # self.canvas.draw() def on_return(self, layers): # Describes what happens once segment button has completed