Skip to content

Commit

Permalink
Update remove_duplicates to avoid unsafe casting
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Mar 1, 2024
1 parent e5272fc commit bec8e6f
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions kilosort/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@
import torch


@njit
@njit("(int64[:], int32[:], int32)")
def remove_duplicates(spike_times, spike_clusters, dt=15):
'''Removes same-cluster spikes that occur within `dt` samples.'''
keep = np.zeros_like(spike_times, bool_)
cluster_t0 = {}
for (i,t), c in zip(enumerate(spike_times), spike_clusters):
t0 = cluster_t0.get(c, t-dt)
for i in range(spike_times.size):
t = spike_times[i]
c = spike_clusters[i]
if c in cluster_t0:
t0 = cluster_t0[c]
else:
t0 = t - dt

if t >= (t0 + dt):
# Separate spike, reset t0 and keep spike
cluster_t0[c] = t
Expand Down

0 comments on commit bec8e6f

Please sign in to comment.