Skip to content

Commit

Permalink
fixing bug with net_avg and flow_threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Aug 15, 2024
1 parent 6993bee commit dc02341
Showing 1 changed file with 10 additions and 15 deletions.
25 changes: 10 additions & 15 deletions cellpose_napari/_dock_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,11 @@ def _deco(func):
@thread_worker
@no_grad()
def run_cellpose(image, model_type, custom_model, channels, channel_axis, diameter,
net_avg, resample, cellprob_threshold,
model_match_threshold, do_3D, stitch_threshold):
resample, cellprob_threshold,
flow_threshold, do_3D, stitch_threshold):
from cellpose import models

flow_threshold = (31.0 - model_match_threshold) / 10.
if model_match_threshold==0.0:
if flow_threshold==0.0:
flow_threshold = 0.0
logger.debug('flow_threshold=0 => no masks thrown out due to model mismatch')
logger.debug(f'computing masks with cellprob_threshold={cellprob_threshold}, flow_threshold={flow_threshold}')
Expand All @@ -74,7 +73,6 @@ def run_cellpose(image, model_type, custom_model, channels, channel_axis, diamet
channels=channels,
channel_axis=channel_axis,
diameter=diameter,
net_avg=net_avg,
resample=resample,
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold,
Expand All @@ -101,15 +99,15 @@ def compute_diameter(image, channels, model_type):
return diam

@thread_worker
def compute_masks(masks_orig, flows_orig, cellprob_threshold, model_match_threshold):
def compute_masks(masks_orig, flows_orig, cellprob_threshold, flow_threshold):
import cv2
from cellpose.utils import fill_holes_and_remove_small_masks
from cellpose.dynamics import get_masks
from cellpose.transforms import resize_image

#print(flows_orig[3].shape, flows_orig[2].shape, masks_orig.shape)
flow_threshold = (31.0 - model_match_threshold) / 10.
if model_match_threshold==0.0:
flow_threshold = (31.0 - flow_threshold) / 10.
if flow_threshold==0.0:
flow_threshold = 0.0
logger.debug('flow_threshold=0 => no masks thrown out due to model mismatch')
logger.debug(f'computing masks with cellprob_threshold={cellprob_threshold}, flow_threshold={flow_threshold}')
Expand All @@ -131,9 +129,8 @@ def compute_masks(masks_orig, flows_orig, cellprob_threshold, model_match_thresh
compute_diameter_shape = dict(widget_type='PushButton', text='compute diameter from shape layer', tooltip='create shape layer with circles and/or squares, select above, and diameter will be estimated from it'),
compute_diameter_button = dict(widget_type='PushButton', text='compute diameter from image', tooltip='cellpose model will estimate diameter from image using specified channels'),
cellprob_threshold = dict(widget_type='FloatSlider', name='cellprob_threshold', value=0.0, min=-8.0, max=8.0, step=0.2, tooltip='cell probability threshold (set lower to get more cells and larger cells)'),
model_match_threshold = dict(widget_type='FloatSlider', name='model_match_threshold', value=27.0, min=0.0, max=30.0, step=0.2, tooltip='threshold on gradient match to accept a mask (set lower to get more cells)'),
flow_threshold = dict(widget_type='FloatSlider', name='flow_threshold', value=0.4, min=0.0, max=3.0, step=0.05, tooltip='threshold on gradient match to accept a mask (set higher to get more cells, or to zero to turn off)'),
compute_masks_button = dict(widget_type='PushButton', text='recompute last masks with new cellprob + model match', enabled=False),
net_average = dict(widget_type='CheckBox', text='average 4 nets', value=True, tooltip='average 4 different fit networks (default) or if not checked run only 1 network (fast)'),
resample_dynamics = dict(widget_type='CheckBox', text='resample dynamics', value=False, tooltip='if False, mask estimation with dynamics run on resized image with diameter=30; if True, flows are resized to original image size before dynamics and mask estimation (turn on for more smooth masks)'),
process_3D = dict(widget_type='CheckBox', text='process stack as 3D', value=False, tooltip='use default 3D processing where flows in X, Y, and Z are computed and dynamics run in 3D to create masks'),
stitch_threshold_3D = dict(widget_type='LineEdit', label='stitch threshold slices', value=0, tooltip='across time or Z, stitch together masks with IoU threshold of "stitch threshold" to create 3D segmentation'),
Expand All @@ -153,9 +150,8 @@ def widget(#label_logo,
compute_diameter_shape,
compute_diameter_button,
cellprob_threshold,
model_match_threshold,
flow_threshold,
compute_masks_button,
net_average,
resample_dynamics,
process_3D,
stitch_threshold_3D,
Expand Down Expand Up @@ -256,10 +252,9 @@ def _new_segmentation(segmentation):
max(0, optional_nuclear_channel)],
channel_axis=widget.channel_axis,
diameter=float(diameter),
net_avg=net_average,
resample=resample_dynamics,
cellprob_threshold=cellprob_threshold,
model_match_threshold=model_match_threshold,
flow_threshold=flow_threshold,
do_3D=(process_3D and float(stitch_threshold_3D)==0 and image_layer.ndim>2),
stitch_threshold=float(stitch_threshold_3D) if image_layer.ndim>2 else 0.0)
cp_worker.returned.connect(_new_segmentation)
Expand Down Expand Up @@ -290,7 +285,7 @@ def _compute_masks(e: Any):
mask_worker = compute_masks(widget.masks_orig,
widget.flows_orig,
widget.cellprob_threshold.value,
widget.model_match_threshold.value)
widget.flow_threshold.value)
mask_worker.returned.connect(update_masks)
mask_worker.start()

Expand Down

0 comments on commit dc02341

Please sign in to comment.