Skip to content

Commit

Permalink
Merge pull request #802 from int-brain-lab/develop
Browse files Browse the repository at this point in the history
2.38.0
  • Loading branch information
k1o0 authored Jul 8, 2024
2 parents 7b279d2 + 91ba20d commit 3e80794
Show file tree
Hide file tree
Showing 32 changed files with 1,254 additions and 574 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ibllib_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- name: Flake8
run: |
python -m flake8
python -m flake8 --select D --ignore E ibllib/qc/camera.py
python -m flake8 --select D --ignore E ibllib/qc/camera.py ibllib/qc/task_metrics.py
- name: Brainbox tests
run: |
cd brainbox
Expand Down
7 changes: 5 additions & 2 deletions brainbox/behavior/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,8 +796,11 @@ def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='re
block_idx = trials.probabilityLeft == block

contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
reaction_time = np.vectorize(lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type])
[(x == signed_contrast) & block_idx]))(contrasts)
reaction_time = np.vectorize(
lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]),
otypes=[float]
)(contrasts)

if compute_ci:
ci = np.full((contrasts.size, 2), np.nan)
for i, x in enumerate(contrasts):
Expand Down
10 changes: 5 additions & 5 deletions brainbox/behavior/wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def velocity(re_ts, re_pos):
for line in traceback.format_stack():
print(line.strip())

msg = 'brainbox.behavior.wheel.velocity has been deprecated. Use velocity_filtered instead.'
warnings.warn(msg, DeprecationWarning)
msg = 'brainbox.behavior.wheel.velocity will soon be removed. Use velocity_filtered instead.'
warnings.warn(msg, FutureWarning)
logging.getLogger(__name__).warning(msg)

dp = np.diff(re_pos)
Expand Down Expand Up @@ -153,8 +153,8 @@ def velocity_smoothed(pos, freq, smooth_size=0.03):
for line in traceback.format_stack():
print(line.strip())

msg = 'brainbox.behavior.wheel.velocity_smoothed has been deprecated. Use velocity_filtered instead.'
warnings.warn(msg, DeprecationWarning)
msg = 'brainbox.behavior.wheel.velocity_smoothed will be removed. Use velocity_filtered instead.'
warnings.warn(msg, FutureWarning)
logging.getLogger(__name__).warning(msg)

# Define our smoothing window with an area of 1 so the units won't be changed
Expand Down Expand Up @@ -188,7 +188,7 @@ def last_movement_onset(t, vel, event_time):
print(line.strip())

msg = 'brainbox.behavior.wheel.last_movement_onset has been deprecated. Use get_movement_onset instead.'
warnings.warn(msg, DeprecationWarning)
warnings.warn(msg, FutureWarning)
logging.getLogger(__name__).warning(msg)

# Look back from timestamp
Expand Down
38 changes: 21 additions & 17 deletions brainbox/io/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths']
CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids']
WAVEFORMS_ATTRIBUTES = ['templates']


def load_lfp(eid, one=None, dataset_types=None, **kwargs):
Expand Down Expand Up @@ -128,6 +129,10 @@ def _channels_alf2bunch(channels, brain_regions=None):
'axial_um': channels['localCoordinates'][:, 1],
'lateral_um': channels['localCoordinates'][:, 0],
}
# here if we have some extra keys, they will carry over to the next dictionary
for k in channels:
if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']:
channels_[k] = channels[k]
if brain_regions:
channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym']
return channels_
Expand Down Expand Up @@ -851,14 +856,14 @@ def _load_object(self, *args, **kwargs):
@staticmethod
def _get_attributes(dataset_types):
"""returns attributes to load for spikes and clusters objects"""
if dataset_types is None:
return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES
else:
spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp]
cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl]
spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes))
cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes))
return spike_attributes, cluster_attributes
dataset_types = [] if dataset_types is None else dataset_types
spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp]
spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes))
cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl]
cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes))
waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl]
waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes))
return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes}

