Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flipping features #19

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions spikedetekt2/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ def save_features(experiment, **prm):
features = project_pcs(waveform, pcs)
spikes.features_masks[i,:,0] = features.ravel()

# Flip the PCs.
# First, compute the mean features.
features_mean = spikes.features_masks[::step, :, 0].mean(axis=0)
to_flip = features_mean < 0
to_flip_ind = np.nonzero(to_flip)[0]
if len(to_flip_ind) > 0:
# Flip the features.
spikes.features_masks[:, to_flip_ind, 0] *= -1
pcs = experiment.channel_groups[chgrp]._node.pca_waveforms
# Find the channels to flip.
channels_to_flip = to_flip[0:pcs.shape[-1]:npcs]
channels_to_flip_ind = np.nonzero(channels_to_flip)[0]
# Flip the PCs.
pcs[..., channels_to_flip_ind] *= -1


# -----------------------------------------------------------------------------
# File logger
Expand All @@ -147,7 +162,7 @@ def close_file_logger(LOGGER_FILE):
# -----------------------------------------------------------------------------
# Main loop
# -----------------------------------------------------------------------------
def run(raw_data=None, experiment=None, prm=None, probe=None,
def run(raw_data=None, experiment=None, prm=None, probe=None,
_debug=False, convert_only=False):
"""This main function takes raw data (either as a RawReader, or a path
to a filename, or an array) and executes the main algorithm (filtering,
Expand Down Expand Up @@ -181,19 +196,19 @@ def run(raw_data=None, experiment=None, prm=None, probe=None,

# Get the bandpass filter.
filter = bandpass_filter(**prm)

if not (convert_only and first_chunk_detected):
# Compute the strong threshold across excerpts uniformly scattered across the
# whole recording.
threshold = get_threshold(raw_data, filter=filter,
threshold = get_threshold(raw_data, filter=filter,
channels=probe.channels, **prm)
assert not np.isnan(threshold.weak).any()
assert not np.isnan(threshold.strong).any()
debug("Threshold: " + str(threshold))

# Debug module.
diagnostics_script_path = prm.get('diagnostics_script_path', None)

# Progress bar.
progress_bar = ProgressReporter(period=30.)
nspikes = 0
Expand Down Expand Up @@ -235,44 +250,44 @@ def run(raw_data=None, experiment=None, prm=None, probe=None,
chunk_low = decimate(chunk_raw)
chunk_low_keep = chunk_low[i//16:j//16,:]
experiment.recordings[chunk.recording].low.append(convert_dtype(chunk_low_keep, np.int16))

if not (convert_only and first_chunk_detected):
# Apply thresholds.
chunk_detect, chunk_threshold = apply_threshold(chunk_fil,
chunk_detect, chunk_threshold = apply_threshold(chunk_fil,
threshold=threshold, **prm)

# Remove dead channels.
dead = np.setdiff1d(np.arange(nchannels), probe.channels)
chunk_detect[:,dead] = 0
chunk_threshold.strong[:,dead] = 0
chunk_threshold.weak[:,dead] = 0

# Find connected component (strong threshold). Return list of
# Component instances.
components = connected_components(
chunk_strong=chunk_threshold.strong,
chunk_weak=chunk_threshold.weak,
chunk_strong=chunk_threshold.strong,
chunk_weak=chunk_threshold.weak,
probe_adjacency_list=probe.adjacency_list,
chunk=chunk, **prm)

# Now we extract the spike in each component.
waveforms = extract_waveforms(chunk_detect=chunk_detect,
threshold=threshold, chunk_fil=chunk_fil, chunk_raw=chunk_raw,
threshold=threshold, chunk_fil=chunk_fil, chunk_raw=chunk_raw,
probe=probe, components=components, **prm)

# DEBUG module.
# Execute the debug script.
if diagnostics_script_path:
execfile(diagnostics_script_path)

# Log number of spikes in the chunk.
nspikes += len(waveforms)

# We sort waveforms by increasing order of fractional time.
[add_waveform(experiment, waveform) for waveform in sorted(waveforms)]

first_chunk_detected = True

# Update the progress bar.
progress_bar.update(rec/float(nrecs) + (float(s_end) / (nsamples*nrecs)),
'%d spikes found.' % (nspikes))
Expand Down