Skip to content

Commit

Permalink
Fix minor linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
stes committed Oct 20, 2024
1 parent 7b60c03 commit 6a38d8e
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 16 deletions.
4 changes: 2 additions & 2 deletions cebra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def __getattr__(key):

return CEBRA
elif key == "KNNDecoder":
from cebra.integrations.sklearn.decoder import KNNDecoder
from cebra.integrations.sklearn.decoder import KNNDecoder # noqa: F811

return KNNDecoder
elif key == "L1LinearRegressor":
from cebra.integrations.sklearn.decoder import L1LinearRegressor
from cebra.integrations.sklearn.decoder import L1LinearRegressor # noqa: F811

return L1LinearRegressor
elif not key.startswith("_"):
Expand Down
13 changes: 9 additions & 4 deletions cebra/data/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,15 @@ class OrthogonalProcrustesAlignment:
For each dataset, the data and labels to align the data on is provided.
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to the labels of the reference dataset (``ref_label``) are selected and used to sample from the dataset to align (``data``).
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number of samples ``subsample``.
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`, on those subsampled datasets.
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data`` to the ``ref_data``.
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to
the labels of the reference dataset (``ref_label``) are selected and used to sample
from the dataset to align (``data``).
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number
of samples ``subsample``.
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`,
on those subsampled datasets.
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data``
to the ``ref_data``.
Note:
``data`` and ``ref_data`` can be of different sample size (axis 0) but **must** have the same number
Expand Down
3 changes: 2 additions & 1 deletion cebra/data/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,8 @@ def load(
- if no key is provided, the first data structure found upon iteration of the collection will be loaded;
- if a key is provided, it needs to correspond to an existing item of the collection;
- if a key is provided, the data value accessed needs to be a data structure;
- the function loads data for only one data structure, even if the file contains more. The function can be called again with the corresponding key to get the other ones.
- the function loads data for only one data structure, even if the file contains more. The function can be
called again with the corresponding key to get the other ones.
Args:
file: The path to the given file to load, in a supported format.
Expand Down
4 changes: 1 addition & 3 deletions cebra/data/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,6 @@ def __post_init__(self):
# here might be sub-optimal. The final behavior should be determined after
# e.g. integrating the FAISS dataloader back in.
super().__post_init__()
index = self.index.to(self.device)

if self.conditional != "time_delta":
raise NotImplementedError(
Expand All @@ -360,8 +359,7 @@ def __post_init__(self):
self.time_distribution = cebra.distributions.TimeContrastive(
time_offset=self.time_offset,
num_samples=len(self.dataset.neural),
device=self.device,
)
device=self.device)
self.behavior_distribution = cebra.distributions.TimedeltaDistribution(
self.dataset.continuous_index, self.time_offset, device=self.device)

Expand Down
3 changes: 2 additions & 1 deletion cebra/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def fit_models(self,
to fit the CEBRA models on. The models are then trained using temporal contrastive learning
(CEBRA-Time).
An example of a valid ``datasets`` value could be:
``datasets={"dataset1": neural_data, "dataset2": (neurald_data, continuous_data, discrete_data), "dataset3": (neural_data2, continuous_data2)}``.
``datasets={"dataset1": neural_data, "dataset2": (neurald_data, continuous_data, discrete_data),
"dataset3": (neural_data2, continuous_data2)}``.
params: Dict of parameter values provided by the user, either as a single value, for
fixed hyperparameter values, or with a list of values for hyperparameters to optimize.
If the value is a list of a single element, the hyperparameter is considered as fixed.
Expand Down
2 changes: 1 addition & 1 deletion cebra/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _requires_package_version(function):

@wraps(function)
def wrapper(*args, patched_version=None, **kwargs):
if patched_version != None:
if patched_version is not None:
installed_version = pkg_resources.parse_version(
patched_version) # Use the patched version if provided
else:
Expand Down
2 changes: 1 addition & 1 deletion cebra/models/criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn


Expand Down Expand Up @@ -212,7 +213,6 @@ def __init__(self,
self.max_inverse_temperature = math.inf
else:
self.max_inverse_temperature = 1.0 / min_temperature
start_tempearture = float(temperature)
log_inverse_temperature = torch.tensor(
math.log(1.0 / float(temperature)))
self.log_inverse_temperature = nn.Parameter(log_inverse_temperature)
Expand Down
2 changes: 1 addition & 1 deletion cebra/models/multiobjective.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def is_valid(self, mode):
Returns:
``True`` for a valid representation, ``False`` otherwise.
"""
return mode in _ALL
return mode in _ALL # noqa: F821

def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion cebra/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def get_options(
):
instance = cls.get_instance(module)
if expand_parametrized:
filter_ = lambda k, v: True

def filter_(k, v):
return True
else:

class _Filter(set):
Expand Down
2 changes: 1 addition & 1 deletion cebra/solver/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def fit(self,
step_idx = 0
while True:
for _, batch in enumerate(loader):
stats = self.step(batch)
_ = self.step(batch)
self._log_checkpoint(num_steps, loader, valid_loader)
step_idx += 1
if step_idx >= num_steps:
Expand Down

0 comments on commit 6a38d8e

Please sign in to comment.