def _get_spike_sorting_collection(self, spike_sorter='pykilosort'):
"""
Expand Down Expand Up @@ -891,14 +896,15 @@ def get_version(self, spike_sorter='pykilosort'):
return dset[0]['version'] if len(dset) else 'unknown'

def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None,
missing='raise', **kwargs):
attribute=None, missing='raise', **kwargs):
"""
Downloads an ALF object
:param obj: object name, str between 'spikes', 'clusters' or 'channels'
:param spike_sorter: (defaults to 'pykilosort')
:param dataset_types: list of extra dataset types, for example ['spikes.samples']
:param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
:param kwargs: additional arguments to be passed to one.api.One.load_object
:param attribute: list of attributes to load for the object
:param missing: 'raise' (default) or 'ignore'
:return:
"""
Expand All @@ -907,8 +913,7 @@ def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_
self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter)
collection = collection or self.collection
_logger.debug(f"loading spike sorting object {obj} from {collection}")
spike_attributes, cluster_attributes = self._get_attributes(dataset_types)
attributes = {'spikes': spike_attributes, 'clusters': cluster_attributes}
attributes = self._get_attributes(dataset_types)
try:
self.files[obj] = self.one.load_object(
self.eid, obj=obj, attribute=attributes.get(obj, None),
Expand Down Expand Up @@ -986,11 +991,10 @@ def load_channels(self, **kwargs):
"""
# we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting
self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore')
if 'electrodeSites' in self.files:
channels = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
else: # otherwise, we try to load the channel object from the spike sorting folder - this may not contain histology
self.download_spike_sorting_object(obj='channels', **kwargs)
channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards)
self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs)
channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards)
if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails
channels = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
if 'brainLocationIds_ccf_2017' not in channels:
_logger.debug(f"loading channels from alyx for {self.files['channels']}")
_channels, self.histology = _load_channel_locations_traj(
Expand All @@ -1000,7 +1004,7 @@ def load_channels(self, **kwargs):
else:
channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions)
self.histology = 'alf'
return channels
return Bunch(channels)

def load_spike_sorting(self, spike_sorter='pykilosort', revision=None, enforce_version=True, good_units=False, **kwargs):
"""
Expand Down
70 changes: 20 additions & 50 deletions brainbox/processing.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
'''
Processes data from one form into another, e.g. taking spike times and binning them into
non-overlapping bins and convolving spike times with a gaussian kernel.
'''
"""Process data from one form into another.
For example, taking spike times and binning them into non-overlapping bins and convolving spike
times with a gaussian kernel.
"""

import numpy as np
import pandas as pd
from scipy import interpolate, sparse
from brainbox import core
from iblutil.numerical import bincount2D as _bincount2D
from iblutil.numerical import bincount2D
from iblutil.util import Bunch
import logging
import warnings
import traceback

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -118,35 +117,6 @@ def sync(dt, times=None, values=None, timeseries=None, offsets=None, interp='zer
return syncd


def bincount2D(x, y, xbin=0, ybin=0, xlim=None, ylim=None, weights=None):
"""
Computes a 2D histogram by aggregating values in a 2D array.
:param x: values to bin along the 2nd dimension (c-contiguous)
:param y: values to bin along the 1st dimension
:param xbin:
scalar: bin size along 2nd dimension
0: aggregate according to unique values
array: aggregate according to exact values (count reduce operation)
:param ybin:
scalar: bin size along 1st dimension
0: aggregate according to unique values
array: aggregate according to exact values (count reduce operation)
:param xlim: (optional) 2 values (array or list) that restrict range along 2nd dimension
:param ylim: (optional) 2 values (array or list) that restrict range along 1st dimension
:param weights: (optional) defaults to None, weights to apply to each value for aggregation
:return: 3 numpy arrays MAP [ny,nx] image, xscale [nx], yscale [ny]
"""
for line in traceback.format_stack():
print(line.strip())
warning_text = """Deprecation warning: bincount2D() is now a part of iblutil.
brainbox.processing.bincount2D is deprecated and will be removed in
future versions. Please replace imports with iblutil.numerical.bincount2D."""
_logger.warning(warning_text)
warnings.warn(warning_text, DeprecationWarning)
return _bincount2D(x, y, xbin, ybin, xlim, ylim, weights)


def compute_cluster_average(spike_clusters, spike_var):
"""
Quickish way to compute the average of some quantity across spikes in each cluster given
Expand Down Expand Up @@ -197,7 +167,7 @@ def bin_spikes(spikes, binsize, interval_indices=False):


