Skip to content

Commit

Permalink
Update canvas
Browse files Browse the repository at this point in the history
  • Loading branch information
lmanan committed Oct 27, 2023
1 parent f92e357 commit f0831ae
Showing 1 changed file with 54 additions and 41 deletions.
95 changes: 54 additions & 41 deletions src/napari_cellulus/widgets/_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gunpowder as gp
import napari
import numpy as np
import pyqtgraph as pg
import torch
from cellulus.criterions import get_loss
from cellulus.models import get_model
Expand All @@ -13,8 +14,7 @@

# widget stuff
from napari.qt.threading import thread_worker
from qtpy.QtCore import Qt, QUrl
from qtpy.QtGui import QDesktopServices
from qtpy.QtCore import Qt
from qtpy.QtWidgets import (
QCheckBox,
QComboBox,
Expand All @@ -23,6 +23,7 @@
QHBoxLayout,
QLabel,
QLineEdit,
QMainWindow,
QPushButton,
QScrollArea,
QVBoxLayout,
Expand All @@ -36,7 +37,7 @@
from ..gp.nodes.napari_image_source import NapariImageSource

# local package imports
from ..gui_helpers import MplCanvas, layer_choice_widget
from ..gui_helpers import layer_choice_widget

############ GLOBALS ###################
time_now = 0
Expand All @@ -49,10 +50,11 @@
_dataset = None


class SegmentationWidget(QScrollArea):
class SegmentationWidget(QMainWindow):
def __init__(self, napari_viewer):
super().__init__()
self.widget = QWidget()
self.scroll = QScrollArea()
self.viewer = napari_viewer

# initialize train_config and model_config
Expand Down Expand Up @@ -95,9 +97,6 @@ def __init__(self, napari_viewer):
grid_layout.addWidget(device_label, 1, 0)
grid_layout.addWidget(self.device_combo_box, 1, 1)

# global_params_widget = QWidget("")
# global_params_widget.setLayout(grid_layout)

# Initialize train configs widget
collapsible_train_configs = QCollapsible("Train Configs", self)
collapsible_train_configs.addWidget(self.create_train_configs_widget)
Expand All @@ -107,16 +106,22 @@ def __init__(self, napari_viewer):
collapsible_model_configs.addWidget(self.create_model_configs_widget)

# Initialize loss/iterations widget
self.canvas = MplCanvas(self, width=5, height=5, dpi=100)
canvas_layout = QVBoxLayout()
canvas_layout.addWidget(self.canvas)
if len(self.iterations) == 0:
self.loss_plot = self.canvas.axes.plot([], [], label="")[0]
self.canvas.axes.legend()
self.canvas.axes.set_xlabel("Iterations")
self.canvas.axes.set_ylabel("Loss")
plot_container_widget = QWidget()
plot_container_widget.setLayout(canvas_layout)

self.losses_widget = pg.PlotWidget()
self.losses_widget.setBackground((37, 41, 49))
styles = {"color": "white", "font-size": "16px"}
self.losses_widget.setLabel("left", "Loss", **styles)
self.losses_widget.setLabel("bottom", "Iterations", **styles)
# self.canvas = MplCanvas(self, width=5, height=10, dpi=100)
# canvas_layout = QVBoxLayout()
# canvas_layout.addWidget(self.canvas)
# if len(self.iterations) == 0:
# self.loss_plot = self.canvas.axes.plot([], [], label="")[0]
# self.canvas.axes.legend()
# self.canvas.axes.set_xlabel("Iterations")
# self.canvas.axes.set_ylabel("Loss")
# plot_container_widget = QWidget()
# plot_container_widget.setLayout(canvas_layout)

# Initialize Layer Choice
self.raw_selector = layer_choice_widget(
Expand Down Expand Up @@ -146,9 +151,17 @@ def __init__(self, napari_viewer):
self.train_button.clicked.connect(self.prepare_for_training)

# Initialize Save and Load Widget
# collapsible_save_load_widget = QCollapsible("Save/Load", self)
# collapsible_save_load_widget.addWidget(self.save_widget.native)
# collapsible.save_load_widget.addWidget(self.load_widget.native)
collapsible_save_load_widget = QCollapsible(
"Save and Load Model", self
)
save_load_layout = QHBoxLayout()
save_model_button = QPushButton("Save Model", self)
load_model_button = QPushButton("Load Model", self)
save_load_layout.addWidget(save_model_button)
save_load_layout.addWidget(load_model_button)
save_load_widget = QWidget()
save_load_widget.setLayout(save_load_layout)
collapsible_save_load_widget.addWidget(save_load_widget)

# Initialize Segment Configs widget
collapsible_segment_configs = QCollapsible("Inference Configs", self)
Expand All @@ -164,13 +177,8 @@ def __init__(self, napari_viewer):
# self.pbar = QProgressBar(self)

# Initialize Feedback Button
self.feedback_button = QPushButton("Feedback!", self)
self.feedback_button.clicked.connect(
lambda: QDesktopServices.openUrl(
QUrl(
"https://github.com/funkelab/napari-cellulus/issues/new/choose"
)
)
self.feedback_label = QLabel(
'<small>Please share any feedback <a href="https://github.com/funkelab/napari-cellulus/issues/new/choose" style="color:gray;">here</a>.</small>'
)

# Add all components to outer_layout
Expand All @@ -180,22 +188,25 @@ def __init__(self, napari_viewer):
# outer_layout.addWidget(global_params_widget)
outer_layout.addWidget(collapsible_train_configs)
outer_layout.addWidget(collapsible_model_configs)
outer_layout.addWidget(plot_container_widget)
# outer_layout.addWidget(plot_container_widget)
outer_layout.addWidget(self.losses_widget)
outer_layout.addWidget(self.raw_selector.native)
outer_layout.addWidget(axis_selector)
outer_layout.addWidget(self.train_button)
# outer_layout.addWidget(collapsible_save_load_widget)
outer_layout.addWidget(collapsible_save_load_widget)
outer_layout.addWidget(collapsible_segment_configs)
outer_layout.addWidget(self.segment_button)
outer_layout.addWidget(self.feedback_button)

outer_layout.addWidget(self.feedback_label)
outer_layout.setSpacing(20)
self.widget.setLayout(outer_layout)
self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.setWidgetResizable(True)
self.setWidget(self.widget)

self.scroll.setWidget(self.widget)
self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.scroll.setWidgetResizable(True)

self.setFixedWidth(400)
self.setCentralWidget(self.scroll)

@property
def create_train_configs_widget(self):
Expand Down Expand Up @@ -272,7 +283,7 @@ def create_segment_configs_widget(self):
@magic_factory(call_button="Save")
def segment_configs_widget(
crop_size: int = 252,
p_salt_pepper: float = 0.1,
p_salt_pepper: float = 0.01,
num_infer_iterations: int = 16,
bandwidth: int = 7,
reduction_probability: float = 0.1,
Expand Down Expand Up @@ -534,11 +545,13 @@ def on_yield(self, step_data):
# self.pbar.setValue(step_data)

def update_canvas(self):
self.loss_plot.set_xdata(self.iterations)
self.loss_plot.set_ydata(self.losses)
self.canvas.axes.relim()
self.canvas.axes.autoscale_view()
self.canvas.draw()
self.losses_widget.plot(self.iterations, self.losses)

# self.loss_plot.set_xdata(self.iterations)
# self.loss_plot.set_ydata(self.losses)
# self.canvas.axes.relim()
# self.canvas.axes.autoscale_view()
# self.canvas.draw()

def on_return(self, layers):
# Describes what happens once segment button has completed
Expand Down

0 comments on commit f0831ae

Please sign in to comment.