diff --git a/src/ibldsp/waveform_extraction.py b/src/ibldsp/waveform_extraction.py index 4bd5576..692ac9e 100644 --- a/src/ibldsp/waveform_extraction.py +++ b/src/ibldsp/waveform_extraction.py @@ -3,6 +3,7 @@ import scipy import pandas as pd import numpy as np +from pathlib import Path from numpy.lib.format import open_memmap from joblib import Parallel, delayed, cpu_count @@ -472,3 +473,94 @@ def extract_wfs_cbin( chan_map = np.ones((max_wf * nu, nc), np.int16) * -1 chan_map[dummy_idx] = channel_neighbors[peak_channel[dummy_idx].astype(int)] np.savez(channels_fn, channels=chan_map) + +class WaveformsLoader: + + def __init__( + self, + data_dir, + max_wf=256, + trough_offset=42, + spike_length_samples=128, + num_channels=40, + wfs_dtype=np.float32 + ): + + self.data_dir = Path(data_dir) + self.max_wf = max_wf + self.trough_offset = trough_offset + self.spike_length_samples = spike_length_samples + self.num_channels = num_channels + self.wfs_dtype = wfs_dtype + + self.traces_fp = self.data_dir.joinpath("waveforms.traces.npy") + self.templates_fp = self.data_dir.joinpath("waveforms.templates.npy") + self.table_fp = self.data_dir.joinpath("waveforms.table.pqt") + self.channels_fp = self.data_dir.joinpath("waveforms.channels.npz") + + assert self.traces_fp.exists(), "waveforms.traces.npy file missing!" + assert self.templates_fp.exists(), "waveforms.templates.npy file missing!" + assert self.table_fp.exists(), "waveforms.table.pqt file missing!" + assert self.channels_fp.exists(), "waveforms.channels.npz file missing!" + + # ingest parquet table + self.table = pd.read_parquet(self.table_fp) + self.num_labels = self.table["cluster"].nunique() + self.labels = np.array(self.table["cluster"].unique()) + self.total_wfs = sum(~self.table["peak_channel"].isna()) + self.table["wf_number"] = np.tile(np.arange(self.max_wf), self.num_labels) + + traces_shape = (self.num_labels, max_wf, num_channels, spike_length_samples) + templates_shape = (self.num_labels, num_channels, spike_length_samples) + + self.traces = np.memmap(self.traces_fp, dtype=wfs_dtype, shape=traces_shape) + self.templates = np.memmap(self.templates_fp, dtype=np.float32, shape=templates_shape) + self.channels = np.load(self.channels_fp, allow_pickle="True")["channels"] + + def load_waveforms(self, labels=None, indices=None, return_info=True): + + if labels is None: + labels = self.labels + if indices is None: + indices = np.arange(self.max_wf) + + wfs = self.traces[np.array(labels)][:, indices, :, :] + + if return_info: + _table = self.table[self.table["cluster"].isin(labels)].copy() + _table = _table[_table["wf_number"].isin(indices)] + return wfs, _table + + return wfs + + def random_waveforms( + self, + labels=None, + num_random_labels=None, + num_random_waveforms=None, + return_info=True, + seed=None + ): + + rg = np.random.default_rng(seed=seed) + + if labels is None: + if num_random_labels is None: + labels = rg.choice(self.labels, 10) + else: + labels = rg.choice(self.labels, num_random_labels) + else: + assert num_random_labels is None, "labels and num_random_labels cannot both be set" + + wfs = self.traces[np.array(labels)][:, :, :, :] + + + + + + + + + + + \ No newline at end of file