Skip to content

Commit

Permalink
Code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
stes committed Jun 26, 2023
1 parent 4ba06c9 commit a6d2fd1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 22 deletions.
8 changes: 3 additions & 5 deletions cebra/data/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
#
import copy
import warnings
from typing import List, Optional, Union

import joblib
import numpy as np
import numpy.typing as npt
import scipy.linalg
import torch
import warnings


def _require_numpy_array(array: Union[npt.NDArray, torch.Tensor]):
Expand Down Expand Up @@ -190,15 +190,13 @@ def fit(
f"should be larger than the 'subsample' "
f"parameter ({self.subsample}). Ignoring subsampling and "
f"computing alignment on the full dataset instead, which will "
f"give better results."
)
f"give better results.")
else:
if self.subsample < 1000:
warnings.warn(
"This function is experimental when the subsample dimension "
"is less than 1000. You can probably use the whole dataset "
"for alignment by setting subsample=None."
)
"for alignment by setting subsample=None.")

idc = np.random.choice(len(X), self.subsample)
X = X[idc]
Expand Down
30 changes: 13 additions & 17 deletions tests/test_data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def test_orthogonal_alignment_shapes(ref_data, data, ref_labels, labels):
assert _does_shape_match(data, aligned_embedding)

# Test with non-default parameters
alignment_model = cebra_data_helper.OrthogonalProcrustesAlignment(
top_k=10)
alignment_model = cebra_data_helper.OrthogonalProcrustesAlignment(top_k=10)

aligned_embedding = alignment_model.fit_transform(ref_data, data,
ref_labels, labels)
assert _does_shape_match(data, aligned_embedding), (data.shape, aligned_embedding.shape)
assert _does_shape_match(data, aligned_embedding), (data.shape,
aligned_embedding.shape)


@pytest.mark.parametrize("ref_data,data,ref_labels,labels,match",
Expand Down Expand Up @@ -156,16 +156,14 @@ def test_orthogonal_alignment_without_labels():
aligned_embedding_without_labels = alignment_model.transform(
embedding_100_4d_2)

assert np.allclose(aligned_embedding,
aligned_embedding_without_labels)
assert np.allclose(aligned_embedding, aligned_embedding_without_labels)


@pytest.mark.parametrize("seed", [483, 425, 166, 672, 123])
def test_orthogonal_alignment(seed):
np.random.seed(seed)
embedding_100_4d = np.random.uniform(0, 1, (1000, 4))
orthogonal_matrix = scipy.stats.ortho_group.rvs(dim=4,
random_state=seed)
orthogonal_matrix = scipy.stats.ortho_group.rvs(dim=4, random_state=seed)
labels_100_1d = np.random.uniform(0, 1, (1000, 1))

alignment_model = cebra_data_helper.OrthogonalProcrustesAlignment()
Expand All @@ -175,14 +173,14 @@ def test_orthogonal_alignment(seed):
orthogonal_matrix),
ref_label=labels_100_1d,
label=labels_100_1d)
assert np.allclose(aligned_embedding, embedding_100_4d, atol = 0.03)
assert np.allclose(aligned_embedding, embedding_100_4d, atol=0.03)

# and without labels
aligned_embedding = alignment_model.fit_transform(ref_data=embedding_100_4d,
data=np.dot(
embedding_100_4d,
orthogonal_matrix))
assert np.allclose(aligned_embedding, embedding_100_4d, atol = 0.03)
assert np.allclose(aligned_embedding, embedding_100_4d, atol=0.03)


def _initialize_embedding_ensembling_data():
Expand Down Expand Up @@ -276,8 +274,7 @@ def test_embeddings_ensembling_without_labels():
embeddings=[embedding_100_4d, embedding_100_4d_2], labels=[None, None])
joint_embedding_without_labels = cebra_data_helper.ensemble_embeddings(
embeddings=[embedding_100_4d, embedding_100_4d_2])
assert np.allclose(joint_embedding,
joint_embedding_without_labels)
assert np.allclose(joint_embedding, joint_embedding_without_labels)


@pytest.mark.parametrize("embeddings,labels,n_jobs,match",
Expand All @@ -290,16 +287,15 @@ def test_invalid_embedding_ensembling(embeddings, labels, n_jobs, match):
n_jobs=n_jobs,
)


@pytest.mark.parametrize("seed", [483, 426, 166, 674, 123])
def test_embedding_ensembling(seed):
np.random.seed(seed)
embedding_100_4d = np.random.uniform(0, 1, (100, 4))
labels_100_1d = np.random.uniform(0, 1, (100, 1))
orthogonal_matrix = scipy.stats.ortho_group.rvs(dim=4,
random_state=seed)
orthogonal_matrix = scipy.stats.ortho_group.rvs(dim=4, random_state=seed)
orthogonal_matrix_2 = scipy.stats.ortho_group.rvs(dim=4,
random_state=seed +
1)
random_state=seed + 1)

embedding_100_4d_2 = np.dot(embedding_100_4d, orthogonal_matrix)
embedding_100_4d_3 = np.dot(embedding_100_4d, orthogonal_matrix_2)
Expand All @@ -309,11 +305,11 @@ def test_embedding_ensembling(seed):
joint_embedding = cebra_data_helper.ensemble_embeddings(
embeddings=[embedding_100_4d, embedding_100_4d_2, embedding_100_4d_3],
labels=labels)
assert np.allclose(joint_embedding, embedding_100_4d, atol = 0.05)
assert np.allclose(joint_embedding, embedding_100_4d, atol=0.05)

joint_embedding = cebra_data_helper.ensemble_embeddings(
embeddings=[embedding_100_4d, embedding_100_4d_2, embedding_100_4d_3])
assert np.allclose(joint_embedding, embedding_100_4d, atol = 0.05)
assert np.allclose(joint_embedding, embedding_100_4d, atol=0.05)


@pytest.mark.benchmark
Expand Down

0 comments on commit a6d2fd1

Please sign in to comment.