Skip to content

Commit

Permalink
initial code
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-langfield committed Jul 16, 2024
1 parent 3daa7d1 commit 4ac6f80
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions src/ibldsp/waveform_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import scipy
import pandas as pd
import numpy as np
from pathlib import Path
from numpy.lib.format import open_memmap
from joblib import Parallel, delayed, cpu_count

Expand Down Expand Up @@ -472,3 +473,94 @@ def extract_wfs_cbin(
chan_map = np.ones((max_wf * nu, nc), np.int16) * -1
chan_map[dummy_idx] = channel_neighbors[peak_channel[dummy_idx].astype(int)]
np.savez(channels_fn, channels=chan_map)

class WaveformsLoader:

def __init__(
self,
data_dir,
max_wf=256,
trough_offset=42,
spike_length_samples=128,
num_channels=40,
wfs_dtype=np.float32
):

self.data_dir = Path(data_dir)
self.max_wf = max_wf
self.trough_offset = trough_offset
self.spike_length_samples = spike_length_samples
self.num_channels = num_channels
self.wfs_dtype = wfs_dtype

self.traces_fp = self.data_dir.joinpath("waveforms.traces.npy")
self.templates_fp = self.data_dir.joinpath("waveforms.templates.npy")
self.table_fp = self.data_dir.joinpath("waveforms.table.pqt")
self.channels_fp = self.data_dir.joinpath("waveforms.channels.npz")

assert self.traces_fp.exists(), "waveforms.traces.npy file missing!"
assert self.templates_fp.exists(), "waveforms.templates.npy file missing!"
assert self.table_fp.exists(), "waveforms.table.pqt file missing!"
assert self.channels_fp.exists(), "waveforms.channels.npz file missing!"

# ingest parquet table
self.table = pd.read_parquet(self.table_fp)
self.num_labels = self.table["cluster"].nunique()
self.labels = np.array(self.table["cluster"].unique())
self.total_wfs = sum(~self.table["peak_channel"].isna())
self.table["wf_number"] = np.tile(np.arange(self.max_wf), self.num_labels)

traces_shape = (self.num_labels, max_wf, num_channels, spike_length_samples)
templates_shape = (self.num_labels, num_channels, spike_length_samples)

self.traces = np.memmap(self.traces_fp, dtype=wfs_dtype, shape=traces_shape)
self.templates = np.memmap(self.templates_fp, dtype=np.float32, shape=templates_shape)
self.channels = np.load(self.channels_fp, allow_pickle="True")["channels"]

def load_waveforms(self, labels=None, indices=None, return_info=True):

if labels is None:
labels = self.labels
if indices is None:
indices = np.arange(self.max_wf)

wfs = self.traces[np.array(labels)][:, indices, :, :]

if return_info:
_table = self.table[self.table["cluster"].isin(labels)].copy()
_table = _table[_table["wf_number"].isin(indices)]
return wfs, _table

return wfs

def random_waveforms(
self,
labels=None,
num_random_labels=None,
num_random_waveforms=None,
return_info=True,
seed=None
):

rg = np.random.default_rng(seed=seed)

if labels is None:
if num_random_labels is None:
labels = rg.choice(self.labels, 10)
else:
labels = rg.choice(self.labels, num_random_labels)
else:
assert num_random_labels is None, "labels and num_random_labels cannot both be set"

wfs = self.traces[np.array(labels)][:, :, :, :]











0 comments on commit 4ac6f80

Please sign in to comment.