From 780ecc06fa1b1b7a2a1015f4d3901a955ff9fd00 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 13 Jan 2015 14:04:47 +0100 Subject: [PATCH] WIP: flipping features. --- spikedetekt2/core/main.py | 47 ++++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/spikedetekt2/core/main.py b/spikedetekt2/core/main.py index ad8e9b0..1e95a74 100644 --- a/spikedetekt2/core/main.py +++ b/spikedetekt2/core/main.py @@ -129,6 +129,21 @@ def save_features(experiment, **prm): features = project_pcs(waveform, pcs) spikes.features_masks[i,:,0] = features.ravel() + # Flip the PCs. + # First, compute the mean features. + features_mean = spikes.features_masks[::step, :, 0].mean(axis=0) + to_flip = features_mean < 0 + to_flip_ind = np.nonzero(to_flip)[0] + if len(to_flip_ind) > 0: + # Flip the features. + spikes.features_masks[:, to_flip_ind, 0] *= -1 + pcs = experiment.channel_groups[chgrp]._node.pca_waveforms + # Find the channels to flip. + channels_to_flip = to_flip[0:pcs.shape[-1]:npcs] + channels_to_flip_ind = np.nonzero(channels_to_flip)[0] + # Flip the PCs. + pcs[..., channels_to_flip_ind] *= -1 + # ----------------------------------------------------------------------------- # File logger @@ -147,7 +162,7 @@ def close_file_logger(LOGGER_FILE): # ----------------------------------------------------------------------------- # Main loop # ----------------------------------------------------------------------------- -def run(raw_data=None, experiment=None, prm=None, probe=None, +def run(raw_data=None, experiment=None, prm=None, probe=None, _debug=False, convert_only=False): """This main function takes raw data (either as a RawReader, or a path to a filename, or an array) and executes the main algorithm (filtering, @@ -181,11 +196,11 @@ def run(raw_data=None, experiment=None, prm=None, probe=None, # Get the bandpass filter. filter = bandpass_filter(**prm) - + if not (convert_only and first_chunk_detected): # Compute the strong threshold across excerpts uniformly scattered across the # whole recording. - threshold = get_threshold(raw_data, filter=filter, + threshold = get_threshold(raw_data, filter=filter, channels=probe.channels, **prm) assert not np.isnan(threshold.weak).any() assert not np.isnan(threshold.strong).any() @@ -193,7 +208,7 @@ def run(raw_data=None, experiment=None, prm=None, probe=None, # Debug module. diagnostics_script_path = prm.get('diagnostics_script_path', None) - + # Progress bar. progress_bar = ProgressReporter(period=30.) nspikes = 0 @@ -235,44 +250,44 @@ def run(raw_data=None, experiment=None, prm=None, probe=None, chunk_low = decimate(chunk_raw) chunk_low_keep = chunk_low[i//16:j//16,:] experiment.recordings[chunk.recording].low.append(convert_dtype(chunk_low_keep, np.int16)) - + if not (convert_only and first_chunk_detected): # Apply thresholds. - chunk_detect, chunk_threshold = apply_threshold(chunk_fil, + chunk_detect, chunk_threshold = apply_threshold(chunk_fil, threshold=threshold, **prm) - + # Remove dead channels. dead = np.setdiff1d(np.arange(nchannels), probe.channels) chunk_detect[:,dead] = 0 chunk_threshold.strong[:,dead] = 0 chunk_threshold.weak[:,dead] = 0 - + # Find connected component (strong threshold). Return list of # Component instances. components = connected_components( - chunk_strong=chunk_threshold.strong, - chunk_weak=chunk_threshold.weak, + chunk_strong=chunk_threshold.strong, + chunk_weak=chunk_threshold.weak, probe_adjacency_list=probe.adjacency_list, chunk=chunk, **prm) - + # Now we extract the spike in each component. waveforms = extract_waveforms(chunk_detect=chunk_detect, - threshold=threshold, chunk_fil=chunk_fil, chunk_raw=chunk_raw, + threshold=threshold, chunk_fil=chunk_fil, chunk_raw=chunk_raw, probe=probe, components=components, **prm) # DEBUG module. # Execute the debug script. if diagnostics_script_path: execfile(diagnostics_script_path) - + # Log number of spikes in the chunk. nspikes += len(waveforms) - + # We sort waveforms by increasing order of fractional time. [add_waveform(experiment, waveform) for waveform in sorted(waveforms)] - + first_chunk_detected = True - + # Update the progress bar. progress_bar.update(rec/float(nrecs) + (float(s_end) / (nsamples*nrecs)), '%d spikes found.' % (nspikes))