Skip to content

Commit

Permalink
Expose n_bins argument from align_embeddings (#25)
Browse files Browse the repository at this point in the history
* Expose n_bins argument from align_embeddings

* Fix docs

* Add sklearn helper functions to public docs
  • Loading branch information
stes authored Jul 12, 2023
1 parent e011694 commit d37e7f9
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
33 changes: 25 additions & 8 deletions cebra/integrations/sklearn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _consistency_scores(
Args:
embeddings: List of embedding matrices.
dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be
associated to the same dataset.
associated to the same dataset.
Returns:
List of the consistencies for each embeddings pair (first element) and
Expand Down Expand Up @@ -145,6 +145,7 @@ def _consistency_datasets(
embeddings: List[Union[npt.NDArray, torch.Tensor]],
dataset_ids: Optional[List[Union[int, str, float]]],
labels: List[Union[npt.NDArray, torch.Tensor]],
num_discretization_bins: int = 100
) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
"""Compute consistency between embeddings from different datasets.
Expand All @@ -158,9 +159,14 @@ def _consistency_datasets(
Args:
embeddings: List of embedding matrices.
dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be
associated to the same dataset.
associated to the same dataset.
labels: List of labels corresponding to each embedding and to use for alignment
between them.
num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used
for embedding alignment. Also see the ``n_bins`` argument in
:py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this
parameter is used internally. This argument is only used if ``labels``
is not ``None`` and the given labels are continuous and not already discrete.
Returns:
A list of scores obtained between embeddings from different datasets (first element),
Expand Down Expand Up @@ -203,7 +209,7 @@ def _consistency_datasets(

# NOTE(celia): with default values normalized=True and n_bins = 100
aligned_embeddings = cebra_sklearn_helpers.align_embeddings(
embeddings, labels)
embeddings, labels, n_bins=num_discretization_bins)
scores, pairs = _consistency_scores(aligned_embeddings,
datasets=dataset_ids)
between_dataset = [p[0] != p[1] for p in pairs]
Expand Down Expand Up @@ -303,6 +309,7 @@ def consistency_score(
between: Optional[Literal["datasets", "runs"]] = None,
labels: Optional[List[Union[npt.NDArray, torch.Tensor]]] = None,
dataset_ids: Optional[List[Union[int, str, float]]] = None,
num_discretization_bins: int = 100
) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
"""Compute the consistency score between embeddings, either between runs or between datasets.
Expand All @@ -320,6 +327,12 @@ def consistency_score(
*Consistency between runs* means the consistency between embeddings obtained from multiple models
trained on the **same dataset**. *Consistency between datasets* means the consistency between embeddings
obtained from models trained on **different datasets**, such as different animals, sessions, etc.
num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used
for embedding alignment. Also see the ``n_bins`` argument in
:py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this
parameter is used internally. This argument is only used if ``labels``
is not ``None``, alignment between datasets is used (``between = "datasets"``), and the given labels
are continuous and not already discrete.
Returns:
The list of scores computed between the embeddings (first returns), the list of pairs corresponding
Expand Down Expand Up @@ -356,12 +369,16 @@ def consistency_score(
if labels is not None:
raise ValueError(
f"No labels should be provided for between-runs consistency.")
scores, pairs, datasets = _consistency_runs(embeddings=embeddings,
dataset_ids=dataset_ids)
scores, pairs, datasets = _consistency_runs(
embeddings=embeddings,
dataset_ids=dataset_ids,
)
elif between == "datasets":
scores, pairs, datasets = _consistency_datasets(embeddings=embeddings,
dataset_ids=dataset_ids,
labels=labels)
scores, pairs, datasets = _consistency_datasets(
embeddings=embeddings,
dataset_ids=dataset_ids,
labels=labels,
num_discretization_bins=num_discretization_bins)
else:
raise NotImplementedError(
f"Invalid comparison, got between={between}, expects either datasets or runs."
Expand Down
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ these components in other contexts and research code bases.
api/sklearn/cebra
api/sklearn/metrics
api/sklearn/decoder
api/sklearn/helpers



Expand Down
7 changes: 7 additions & 0 deletions docs/source/api/sklearn/helpers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Helper functions
----------------

.. automodule:: cebra.integrations.sklearn.helpers
:show-inheritance:
:members:

0 comments on commit d37e7f9

Please sign in to comment.