def get_units_bunch(spks_b, *args):
'''
"""
Returns a bunch, where the bunch keys are keys from `spks` with labels of spike information
(e.g. unit IDs, times, features, etc.), and the values for each key are arrays with values for
each unit: these arrays are ordered and can be indexed by unit id.
Expand All @@ -223,18 +193,18 @@ def get_units_bunch(spks_b, *args):
--------
1) Create a units bunch given a spikes bunch, and get the amps for unit #4 from the units
bunch.
>>> import brainbox as bb
>>> import alf.io as aio
>>> from brainbox import processing
>>> import one.alf.io as alfio
>>> import ibllib.ephys.spikes as e_spks
(*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
>>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
>>> units_b = bb.processing.get_units_bunch(spks_b)
>>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
>>> units_b = processing.get_units_bunch(spks_b)
# Get amplitudes for unit 4.
>>> amps = units_b['amps']['4']
TODO add computation time estimate?
'''
"""

# Initialize `units`
units_b = Bunch()
Expand All @@ -261,7 +231,7 @@ def get_units_bunch(spks_b, *args):


def filter_units(units_b, t, **kwargs):
'''
"""
Filters units according to some parameters. **kwargs are the keyword parameters used to filter
the units.
Expand Down Expand Up @@ -299,24 +269,24 @@ def filter_units(units_b, t, **kwargs):
Examples
--------
1) Filter units according to the default parameters.
>>> import brainbox as bb
>>> import alf.io as aio
>>> from brainbox import processing
>>> import one.alf.io as alfio
>>> import ibllib.ephys.spikes as e_spks
(*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
# Get a spikes bunch, units bunch, and filter the units.
>>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
>>> units_b = bb.processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters'])
>>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
>>> units_b = processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters'])
>>> T = spks_b['times'][-1] - spks_b['times'][0]
>>> filtered_units = bb.processing.filter_units(units_b, T)
>>> filtered_units = processing.filter_units(units_b, T)
2) Filter units with no minimum amplitude, a minimum firing rate of 1 Hz, and a max false
positive rate of 0.2, given a refractory period of 2 ms.
>>> filtered_units = bb.processing.filter_units(units_b, T, min_amp=0, min_fr=1)
>>> filtered_units = processing.filter_units(units_b, T, min_amp=0, min_fr=1)
TODO: `units_b` input arg could eventually be replaced by `clstrs_b` if the required metrics
are in `clstrs_b['metrics']`
'''
"""

# Set params
params = {'min_amp': 50e-6, 'min_fr': 0.5, 'max_fpr': 0.2, 'rp': 0.002} # defaults
Expand Down
6 changes: 6 additions & 0 deletions brainbox/tests/test_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def test_get_movement_onset(self):
with self.assertRaises(ValueError):
wheel.get_movement_onset(intervals, np.random.permutation(self.trials['feedback_times']))

def test_velocity_deprecation(self):
"""Ensure brainbox.behavior.wheel.velocity is removed."""
from datetime import datetime
self.assertTrue(datetime.today() < datetime(2024, 8, 1),
'remove brainbox.behavior.wheel.velocity, velocity_smoothed and last_movement_onset')


class TestTraining(unittest.TestCase):
def setUp(self):
Expand Down
16 changes: 1 addition & 15 deletions brainbox/tests/test_processing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from brainbox import processing, core
import unittest
import numpy as np
import datetime


class TestProcessing(unittest.TestCase):
Expand Down Expand Up @@ -63,15 +62,6 @@ def test_sync(self):
self.assertTrue(times2.min() >= resamp2.times.min())
self.assertTrue(times2.max() <= resamp2.times.max())

def test_bincount2D_deprecation(self):
# Timer to remove bincount2D (now in iblutil)
# Once this test fails:
# - Remove the bincount2D method in processing.py
# - Remove the import from iblutil at the top of that file
# - Delete this test
if datetime.datetime.now() > datetime.datetime(2024, 6, 30):
raise NotImplementedError

def test_compute_cluster_averag(self):
# Create fake data for 3 clusters
clust1 = np.ones(40)
Expand Down Expand Up @@ -104,10 +94,6 @@ def test_compute_cluster_averag(self):
self.assertTrue(np.all(count == (40, 40, 50)))


def test_get_unit_bunches():
pass


if __name__ == "__main__":
if __name__ == '__main__':
np.random.seed(0)
unittest.main(exit=False)
Loading

0 comments on commit 3e80794

Please sign in to comment.