Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated waveform extraction #31

Merged
merged 11 commits into from
Apr 18, 2024
2 changes: 2 additions & 0 deletions release_notes.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# 0.10.0
## 0.10.3 2024-04-18
- Patch fixing memory leaks for `waveform_extraction` module.
## 0.10.2 2024-04-10
- Add `waveform_extraction` module to `ibldsp`. This includes the `extract_wfs_array` and `extract_wfs_cbin` methods.
- Add code for performing subsample shifts of waveforms.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setuptools.setup(
name="ibl-neuropixel",
version="0.10.2",
version="0.10.3",
author="The International Brain Laboratory",
description="Collection of tools for Neuropixel 1.0 and 2.0 probes data",
long_description=long_description,
Expand Down
247 changes: 157 additions & 90 deletions src/ibldsp/waveform_extraction.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import logging

import scipy
import pandas as pd
import numpy as np
from numpy.lib.format import open_memmap
import neuropixel
import spikeglx

from joblib import Parallel, delayed, cpu_count

import neuropixel
import spikeglx
from ibldsp.voltage import detect_bad_channels, interpolate_bad_channels, car
from ibldsp.fourier import fshift
from ibldsp.utils import make_channel_index

logger = logging.getLogger(__name__)


def extract_wfs_array(
arr,
Expand Down Expand Up @@ -83,7 +86,12 @@ def _get_channel_labels(sr, num_snippets=20, verbose=True):
if verbose:
from tqdm import trange

start = (np.linspace(100, int(sr.rl) - 100, num_snippets) * sr.fs).astype(int)
# for most of recordings we take 100 secs left and right but account for recordings smaller
buffer_left_right = np.minimum(100, sr.rl * 0.03)
start = (
np.linspace(buffer_left_right, int(sr.rl) - buffer_left_right, num_snippets)
* sr.fs
).astype(int)
end = start + int(sr.fs)

_channel_labels = np.zeros((384, num_snippets), int)
Expand All @@ -101,10 +109,9 @@ def _get_channel_labels(sr, num_snippets=20, verbose=True):

def _make_wfs_table(
sr,
spike_times,
spike_samples,
spike_clusters,
spike_channels,
chunksize_t=10,
max_wf=256,
trough_offset=42,
spike_length_samples=128,
Expand All @@ -118,8 +125,8 @@ def _make_wfs_table(
"""
# exclude spikes without a buffer on either end
# of recording
allowed_idx = (spike_times > trough_offset) & (
spike_times < sr.ns - (spike_length_samples - trough_offset)
allowed_idx = (spike_samples > trough_offset) & (
spike_samples < sr.ns - (spike_length_samples - trough_offset)
)

rng = np.random.default_rng(seed=2024) # numpy 1.23.5
Expand All @@ -136,7 +143,7 @@ def _make_wfs_table(
nspikes = u_spikeidx.shape[0]
unit_nspikes[i] = nspikes
# uniformly select up to 500 spikes
u_wf_idx = rng.choice(u_spikeidx, min(max_wf, nspikes))
u_wf_idx = rng.choice(u_spikeidx, min(max_wf, nspikes), replace=False)
unit_wf_idx[u, : min(max_wf, nspikes)] = u_wf_idx

# all wf indices in order
Expand All @@ -145,13 +152,12 @@ def _make_wfs_table(
wf_idx = wf_idx[np.nonzero(wf_idx)[0][0]:]

# get sample times, clusters, channels

wf_flat = pd.DataFrame(
{
"indices": np.arange(wf_idx.shape[0]),
"samples": spike_times[wf_idx].astype(int),
"clusters": spike_clusters[wf_idx].astype(int),
"channels": spike_channels[wf_idx].astype(int),
"index": np.arange(wf_idx.shape[0]),
"sample": spike_samples[wf_idx].astype(int),
"cluster": spike_clusters[wf_idx].astype(int),
"peak_channel": spike_channels[wf_idx].astype(int),
}
)

Expand All @@ -176,6 +182,9 @@ def write_wfs_chunk(
Parallel job to extract waveforms from chunk `i_chunk` of a recording `sr` and
write them to the correct spot in the output .npy file `wfs_fn`.
"""
if len(wf_flat) == 0:
return

my_sr = spikeglx.Reader(cbin)
s0, s1 = sr_sl

Expand All @@ -197,13 +206,13 @@ def write_wfs_chunk(
else:
offset = trough_offset

sample = wf_flat["samples"].astype(int) + offset - i_chunk * chunksize_samples
peak_channel = wf_flat["channels"]
sample = wf_flat["sample"].astype(int) + offset - i_chunk * chunksize_samples
peak_channel = wf_flat["peak_channel"]

df = pd.DataFrame({"sample": sample, "peak_channel": peak_channel})

snip = my_sr[
s0 - offset: s1 + spike_length_samples - trough_offset, : -my_sr.nsync
s0 - offset:s1 + spike_length_samples - trough_offset, :-my_sr.nsync
]
snip0 = interpolate_bad_channels(
fshift(
Expand All @@ -216,98 +225,109 @@ def write_wfs_chunk(
# car
snip1 = np.full((my_sr.nc, snip0.shape[1]), np.nan)
snip1[:-1, :] = car_func(snip0)
wfs_mmap[wf_flat["indices"], :, :] = extract_wfs_array(
wfs_mmap[wf_flat["index"], :, :] = extract_wfs_array(
snip1.T, df, channel_neighbors
)[0]
wfs_mmap.flush()


def extract_wfs_cbin(
cbin_file,
output_file,
spike_times,
output_dir,
spike_samples,
spike_clusters,
spike_channels,
h=None,
wf_extract_params=None,
nprocesses=None,
channel_labels=None,
max_wf=256,
trough_offset=42,
spike_length_samples=128,
chunksize_samples=int(3000),
n_jobs=None,
):
"""
Given a cbin file and locations of spikes, extract waveforms for each unit, compute
the templates, and save to `output_file`.

If `output_file=Path("/path/to/example_clusters.npy")`, this array will be of shape
`(num_units, max_wf, nc, spike_length_samples)` where by default `max_wf=256, nc=40,
spike_length_samples=128`.

The file "path/to/example_clusters_templates.npy" will also be generated, of shape
`(num_units, nc, spike_length_samples)`, where the median across waveforms is taken
for each unit.

The parquet file "path/to/example_clusters.pqt" contains the samples and max channels
of each waveform, indexed by unit.
the templates, and save the results in `output_path`. The waveforms come from chunks
of raw data which are phase-corrected to account for the ADC, high-pass filtered in
time with an order 3 Butterworth filter with a 300Hz cutoff, and a common-average
reference procedure is applied in the spatial dimension.

The following files will be generated:
- waveforms.traces.npy: `(num_units, max_wf, nc, spike_length_samples)`
This file contains the lightly processed waveforms indexed by cluster in the first
dimension. By default `max_wf=256, nc=40, spike_length_samples=128`.

- waveforms.templates.npy: `(num_units, nc, spike_length_samples)`
This file contains the median across individual waveforms for each unit.

- waveforms.channels.npz: `(num_units * max_wf, nc)`
The i'th row contains the ordered indices of the `nc`-channel neighborhood used
to extract the i'th waveform. A NaN means the waveform is missing because the
unit it was supposed to come from has less than `max_wf` spikes total in the
recording.

- waveforms.table.pqt: `num_units * max_wf` rows
For each waveform, gives the absolute sample number from the recording (i.e.
where to find it in `spikes.samples`), peak channel, cluster, and linear index.
A row of -1s implies that the waveform is missing because the unit is was supposed
to come from has less than `max_wf` spikes total.
"""
if h is None:
h = neuropixel.trace_header()

if wf_extract_params is None:
wf_extract_params = {
"max_wf": 256,
"trough_offset": 42,
"spike_length_samples": 128,
"chunksize_t": 10,
}

output_path = output_file.parent

max_wf = wf_extract_params["max_wf"]
trough_offset = wf_extract_params["trough_offset"]
spike_length_samples = wf_extract_params["spike_length_samples"]
chunksize_t = wf_extract_params["chunksize_t"]
n_jobs = n_jobs or int(cpu_count() / 2)

sr = spikeglx.Reader(cbin_file)
chunksize_samples = chunksize_t * 30_000
s0_arr = np.arange(0, sr.ns, chunksize_samples)
s1_arr = s0_arr + chunksize_samples
s1_arr[-1] = sr.ns

# selects spikes from throughout the recording for each unit
wf_flat, unit_ids = _make_wfs_table(
sr, spike_times, spike_clusters, spike_channels, **wf_extract_params
sr,
spike_samples,
spike_clusters,
spike_channels,
max_wf,
trough_offset,
spike_length_samples,
)
num_chunks = s0_arr.shape[0]
print(f"Chunk size: {chunksize_t}")
print(f"Num chunks: {num_chunks}")

print("Running channel detection")
channel_labels = _get_channel_labels(sr)
logger.info(f"Chunk size samples: {chunksize_samples}")
logger.info(f"Num chunks: {num_chunks}")

logger.info("Running channel detection")
if channel_labels is None:
channel_labels = _get_channel_labels(sr)

nwf = wf_flat["samples"].shape[0]
nwf = len(wf_flat)
nu = unit_ids.shape[0]
print(f"Extracting {nwf} waveforms from {nu} units")
logger.info(f"Extracting {nwf} waveforms from {nu} units")

# get channel geometry
geom = np.c_[h["x"], h["y"]]
channel_neighbors = make_channel_index(geom)
nc = channel_neighbors.shape[1]

fn = output_path.joinpath("_wf_extract_intermediate.npy")
# this intermediate memmap is written to in parallel
# the waveforms are ordered only by their chronological position
# in the recording, as we are reading them in time chunks
int_fn = output_dir.joinpath("_wf_extract_intermediate.npy")
wfs = open_memmap(
fn, mode="w+", shape=(nwf, nc, spike_length_samples), dtype=np.float32
int_fn, mode="w+", shape=(nwf, nc, spike_length_samples), dtype=np.float32
)

slices = [
slice(
*(np.searchsorted(wf_flat["samples"], [s0_arr[i], s1_arr[i]]).astype(int))
)
slice(*(np.searchsorted(wf_flat["sample"], [s0_arr[i], s1_arr[i]]).astype(int)))
for i in range(num_chunks)
]

nprocesses = nprocesses or int(cpu_count() - cpu_count() / 4)
_ = Parallel(n_jobs=nprocesses)(
_ = Parallel(n_jobs=n_jobs)(
delayed(write_wfs_chunk)(
i,
cbin_file,
fn,
int_fn,
wfs.shape,
h,
channel_labels,
Expand All @@ -321,34 +341,81 @@ def extract_wfs_cbin(
for i in range(num_chunks)
)

wfs = open_memmap(
fn, mode="r+", shape=(nwf, nc, spike_length_samples), dtype=np.float32
)
# bookkeeping
wfs_by_unit = np.full(
(nu, max_wf, nc, spike_length_samples), np.nan, dtype=np.float16
# output files
traces_fn = output_dir.joinpath("waveforms.traces.npy")
templates_fn = output_dir.joinpath("waveforms.templates.npy")
table_fn = output_dir.joinpath("waveforms.table.pqt")
channels_fn = output_dir.joinpath("waveforms.channels.npz")

## rearrange and save traces by unit
# store medians across waveforms
wfs_templates = np.full((nu, nc, spike_length_samples), np.nan, dtype=np.float32)
# create waveform output file (~2-3 GB)
traces_by_unit = open_memmap(
traces_fn,
mode="w+",
shape=(nu, max_wf, nc, spike_length_samples),
dtype=np.float16,
)
wfs_medians = np.full((nu, nc, spike_length_samples), np.nan, dtype=np.float32)
print("Computing templates")
for i, u in enumerate(unit_ids):
_wfs_unit = wfs[wf_flat["clusters"] == u]
nwf_u = _wfs_unit.shape[0]
wfs_by_unit[i, : min(max_wf, nwf_u), :, :] = _wfs_unit.astype(np.float16)
wfs_medians[i, :, :] = np.nanmedian(_wfs_unit, axis=0)
logger.info("Writing to output files")

df = pd.DataFrame(
for i, u in enumerate(unit_ids):
idx = np.where(wf_flat["cluster"] == u)[0]
nwf_u = idx.shape[0]
# reopening these memmaps on each iteration
# forces Python to clean up each large array it loads
# and prevent a memory leak
wfs = open_memmap(
int_fn, mode="r+", shape=(nwf, nc, spike_length_samples), dtype=np.float32
)
traces_by_unit = open_memmap(
traces_fn,
mode="r+",
shape=(nu, max_wf, nc, spike_length_samples),
dtype=np.float16,
)
# write up to 256 waveforms and leave the rest of dimensions 1-3 as NaNs
traces_by_unit[i, : min(max_wf, nwf_u), :, :] = wfs[idx].astype(np.float16)
traces_by_unit.flush()
# populate this array in memory as it's 256x smaller
wfs_templates[i, :, :] = np.nanmedian(wfs[idx], axis=0)

# cleanup intermediate file
int_fn.unlink()

# save templates
np.save(templates_fn, wfs_templates)

# add in dummy rows and order by unit, and then sample
unit_counts = wf_flat.groupby("cluster")["sample"].count().reset_index(name="count")
unit_counts["missing"] = 256 - unit_counts["count"]
missing_wf = unit_counts[unit_counts["missing"] > 0]
total_missing = sum(missing_wf.missing)
extra_rows = pd.DataFrame(
{
"sample": wf_flat["samples"],
"peak_channel": wf_flat["channels"],
"cluster": wf_flat["clusters"],
"sample": [np.nan] * total_missing,
"peak_channel": [np.nan] * total_missing,
"index": [np.nan] * total_missing,
"cluster": sum(
[[row["cluster"]] * row["missing"] for _, row in missing_wf.iterrows()],
[],
),
}
)
df = df.sort_values(["cluster", "sample"]).set_index(["cluster", "sample"])

np.save(output_file, wfs_by_unit)
# medians
avg_file = output_file.parent.joinpath(output_file.stem + "_templates.npy")
np.save(avg_file, wfs_medians)
df.to_parquet(output_file.with_suffix(".pqt"))

fn.unlink()
save_df = pd.concat([wf_flat, extra_rows])
# now the waveforms are arranged by cluster, and then in time
# these match dimensions 0 and 1 of waveforms.traces.npy
save_df.sort_values(["cluster", "sample"], inplace=True)
save_df.to_parquet(table_fn)

# save channel map for each waveform
# these values are now reordered so that they match the pqt
# and the traces file
peak_channel = np.nan_to_num(save_df["peak_channel"].to_numpy(), nan=-1).astype(
np.int16
)
dummy_idx = np.where(peak_channel >= 0)[0]
# leave "missing" waveforms as -1 since we can't have NaN with int dtype
chan_map = np.ones((max_wf * nu, nc), np.int16) * -1
chan_map[dummy_idx] = channel_neighbors[peak_channel[dummy_idx].astype(int)]
np.savez(channels_fn, channels=chan_map)
Loading