Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding _annotation_mapping in AnalysisMixin #585

Merged
merged 85 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 80 commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
e6651fd
_celltype_mapping draft
Jul 19, 2023
68c1a01
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2023
cdc3bc5
Merge branch 'main' into add/celltype_mapping
ArinaDanilina Jul 19, 2023
c1acb13
exposing celltype_mapping in SpatialMapping and SpatialAlignment
Jul 20, 2023
4324afa
fix ruff ?
Jul 20, 2023
2567c0b
fix conflict
Jul 20, 2023
abcb8fc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2023
c89236f
fixes for mypy and ruff
Jul 20, 2023
8a9138d
renamin, adding function to protocol
Jul 20, 2023
4539312
ruff fix?
Jul 21, 2023
d8f7714
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 21, 2023
954956c
Merge branch 'main' into add/celltype_mapping
ArinaDanilina Jul 26, 2023
4944db6
adding cell_transition_kwargs
Jul 31, 2023
c2a2466
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2023
268e2c5
Merge branch 'main' into add/celltype_mapping
Jul 31, 2023
d10f82f
ruff and mypy fix ?
Jul 31, 2023
6bb5e20
Merge branch 'main' into add/celltype_mapping
Aug 12, 2023
e63eda8
Merge branch 'main' into add/celltype_mapping
Aug 23, 2023
17d48b0
anno_mapping changes
Sep 15, 2023
051270b
Merge branch 'main' into add/celltype_mapping
ArinaDanilina Sep 27, 2023
6e8442e
Merge branch 'add/celltype_mapping' of https://github.com/theislab/mo…
Sep 27, 2023
86ed240
Merge branch 'main' into add/celltype_mapping
ArinaDanilina Oct 10, 2023
c770d1d
Merge branch 'main' into add/celltype_mapping
giovp Oct 10, 2023
956d7c6
Merge branch 'add/celltype_mapping' of https://github.com/theislab/mo…
Oct 19, 2023
713e8e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 19, 2023
b2a97c2
removed source and target groups
Oct 19, 2023
e81a681
merge
Oct 19, 2023
b44ec15
Merge branch 'main' into add/celltype_mapping
Oct 19, 2023
4f5a2b0
Merge branch 'main' into add/celltype_mapping
ArinaDanilina Oct 20, 2023
f53b8be
Merge branch 'main' into add/celltype_mapping
giovp Oct 24, 2023
189d468
key_added logic
Nov 2, 2023
47b8ec5
key_added logic
Nov 2, 2023
d434bfa
Merge branch 'add/celltype_mapping' of https://github.com/theislab/mo…
Nov 2, 2023
7ac7e26
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2023
02109f0
Merge branch 'main' into add/celltype_mapping
ArinaDanilina Nov 16, 2023
5f5bf26
anno_map progress
Nov 30, 2023
77605cf
Merge branch 'add/celltype_mapping' of https://github.com/theislab/mo…
Nov 30, 2023
94c43b2
Merge branch 'main' into add/celltype_mapping
Nov 30, 2023
f7accd4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2023
7478927
Merge branch 'add/celltype_mapping' of https://github.com/theislab/mo…
Nov 30, 2023
c408f37
Merge branch 'main' into add/celltype_mapping
ArinaDanilina Dec 1, 2023
69c3e76
mp fix, ap and tp added
Dec 7, 2023
878a79e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2023
2366a32
make sum for temporal problem work
giovp Dec 11, 2023
5bd6696
fix max for temporal
giovp Dec 11, 2023
5b3d508
before merge
Dec 11, 2023
4651972
Merge branch 'main' into add/celltype_mapping
giovp Dec 11, 2023
09a3920
merge tp fix and added in cross_modality
Dec 12, 2023
415c40a
Merge branch 'add/celltype_mapping' of https://github.com/theislab/mo…
Dec 12, 2023
338e75a
passing batch key ap
Dec 12, 2023
5a6d538
cleaned arguments, forward in mp
Dec 14, 2023
da8eac1
fix for cross_modality and general label handling in max
Dec 19, 2023
10a79fe
handling annottion labels
Dec 19, 2023
d3f435b
update
giovp Dec 19, 2023
3a37762
update
giovp Dec 19, 2023
72e545d
update
giovp Dec 19, 2023
284c83d
update
Dec 20, 2023
fc7aba0
update
giovp Dec 20, 2023
e26266d
shorter returns
Dec 26, 2023
96dd1d5
fix for temporal cell_transition
Jan 7, 2024
039ce1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2024
4c694ee
gt for annotation tests
Jan 9, 2024
d6968f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2024
112af8e
only passing tests
Jan 10, 2024
f370bc1
Merge branch 'add/celltype_mapping' of https://github.com/theislab/mo…
Jan 10, 2024
e0efc2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
c055c8e
Merge branch 'main' into add/celltype_mapping
ArinaDanilina Jan 10, 2024
01637c8
ruff type annotation
Jan 10, 2024
2d9b73e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
6a5ad0b
Revert "ruff type annotation"
Jan 10, 2024
1cb646d
revert
Jan 10, 2024
93b2e8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
4d3c8db
fully passing tests
Jan 17, 2024
e2b9c5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 17, 2024
e1608b0
some mypy fixes
Jan 18, 2024
ab89a42
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2024
57c2649
ruff typing
Jan 18, 2024
2b8aa48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2024
1586794
docstrings
Jan 18, 2024
e13ef40
lint docstrings
Jan 18, 2024
a775e70
unexpose scale_by_marginals and edits
Jan 19, 2024
93b561f
Update src/moscot/problems/cross_modality/_mixins.py
ArinaDanilina Jan 19, 2024
8fb3776
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
557139d
returns -> Returns
Jan 19, 2024
0cd1ae5
returns -> Returns
Jan 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ ignore = [
"D107",
# Missing docstring in magic method
"D105",
# Use `X | Y` for type annotations
"UP007",
]
line-length = 120
select = [
Expand Down
156 changes: 133 additions & 23 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

import types
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
Iterable,
List,
Literal,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Union,
)

