Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 9, 2024
1 parent 1a71ac7 commit de8ac30
Show file tree
Hide file tree
Showing 24 changed files with 96 additions and 96 deletions.
4 changes: 2 additions & 2 deletions tests/backends/ott/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


class TestSinkhorn:
@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("jit", [False, True])
@pytest.mark.parametrize("eps", [None, 1e-2, 1e-1])
def test_matches_ott(self, x: Geom_t, eps: Optional[float], jit: bool):
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_matches_ott(self, x: Geom_t, y: Geom_t, xy: Geom_t, eps: Optional[float
assert isinstance(solver.xy, PointCloud)
np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)

@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("alpha", [0.1, 0.9])
def test_alpha(self, x: Geom_t, y: Geom_t, xy: Geom_t, alpha: float) -> None:
thresh, eps = 5e-2, 1e-1
Expand Down
34 changes: 17 additions & 17 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _close_figure():
plt.close()


@pytest.fixture()
@pytest.fixture
def x() -> Geom_t:
rng = np.random.RandomState(0)
n = 20 # number of points in the first distribution
Expand All @@ -51,7 +51,7 @@ def x() -> Geom_t:
return jnp.asarray(xs)


@pytest.fixture()
@pytest.fixture
def y() -> Geom_t:
rng = np.random.RandomState(1)
n2 = 30 # number of points in the second distribution
Expand All @@ -63,7 +63,7 @@ def y() -> Geom_t:
return jnp.asarray(xt)


@pytest.fixture()
@pytest.fixture
def xy() -> Tuple[Geom_t, Geom_t]:
rng = np.random.RandomState(2)
n = 20 # number of points in the first distribution
Expand All @@ -83,36 +83,36 @@ def xy() -> Tuple[Geom_t, Geom_t]:
return jnp.asarray(ys), jnp.asarray(yt)


@pytest.fixture()
@pytest.fixture
def ab() -> Tuple[np.ndarray, np.ndarray]:
rng = np.random.RandomState(42)
return rng.normal(size=(20, 2)), rng.normal(size=(30, 4))


@pytest.fixture()
@pytest.fixture
def x_cost(x: Geom_t) -> jnp.ndarray:
return ((x[:, None, :] - x[None, ...]) ** 2).sum(-1)


@pytest.fixture()
@pytest.fixture
def y_cost(y: Geom_t) -> jnp.ndarray:
return ((y[:, None, :] - y[None, ...]) ** 2).sum(-1)


@pytest.fixture()
@pytest.fixture
def xy_cost(xy: Geom_t) -> jnp.ndarray:
x, y = xy
return ((x[:, None, :] - y[None, ...]) ** 2).sum(-1)


@pytest.fixture()
@pytest.fixture
def adata_x(x: Geom_t) -> AnnData:
rng = np.random.RandomState(43)
pc = rng.normal(size=(len(x), 4))
return AnnData(X=np.asarray(x, dtype=float), obsm={"X_pca": pc})


@pytest.fixture()
@pytest.fixture
def adata_y(y: Geom_t) -> AnnData:
rng = np.random.RandomState(44)
pc = rng.normal(size=(len(y), 4))
Expand All @@ -126,7 +126,7 @@ def creat_prob(n: int, *, uniform: bool = False, seed: Optional[int] = None) ->
return jnp.asarray(a)


@pytest.fixture()
@pytest.fixture
def adata_time() -> AnnData:
rng = np.random.RandomState(42)

Expand Down Expand Up @@ -156,7 +156,7 @@ def adata_time() -> AnnData:
return adata


@pytest.fixture()
@pytest.fixture
def gt_temporal_adata() -> AnnData:
adata = _gt_temporal_adata.copy()
# TODO(michalk8): remove both lines once data has been regenerated
Expand All @@ -165,7 +165,7 @@ def gt_temporal_adata() -> AnnData:
return adata


@pytest.fixture()
@pytest.fixture
def adata_space_rotate() -> AnnData:
rng = np.random.RandomState(31)
grid = _make_grid(10)
Expand All @@ -182,15 +182,15 @@ def adata_space_rotate() -> AnnData:
return adata


@pytest.fixture()
@pytest.fixture
def adata_mapping() -> AnnData:
grid = _make_grid(10)
adataref, adata1, adata2 = _make_adata(grid, n=3, seed=17, cat_key="covariate", num_categories=3)
sc.pp.pca(adataref, n_comps=30)
return ad.concat([adataref, adata1, adata2], label="batch", join="outer", index_unique="-")


@pytest.fixture()
@pytest.fixture
def adata_translation() -> AnnData:
rng = np.random.RandomState(31)
adatas = [AnnData(X=csr_matrix(rng.normal(size=(100, 60)))) for _ in range(3)]
Expand All @@ -202,7 +202,7 @@ def adata_translation() -> AnnData:
return adata


@pytest.fixture()
@pytest.fixture
def adata_translation_split(adata_translation) -> Tuple[AnnData, AnnData]:
rng = np.random.RandomState(15)
adata_src = adata_translation[adata_translation.obs.batch != "0"].copy()
Expand All @@ -212,7 +212,7 @@ def adata_translation_split(adata_translation) -> Tuple[AnnData, AnnData]:
return adata_src, adata_tgt


@pytest.fixture()
@pytest.fixture
def adata_anno(
problem_kind: Literal["temporal", "cross_modality", "alignment", "mapping"],
) -> Union[AnnData, Tuple[AnnData, AnnData]]:
Expand Down Expand Up @@ -258,7 +258,7 @@ def adata_anno(
return adata


@pytest.fixture()
@pytest.fixture
def gt_tm_annotation() -> np.ndarray:
tm = np.zeros((10, 15))
for i in range(10):
Expand Down
10 changes: 5 additions & 5 deletions tests/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@


class TestSimulateData:
@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("n_distributions", [2, 4])
@pytest.mark.parametrize("key", ["batch", "day"])
def test_n_distributions(self, n_distributions: int, key: str):
adata = simulate_data(n_distributions=n_distributions, key=key)
assert key in adata.obs.columns
assert adata.obs[key].nunique() == n_distributions

@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("obs_to_add", [{"celltype": 2}, {"celltype": 5, "cluster": 4}])
def test_obs_to_add(self, obs_to_add: Mapping[str, int]):
adata = simulate_data(obs_to_add=obs_to_add)
Expand All @@ -26,7 +26,7 @@ def test_obs_to_add(self, obs_to_add: Mapping[str, int]):
assert colname in adata.obs.columns
assert adata.obs[colname].nunique() == k

@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("spatial_dim", [None, 2, 3])
def test_quad_term_spatial(self, spatial_dim: Optional[int]):
kwargs = {}
Expand All @@ -40,7 +40,7 @@ def test_quad_term_spatial(self, spatial_dim: Optional[int]):
else:
assert adata.obsm["spatial"].shape[1] == spatial_dim

@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("n_intBCs", [None, 4, 7])
@pytest.mark.parametrize("barcode_dim", [None, 5, 8])
def test_quad_term_barcode(self, n_intBCs: Optional[int], barcode_dim: Optional[int]):
Expand All @@ -63,7 +63,7 @@ def test_quad_term_barcode(self, n_intBCs: Optional[int], barcode_dim: Optional[
else:
assert len(np.unique(adata.obsm["barcode"])) <= n_intBCs

@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("n_initial_nodes", [None, 4, 7])
@pytest.mark.parametrize("n_distributions", [3, 6])
def test_quad_term_tree(self, n_initial_nodes: Optional[int], n_distributions: int):
Expand Down
8 changes: 4 additions & 4 deletions tests/plotting/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
DPI = 40


@pytest.fixture()
@pytest.fixture
def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData:
plot_vars = {
"transition_matrix": gt_temporal_adata.uns["cell_transition_10_105_forward"],
Expand All @@ -38,7 +38,7 @@ def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData:
return gt_temporal_adata


@pytest.fixture()
@pytest.fixture
def adata_pl_push(adata_time: AnnData) -> AnnData:
rng = np.random.RandomState(0)
plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1}
Expand All @@ -57,7 +57,7 @@ def adata_pl_push(adata_time: AnnData) -> AnnData:
return adata_time


@pytest.fixture()
@pytest.fixture
def adata_pl_pull(adata_time: AnnData) -> AnnData:
rng = np.random.RandomState(0)
plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1}
Expand All @@ -75,7 +75,7 @@ def adata_pl_pull(adata_time: AnnData) -> AnnData:
return adata_time


@pytest.fixture()
@pytest.fixture
def adata_pl_sankey(adata_time: AnnData) -> AnnData:
rng = np.random.RandomState(0)
celltypes = ["A", "B", "C", "D", "E"]
Expand Down
12 changes: 6 additions & 6 deletions tests/problems/base/test_compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_sc_pipeline(self, adata_time: AnnData):
assert problem[key].solution is problem.solutions[key]

@pytest.mark.parametrize("scale", [True, False])
@pytest.mark.fast()
@pytest.mark.fast
def test_default_callback(self, adata_time: AnnData, mocker: MockerFixture, scale: bool):
subproblem = OTProblem(adata_time, adata_tgt=adata_time.copy())
xy_callback_kwargs = {"n_comps": 5, "scale": scale}
Expand All @@ -88,7 +88,7 @@ def test_default_callback(self, adata_time: AnnData, mocker: MockerFixture, scal
assert isinstance(problem.problems, dict)
spy.assert_called_with("xy", subproblem.adata_src, subproblem.adata_tgt, **xy_callback_kwargs)

@pytest.mark.fast()
@pytest.mark.fast
def test_custom_callback_lin(self, adata_time: AnnData, mocker: MockerFixture):
expected_keys = [(0, 1), (1, 2)]
spy = mocker.spy(TestCompoundProblem, "xy_callback")
Expand All @@ -106,7 +106,7 @@ def test_custom_callback_lin(self, adata_time: AnnData, mocker: MockerFixture):

assert spy.call_count == len(expected_keys)

@pytest.mark.fast()
@pytest.mark.fast
def test_custom_callback_quad(self, adata_time: AnnData, mocker: MockerFixture):
expected_keys = [(0, 1), (1, 2)]
spy_x = mocker.spy(TestCompoundProblem, "x_callback")
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_different_passings_linear(self, adata_with_cost_matrix: AnnData):
np.testing.assert_allclose(gt.matrix, p1_tmap, rtol=RTOL, atol=ATOL)
np.testing.assert_allclose(gt.matrix, p2_tmap, rtol=RTOL, atol=ATOL)

@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("cost", [("sq_euclidean", SqEuclidean), ("euclidean", Euclidean), ("cosine", Cosine)])
def test_prepare_cost(self, adata_time: AnnData, cost: Tuple[str, Any]):
problem = Problem(adata=adata_time)
Expand All @@ -179,7 +179,7 @@ def test_prepare_cost(self, adata_time: AnnData, cost: Tuple[str, Any]):
assert isinstance(problem[0, 1].x.cost, cost[1])
assert isinstance(problem[0, 1].y.cost, cost[1])

@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("cost", [("sq_euclidean", SqEuclidean), ("euclidean", Euclidean), ("cosine", Cosine)])
def test_prepare_cost_with_callback(self, adata_time: AnnData, cost: Tuple[str, Any]):
problem = Problem(adata=adata_time)
Expand All @@ -196,7 +196,7 @@ def test_prepare_cost_with_callback(self, adata_time: AnnData, cost: Tuple[str,
assert isinstance(problem[0, 1].x.cost, cost[1])
assert isinstance(problem[0, 1].y.cost, cost[1])

@pytest.mark.fast()
@pytest.mark.fast
def test_prepare_different_costs(self, adata_time: AnnData):
problem = Problem(adata=adata_time)
problem = problem.prepare(
Expand Down
2 changes: 1 addition & 1 deletion tests/problems/base/test_general_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_simple_run(self, adata_x: AnnData, adata_y: AnnData):

assert isinstance(prob.solution, BaseDiscreteSolverOutput)

@pytest.mark.fast()
@pytest.mark.fast
def test_output(self, adata_x: AnnData, x: Geom_t):
problem = OTProblem(adata_x)
problem._solution = MockSolverOutput(x * x.T)
Expand Down
4 changes: 2 additions & 2 deletions tests/problems/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tests._utils import Geom_t


@pytest.fixture()
@pytest.fixture
def adata_with_cost_matrix(adata_x: Geom_t, adata_y: Geom_t) -> AnnData:
adata = ad.concat([adata_x, adata_y], label="batch", index_unique="-")
C = pairwise_distances(adata_x.obsm["X_pca"], adata_y.obsm["X_pca"]) ** 2
Expand All @@ -19,7 +19,7 @@ def adata_with_cost_matrix(adata_x: Geom_t, adata_y: Geom_t) -> AnnData:
return adata


@pytest.fixture()
@pytest.fixture
def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
adata = adata_time[adata_time.obs["time"].isin([0, 1])].copy()
rng = np.random.RandomState(42)
Expand Down
4 changes: 2 additions & 2 deletions tests/problems/cross_modality/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_translate_alternative(
trans_backward = tp.translate(source=src, target=tgt, forward=False, alternative_attr=alternative_attr)
assert trans_backward.shape == adata_src[adata_src.obs["batch"] == "1"].obsm["X_pca"].shape

@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("forward", [True, False])
@pytest.mark.parametrize("normalize", [True, False])
def test_cell_transition_pipeline(
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_cell_transition_pipeline(
with pytest.raises(AssertionError):
pd.testing.assert_frame_equal(result1, result2)

@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("forward", [True, False])
@pytest.mark.parametrize("mapping_mode", ["max", "sum"])
@pytest.mark.parametrize("batch_size", [3, 7, None])
Expand Down
4 changes: 2 additions & 2 deletions tests/problems/cross_modality/test_translation_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


class TestTranslationProblem:
@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}])
@pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}])
@pytest.mark.parametrize("joint_attr", [None, "X_pca", {"attr": "obsm", "key": "X_pca"}])
Expand All @@ -50,7 +50,7 @@ def test_prepare_dummy_policy(
assert tp[prob_key].shape == (2 * n_obs, n_obs)
np.testing.assert_array_equal(tp._policy._cat, prob_key)

@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}])
@pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}])
@pytest.mark.parametrize("joint_attr", [None, "X_pca", {"attr": "obsm", "key": "X_pca"}])
Expand Down
2 changes: 1 addition & 1 deletion tests/problems/generic/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from anndata import AnnData


@pytest.fixture()
@pytest.fixture
def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
adata = adata_time[adata_time.obs["time"].isin([0, 1])].copy()
rng = np.random.RandomState(42)
Expand Down
2 changes: 1 addition & 1 deletion tests/problems/generic/test_conditional_neural_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class TestGENOTLinProblem:
@pytest.mark.fast()
@pytest.mark.fast
def test_prepare(self, adata_time: ad.AnnData):
problem = GENOTLinProblem(adata=adata_time)
problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"})
Expand Down
Loading

0 comments on commit de8ac30

Please sign in to comment.