From b7b802de0bfc189d11ee56d3bdc5f7d9f3ca1f9e Mon Sep 17 00:00:00 2001 From: naterenegar Date: Wed, 30 Oct 2024 12:21:28 -0400 Subject: [PATCH 1/3] Added three hidden parameters to the parameters list. Added a warning when the last iteration of matching pursuit is reached but spikes are still detected, along with a parameter to increase the number of matching pursuit iterations. --- kilosort/parameters.py | 45 +++++++++++++++++++++++++++++++++++ kilosort/spikedetect.py | 5 +++- kilosort/template_matching.py | 8 +++++-- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/kilosort/parameters.py b/kilosort/parameters.py index c54b10aa..a6bb2a2b 100644 --- a/kilosort/parameters.py +++ b/kilosort/parameters.py @@ -298,6 +298,16 @@ """ }, + 'max_peels': { + 'gui_name': 'max peels', 'type': int, 'min': 1, 'max': 10000, 'exclude': [], + 'default': 100, 'step': 'spike detection', + 'description': + """ + Number of iterations to do over each batch of data in the matching + pursuit step. More iterations should detect more overlapping spikes. + """ + }, + 'templates_from_data': { 'gui_name': 'templates from data', 'type': bool, 'min': None, 'max': None, 'exclude': [], 'default': True, 'step': 'spike detection', @@ -308,6 +318,28 @@ """ }, + 'loc_range': { + 'gui_name': 'loc range', 'type': list, 'min': None, 'max': None, + 'exclude': [], 'default': [4, 5], 'step': 'spike detection', + 'description': + """ + Number of channels and time steps, respectively, to use for local + maximum detection when detecting spikes to compute universal + templates from data (only used if templates_from_data is True). + """ + }, + + 'long_range': { + 'gui_name': 'loc range', 'type': list, 'min': None, 'max': None, + 'exclude': [], 'default': [6, 30], 'step': 'spike detection', + 'description': + """ + Number of channels and time steps, respectively, to use for peak + isolation when detecting spikes to compute universal templates from + data (only used if templates_from_data is True). + """ + }, + 'n_templates': { 'gui_name': 'n templates', 'type': int, 'min': 1, 'max': np.inf, 'exclude': [], 'default': 6, 'step': 'spike detection', @@ -384,6 +416,19 @@ }, + 'drift_smoothing': { + 'gui_name': 'drift smoothing', 'type': list, 'min': None, 'max': None, + 'exclude': [], 'default': [0.5, 0.5, 0.5], 'step': 'preprocessing', + 'description': + """ + Amount of gaussian smoothing to apply to the spatiotemporal drift + estimation, for correlation, time (units of registration blocks), + and y (units of batches) axes. The y smoothing has no effect + for `nblocks = 1`. Adjusting smoothing for the correlation axis + is not recommended. + """ + }, + ### POSTPROCESSING 'duplicate_spike_ms': { 'gui_name': 'duplicate spike ms', 'type': float, 'min': 0, 'max': np.inf, diff --git a/kilosort/spikedetect.py b/kilosort/spikedetect.py index 30df0f9e..01ea6eee 100644 --- a/kilosort/spikedetect.py +++ b/kilosort/spikedetect.py @@ -49,13 +49,16 @@ def extract_snippets(X, nt, twav_min, Th_single_ch, loc_range=[4,5], def extract_wPCA_wTEMP(ops, bfile, nt=61, twav_min=20, Th_single_ch=6, nskip=25, device=torch.device('cuda')): + loc_range = ops['settings']['loc_range'] + long_range = ops['settings']['long_range'] clips = np.zeros((500000,nt), 'float32') i = 0 for j in range(0, bfile.n_batches, nskip): X = bfile.padded_batch_to_torch(j, ops) clips_new = extract_snippets(X, nt=nt, twav_min=twav_min, - Th_single_ch=Th_single_ch, device=device) + Th_single_ch=Th_single_ch, device=device, + loc_range=loc_range,long_range=long_range) nnew = len(clips_new) diff --git a/kilosort/template_matching.py b/kilosort/template_matching.py index 66b31503..ab876e9e 100644 --- a/kilosort/template_matching.py +++ b/kilosort/template_matching.py @@ -140,8 +140,8 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')): Xres = X.clone() lam = 20 - - for t in range(100): + max_peels = ops['settings']['max_peels'] + for t in range(max_peels): # Cf = 2 * B - nm.unsqueeze(-1) Cf = torch.relu(B)**2 /nm.unsqueeze(-1) #a = 1 + lam @@ -163,6 +163,10 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')): if len(xs)==0: #print('iter %d'%t) break + elif len(xs) > 0 and t == max_peels - 1: + logger.debug(f'Reached last iteration of matching pursuit with {len(xs)} spikes detected.' + 'Consider increasing the \'max_peels\' parameter' + ) iX = xs[:,:1] iY = imax[iX] From 6682999d8cb74527fd65857da259a2f9efd3de2a Mon Sep 17 00:00:00 2001 From: naterenegar Date: Wed, 30 Oct 2024 12:30:41 -0400 Subject: [PATCH 2/3] needed to modify gui parameter construction to properly parse the default parameter settings for loc range and long range --- kilosort/gui/settings_box.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/kilosort/gui/settings_box.py b/kilosort/gui/settings_box.py index 209d9f5c..b5daf11e 100644 --- a/kilosort/gui/settings_box.py +++ b/kilosort/gui/settings_box.py @@ -304,6 +304,8 @@ def set_cached_field_values(self): # List of floats gets cached as list of strings, so # have to convert back. d = str([float(s) for s in v]) + elif k == 'loc_range' or k == 'long_range': + d = str([int(s) for s in v]) else: d = str(v) else: From d0d98c1c9c2707ce77b73fd993ef553f348d87e4 Mon Sep 17 00:00:00 2001 From: naterenegar Date: Wed, 30 Oct 2024 12:49:24 -0400 Subject: [PATCH 3/3] adjusted formatting of matching pursuit warning --- kilosort/template_matching.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/kilosort/template_matching.py b/kilosort/template_matching.py index ab876e9e..ce7a3633 100644 --- a/kilosort/template_matching.py +++ b/kilosort/template_matching.py @@ -164,9 +164,8 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')): #print('iter %d'%t) break elif len(xs) > 0 and t == max_peels - 1: - logger.debug(f'Reached last iteration of matching pursuit with {len(xs)} spikes detected.' - 'Consider increasing the \'max_peels\' parameter' - ) + logger.debug(f'Reached last iteration of matching pursuit with {len(xs)} spikes detected.') + logger.debug(f'Consider increasing the \'max_peels\' parameter. Current value = {max_peels}') iX = xs[:,:1] iY = imax[iX]