Expand Down Expand Up @@ -45,8 +46,8 @@ class AnalysisMixinProtocol(Protocol[K, B]):

adata: AnnData
_policy: SubsetPolicy[K]
solutions: Dict[Tuple[K, K], BaseSolverOutput]
problems: Dict[Tuple[K, K], B]
solutions: dict[tuple[K, K], BaseSolverOutput]
problems: dict[tuple[K, K], B]

def _apply(
self,
Expand All @@ -61,15 +62,15 @@ def _apply(
...

def _interpolate_transport(
self: "AnalysisMixinProtocol[K, B]",
path: Sequence[Tuple[K, K]],
self: AnalysisMixinProtocol[K, B],
path: Sequence[tuple[K, K]],
scale_by_marginals: bool = True,
) -> LinearOperator:
...

def _flatten(
self: "AnalysisMixinProtocol[K, B]",
data: Dict[K, ArrayLike],
self: AnalysisMixinProtocol[K, B],
data: dict[K, ArrayLike],
*,
key: Optional[str],
) -> ArrayLike:
Expand All @@ -83,8 +84,20 @@ def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]:
"""Pull distribution."""
...

def _cell_transition(
giovp marked this conversation as resolved.
Show resolved Hide resolved
self: AnalysisMixinProtocol[K, B],
source: K,
target: K,
source_groups: Str_Dict_t,
target_groups: Str_Dict_t,
aggregation_mode: Literal["annotation", "cell"] = "annotation",
key_added: Optional[str] = _constants.CELL_TRANSITION,
**kwargs: Any,
) -> pd.DataFrame:
...

def _cell_transition_online(
self: "AnalysisMixinProtocol[K, B]",
self: AnalysisMixinProtocol[K, B],
key: Optional[str],
source: K,
target: K,
Expand All @@ -99,6 +112,20 @@ def _cell_transition_online(
) -> pd.DataFrame:
...

def _annotation_mapping(
self: AnalysisMixinProtocol[K, B],
mapping_mode: Literal["sum", "max"],
annotation_label: str,
forward: bool,
source: K,
target: K,
key: str | None = None,
other_adata: Optional[str] = None,
scale_by_marginals: bool = True,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> pd.DataFrame:
...


class AnalysisMixin(Generic[K, B]):
"""Base Analysis Mixin."""
Expand All @@ -122,7 +149,6 @@ def _cell_transition(
)
if aggregation_mode == "cell" and source_groups is None and target_groups is None:
raise ValueError("At least one of `source_groups` and `target_group` must be specified.")

_check_argument_compatibility_cell_transition(
source_annotation=source_groups,
target_annotation=target_groups,
Expand Down Expand Up @@ -179,13 +205,13 @@ def _cell_transition_online(
)
df_source = _get_df_cell_transition(
self.adata,
[source_annotation_key, target_annotation_key],
[source_annotation_key] if aggregation_mode == "cell" else [source_annotation_key, target_annotation_key],
key,
source,
)
df_target = _get_df_cell_transition(
self.adata if other_adata is None else other_adata,
[source_annotation_key, target_annotation_key],
[target_annotation_key] if aggregation_mode == "cell" else [source_annotation_key, target_annotation_key],
key if other_adata is None else other_key,
target,
)
Expand Down Expand Up @@ -273,6 +299,90 @@ def _cell_transition_online(
forward=forward,
)

def _annotation_mapping(
self: AnalysisMixinProtocol[K, B],
mapping_mode: Literal["sum", "max"],
annotation_label: str,
source: K,
target: K,
key: str | None = None,
forward: bool = True,
other_adata: str | None = None,
scale_by_marginals: bool = True,
batch_size: int | None = None,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> pd.DataFrame:
if mapping_mode == "sum":
cell_transition_kwargs = dict(cell_transition_kwargs)
cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell
cell_transition_kwargs.setdefault("key", key)
cell_transition_kwargs.setdefault("source", source)
cell_transition_kwargs.setdefault("target", target)
cell_transition_kwargs.setdefault("other_adata", other_adata)
cell_transition_kwargs.setdefault("forward", not forward)
if forward:
cell_transition_kwargs.setdefault("source_groups", annotation_label)
cell_transition_kwargs.setdefault("target_groups", None)
axis = 0 # rows
else:
cell_transition_kwargs.setdefault("source_groups", None)
cell_transition_kwargs.setdefault("target_groups", annotation_label)
axis = 1 # columns
out: pd.DataFrame = self._cell_transition(**cell_transition_kwargs)
return out.idxmax(axis=axis).to_frame(name=annotation_label)
if mapping_mode == "max":
out = []
if forward:
source_df = _get_df_cell_transition(
self.adata,
annotation_keys=[annotation_label],
filter_key=key,
filter_value=source,
)
out_len = self.solutions[(source, target)].shape[1]
batch_size = batch_size if batch_size is not None else out_len
for batch in range(0, out_len, batch_size):
tm_batch: ArrayLike = self.push(
source=source,
target=target,
data=None,
subset=(batch, batch_size),
normalize=True,
return_all=False,
scale_by_marginals=scale_by_marginals,
split_mass=True,
key_added=None,
)
v = np.array(tm_batch.argmax(1))
out.extend(source_df[annotation_label][v[i]] for i in range(len(v)))

else:
target_df = _get_df_cell_transition(
self.adata if other_adata is None else other_adata,
annotation_keys=[annotation_label],
filter_key=key,
filter_value=target,
)
out_len = self.solutions[(source, target)].shape[0]
batch_size = batch_size if batch_size is not None else out_len
for batch in range(0, out_len, batch_size):
tm_batch: ArrayLike = self.pull( # type: ignore[no-redef]
source=source,
target=target,
data=None,
subset=(batch, batch_size),
normalize=True,
return_all=False,
scale_by_marginals=scale_by_marginals,
split_mass=True,
key_added=None,
)
v = np.array(tm_batch.argmax(1))
out.extend(target_df[annotation_label][v[i]] for i in range(len(v)))
categories = pd.Categorical(out)
return pd.DataFrame(categories, columns=[annotation_label])
raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.")

ArinaDanilina marked this conversation as resolved.
Show resolved Hide resolved
def _sample_from_tmap(
self: AnalysisMixinProtocol[K, B],
source: K,
Expand All @@ -284,7 +394,7 @@ def _sample_from_tmap(
account_for_unbalancedness: bool = False,
interpolation_parameter: Optional[Numeric_t] = None,
seed: Optional[int] = None,
) -> Tuple[List[Any], List[ArrayLike]]:
) -> tuple[list[Any], list[ArrayLike]]:
rng = np.random.RandomState(seed)
if account_for_unbalancedness and interpolation_parameter is None:
raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.")
Expand Down Expand Up @@ -321,7 +431,7 @@ def _sample_from_tmap(

rows_sampled = rng.choice(source_dim, p=row_probability / row_probability.sum(), size=n_samples)
rows, counts = np.unique(rows_sampled, return_counts=True)
all_cols_sampled: List[str] = []
all_cols_sampled: list[str] = []
for batch in range(0, len(rows), batch_size):
rows_batch = rows[batch : batch + batch_size]
counts_batch = counts[batch : batch + batch_size]
Expand Down Expand Up @@ -354,7 +464,7 @@ def _sample_from_tmap(
def _interpolate_transport(
self: AnalysisMixinProtocol[K, B],
# TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key)
path: Sequence[Tuple[K, K]],
path: Sequence[tuple[K, K]],
scale_by_marginals: bool = True,
**_: Any,
) -> LinearOperator:
Expand All @@ -365,7 +475,7 @@ def _interpolate_transport(
fst, *rest = path
return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals)

def _flatten(self: AnalysisMixinProtocol[K, B], data: Dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike:
def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike:
tmp = np.full(len(self.adata), np.nan)
for k, v in data.items():
mask = self.adata.obs[key] == k
Expand All @@ -377,8 +487,8 @@ def _annotation_aggregation_transition(
source: K,
target: K,
annotation_key: str,
annotations_1: List[Any],
annotations_2: List[Any],
annotations_1: list[Any],
annotations_2: list[Any],
df: pd.DataFrame,
tm: pd.DataFrame,
forward: bool,
Expand Down Expand Up @@ -413,8 +523,8 @@ def _cell_aggregation_transition(
target: str,
annotation_key: str,
# TODO(MUCDK): unused variables, del below
annotations_1: List[Any],
annotations_2: List[Any],
annotations_1: list[Any],
annotations_2: list[Any],
df_1: pd.DataFrame,
df_2: pd.DataFrame,
tm: pd.DataFrame,
Expand Down Expand Up @@ -450,9 +560,9 @@ def compute_feature_correlation(
obs_key: str,
corr_method: Literal["pearson", "spearman"] = "pearson",
significance_method: Literal["fisher", "perm_test"] = "fisher",
annotation: Optional[Dict[str, Iterable[str]]] = None,
annotation: Optional[dict[str, Iterable[str]]] = None,
layer: Optional[str] = None,
features: Optional[Union[List[str], Literal["human", "mouse", "drosophila"]]] = None,
features: Optional[Union[list[str], Literal["human", "mouse", "drosophila"]]] = None,
confidence_level: float = 0.95,
n_perms: int = 1000,
seed: Optional[int] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/base/problems/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _check_argument_compatibility_cell_transition(
raise ValueError("Unable to infer distributions, missing `adata` and `key`.")
if forward and target_annotation is None:
raise ValueError("No target annotation provided.")
if not forward and source_annotation is None:
if aggregation_mode == "annotation" and (not forward and source_annotation is None):
raise ValueError("No source annotation provided.")
if (aggregation_mode == "annotation") and (source_annotation is None or target_annotation is None):
raise ValueError(
Expand Down
60 changes: 59 additions & 1 deletion src/moscot/problems/cross_modality/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional
import types
from typing import TYPE_CHECKING, Any, Dict, Literal, Mapping, Optional

import pandas as pd

Expand All @@ -24,6 +25,9 @@ class CrossModalityTranslationMixinProtocol(AnalysisMixinProtocol[K, B]):
def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame:
...

def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame:
...


class CrossModalityTranslationMixin(AnalysisMixin[K, B]):
"""Cross modality translation analysis mixin class."""
Expand Down Expand Up @@ -183,6 +187,60 @@ def cell_transition( # type: ignore[misc]
key_added=key_added,
)

def annotation_mapping( # type: ignore[misc]
self: CrossModalityTranslationMixinProtocol[K, B],
mapping_mode: Literal["sum", "max"],
annotation_label: str,
forward: bool,
source: str = "src",
target: str = "tgt",
scale_by_marginals: bool = True,
ArinaDanilina marked this conversation as resolved.
Show resolved Hide resolved
other_adata: Optional[str] = None,
ArinaDanilina marked this conversation as resolved.
Show resolved Hide resolved
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> pd.DataFrame:
"""Transfer annotations between distributions.

This function transfers annotation labels (e.g. cell types) between groups of cells.
ArinaDanilina marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
mapping_mode
How to decide which label to transfer. Valid options are:

- ``'max'`` - pick the label of the annotated cell with the highest mapping weight.
ArinaDanilina marked this conversation as resolved.
Show resolved Hide resolved
- ``'sum'`` - compute :meth:`cell_transition` of annotation labels to target cells and
pick the label with the highest transition weight.
annotation_label
Key in :attr:`~anndata.AnnData.obs` where the annotation is stored.
forward
If :obj:`True`, transfer the annotations from ``source`` to ``target``.
source
Key identifying the source distribution.
target
Key identifying the target distribution.
scale_by_marginals
Whether to scale by the source :term:`marginals`.
other_adata
ArinaDanilina marked this conversation as resolved.
Show resolved Hide resolved
The second :class:`anndata.AnnData` if present.
cell_transition_kwargs
Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``.

Returns
-------
:class:`~pandas.DataFrame`. - returns the DataFrame of transferred annotations.
ArinaDanilina marked this conversation as resolved.
Show resolved Hide resolved
"""
return self._annotation_mapping(
mapping_mode=mapping_mode,
annotation_label=annotation_label,
source=source,
target=target,
key=self.batch_key,
forward=forward,
other_adata=self.adata_tgt if other_adata is None else other_adata,
scale_by_marginals=scale_by_marginals,
cell_transition_kwargs=cell_transition_kwargs,
)

@property
def batch_key(self) -> Optional[str]:
"""Batch key in :attr:`~anndata.AnnData.obs`."""
Expand Down
Loading