Skip to content

Commit

Permalink
Merge pull request #585 from theislab/add/celltype_mapping
Browse files Browse the repository at this point in the history
adding _annotation_mapping in AnalysisMixin, temporal, spatial and cross_modality problems
  • Loading branch information
ArinaDanilina authored Jan 19, 2024
2 parents 924c78a + 0cd1ae5 commit 7a4883e
Show file tree
Hide file tree
Showing 11 changed files with 536 additions and 49 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ ignore = [
"D107",
# Missing docstring in magic method
"D105",
# Use `X | Y` for type annotations
"UP007",
]
line-length = 120
select = [
Expand Down
156 changes: 133 additions & 23 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

import types
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
Iterable,
List,
Literal,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Union,
)

Expand Down Expand Up @@ -45,8 +46,8 @@ class AnalysisMixinProtocol(Protocol[K, B]):

adata: AnnData
_policy: SubsetPolicy[K]
solutions: Dict[Tuple[K, K], BaseSolverOutput]
problems: Dict[Tuple[K, K], B]
solutions: dict[tuple[K, K], BaseSolverOutput]
problems: dict[tuple[K, K], B]

def _apply(
self,
Expand All @@ -61,15 +62,15 @@ def _apply(
...

def _interpolate_transport(
self: "AnalysisMixinProtocol[K, B]",
path: Sequence[Tuple[K, K]],
self: AnalysisMixinProtocol[K, B],
path: Sequence[tuple[K, K]],
scale_by_marginals: bool = True,
) -> LinearOperator:
...

def _flatten(
self: "AnalysisMixinProtocol[K, B]",
data: Dict[K, ArrayLike],
self: AnalysisMixinProtocol[K, B],
data: dict[K, ArrayLike],
*,
key: Optional[str],
) -> ArrayLike:
Expand All @@ -83,8 +84,20 @@ def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]:
"""Pull distribution."""
...

def _cell_transition(
self: AnalysisMixinProtocol[K, B],
source: K,
target: K,
source_groups: Str_Dict_t,
target_groups: Str_Dict_t,
aggregation_mode: Literal["annotation", "cell"] = "annotation",
key_added: Optional[str] = _constants.CELL_TRANSITION,
**kwargs: Any,
) -> pd.DataFrame:
...

def _cell_transition_online(
self: "AnalysisMixinProtocol[K, B]",
self: AnalysisMixinProtocol[K, B],
key: Optional[str],
source: K,
target: K,
Expand All @@ -99,6 +112,20 @@ def _cell_transition_online(
) -> pd.DataFrame:
...

def _annotation_mapping(
self: AnalysisMixinProtocol[K, B],
mapping_mode: Literal["sum", "max"],
annotation_label: str,
forward: bool,
source: K,
target: K,
key: str | None = None,
other_adata: Optional[str] = None,
scale_by_marginals: bool = True,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> pd.DataFrame:
...


class AnalysisMixin(Generic[K, B]):
"""Base Analysis Mixin."""
Expand All @@ -122,7 +149,6 @@ def _cell_transition(
)
if aggregation_mode == "cell" and source_groups is None and target_groups is None:
raise ValueError("At least one of `source_groups` and `target_group` must be specified.")

_check_argument_compatibility_cell_transition(
source_annotation=source_groups,
target_annotation=target_groups,
Expand Down Expand Up @@ -179,13 +205,13 @@ def _cell_transition_online(
)
df_source = _get_df_cell_transition(
self.adata,
[source_annotation_key, target_annotation_key],
[source_annotation_key] if aggregation_mode == "cell" else [source_annotation_key, target_annotation_key],
key,
source,
)
df_target = _get_df_cell_transition(
self.adata if other_adata is None else other_adata,
[source_annotation_key, target_annotation_key],
[target_annotation_key] if aggregation_mode == "cell" else [source_annotation_key, target_annotation_key],
key if other_adata is None else other_key,
target,
)
Expand Down Expand Up @@ -273,6 +299,90 @@ def _cell_transition_online(
forward=forward,
)

def _annotation_mapping(
self: AnalysisMixinProtocol[K, B],
mapping_mode: Literal["sum", "max"],
annotation_label: str,
source: K,
target: K,
key: str | None = None,
forward: bool = True,
other_adata: str | None = None,
scale_by_marginals: bool = True,
batch_size: int | None = None,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> pd.DataFrame:
if mapping_mode == "sum":
cell_transition_kwargs = dict(cell_transition_kwargs)
cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell
cell_transition_kwargs.setdefault("key", key)
cell_transition_kwargs.setdefault("source", source)
cell_transition_kwargs.setdefault("target", target)
cell_transition_kwargs.setdefault("other_adata", other_adata)
cell_transition_kwargs.setdefault("forward", not forward)
if forward:
cell_transition_kwargs.setdefault("source_groups", annotation_label)
cell_transition_kwargs.setdefault("target_groups", None)
axis = 0 # rows
else:
cell_transition_kwargs.setdefault("source_groups", None)
cell_transition_kwargs.setdefault("target_groups", annotation_label)
axis = 1 # columns
out: pd.DataFrame = self._cell_transition(**cell_transition_kwargs)
return out.idxmax(axis=axis).to_frame(name=annotation_label)
if mapping_mode == "max":
out = []
if forward:
source_df = _get_df_cell_transition(
self.adata,
annotation_keys=[annotation_label],
filter_key=key,
filter_value=source,
)
out_len = self.solutions[(source, target)].shape[1]
batch_size = batch_size if batch_size is not None else out_len
for batch in range(0, out_len, batch_size):
tm_batch: ArrayLike = self.push(
source=source,
target=target,
data=None,
subset=(batch, batch_size),
normalize=True,
return_all=False,
scale_by_marginals=scale_by_marginals,
split_mass=True,
key_added=None,
)
v = np.array(tm_batch.argmax(1))
out.extend(source_df[annotation_label][v[i]] for i in range(len(v)))

else:
target_df = _get_df_cell_transition(
self.adata if other_adata is None else other_adata,
annotation_keys=[annotation_label],
filter_key=key,
filter_value=target,
)
out_len = self.solutions[(source, target)].shape[0]
batch_size = batch_size if batch_size is not None else out_len
for batch in range(0, out_len, batch_size):
tm_batch: ArrayLike = self.pull( # type: ignore[no-redef]
source=source,
target=target,
data=None,
subset=(batch, batch_size),
normalize=True,
return_all=False,
scale_by_marginals=scale_by_marginals,
split_mass=True,
key_added=None,
)
v = np.array(tm_batch.argmax(1))
out.extend(target_df[annotation_label][v[i]] for i in range(len(v)))
categories = pd.Categorical(out)
return pd.DataFrame(categories, columns=[annotation_label])
raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.")

def _sample_from_tmap(
self: AnalysisMixinProtocol[K, B],
source: K,
Expand All @@ -284,7 +394,7 @@ def _sample_from_tmap(
account_for_unbalancedness: bool = False,
interpolation_parameter: Optional[Numeric_t] = None,
seed: Optional[int] = None,
) -> Tuple[List[Any], List[ArrayLike]]:
) -> tuple[list[Any], list[ArrayLike]]:
rng = np.random.RandomState(seed)
if account_for_unbalancedness and interpolation_parameter is None:
raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.")
Expand Down Expand Up @@ -321,7 +431,7 @@ def _sample_from_tmap(

rows_sampled = rng.choice(source_dim, p=row_probability / row_probability.sum(), size=n_samples)
rows, counts = np.unique(rows_sampled, return_counts=True)
all_cols_sampled: List[str] = []
all_cols_sampled: list[str] = []
for batch in range(0, len(rows), batch_size):
rows_batch = rows[batch : batch + batch_size]
counts_batch = counts[batch : batch + batch_size]
Expand Down Expand Up @@ -354,7 +464,7 @@ def _sample_from_tmap(
def _interpolate_transport(
self: AnalysisMixinProtocol[K, B],
# TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key)
path: Sequence[Tuple[K, K]],
path: Sequence[tuple[K, K]],
scale_by_marginals: bool = True,
**_: Any,
) -> LinearOperator:
Expand All @@ -365,7 +475,7 @@ def _interpolate_transport(
fst, *rest = path
return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals)

def _flatten(self: AnalysisMixinProtocol[K, B], data: Dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike:
def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike:
tmp = np.full(len(self.adata), np.nan)
for k, v in data.items():
mask = self.adata.obs[key] == k
Expand All @@ -377,8 +487,8 @@ def _annotation_aggregation_transition(
source: K,
target: K,
annotation_key: str,
annotations_1: List[Any],
annotations_2: List[Any],
annotations_1: list[Any],
annotations_2: list[Any],
df: pd.DataFrame,
tm: pd.DataFrame,
forward: bool,
Expand Down Expand Up @@ -413,8 +523,8 @@ def _cell_aggregation_transition(
target: str,
annotation_key: str,
# TODO(MUCDK): unused variables, del below
annotations_1: List[Any],
annotations_2: List[Any],
annotations_1: list[Any],
annotations_2: list[Any],
df_1: pd.DataFrame,
df_2: pd.DataFrame,
tm: pd.DataFrame,
Expand Down Expand Up @@ -450,9 +560,9 @@ def compute_feature_correlation(
obs_key: str,
corr_method: Literal["pearson", "spearman"] = "pearson",
significance_method: Literal["fisher", "perm_test"] = "fisher",
annotation: Optional[Dict[str, Iterable[str]]] = None,
annotation: Optional[dict[str, Iterable[str]]] = None,
layer: Optional[str] = None,
features: Optional[Union[List[str], Literal["human", "mouse", "drosophila"]]] = None,
features: Optional[Union[list[str], Literal["human", "mouse", "drosophila"]]] = None,
confidence_level: float = 0.95,
n_perms: int = 1000,
seed: Optional[int] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/base/problems/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _check_argument_compatibility_cell_transition(
raise ValueError("Unable to infer distributions, missing `adata` and `key`.")
if forward and target_annotation is None:
raise ValueError("No target annotation provided.")
if not forward and source_annotation is None:
if aggregation_mode == "annotation" and (not forward and source_annotation is None):
raise ValueError("No source annotation provided.")
if (aggregation_mode == "annotation") and (source_annotation is None or target_annotation is None):
raise ValueError(
Expand Down
53 changes: 52 additions & 1 deletion src/moscot/problems/cross_modality/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional
import types
from typing import TYPE_CHECKING, Any, Dict, Literal, Mapping, Optional

import pandas as pd

Expand All @@ -24,6 +25,9 @@ class CrossModalityTranslationMixinProtocol(AnalysisMixinProtocol[K, B]):
def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame:
...

def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame:
...


class CrossModalityTranslationMixin(AnalysisMixin[K, B]):
"""Cross modality translation analysis mixin class."""
Expand Down Expand Up @@ -183,6 +187,53 @@ def cell_transition( # type: ignore[misc]
key_added=key_added,
)

def annotation_mapping( # type: ignore[misc]
self: CrossModalityTranslationMixinProtocol[K, B],
mapping_mode: Literal["sum", "max"],
annotation_label: str,
forward: bool,
source: str = "src",
target: str = "tgt",
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> pd.DataFrame:
"""Transfer annotations between distributions.
This function transfers annotations (e.g. cell type labels) between distributions of cells.
Parameters
----------
mapping_mode
How to decide which label to transfer. Valid options are:
- ``'max'`` - pick the label of the annotated cell with the highest matching probability.
- ``'sum'`` - aggregate the annotated cells by label then
pick the label with the highest total matching probability.
annotation_label
Key in :attr:`~anndata.AnnData.obs` where the annotation is stored.
forward
If :obj:`True`, transfer the annotations from ``source`` to ``target``.
source
Key identifying the source distribution.
target
Key identifying the target distribution.
cell_transition_kwargs
Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``.
Returns
-------
:class:`~pandas.DataFrame` - Returns the DataFrame of transferred annotations.
"""
return self._annotation_mapping(
mapping_mode=mapping_mode,
annotation_label=annotation_label,
source=source,
target=target,
key=self.batch_key,
forward=forward,
other_adata=self.adata_tgt,
cell_transition_kwargs=cell_transition_kwargs,
)

@property
def batch_key(self) -> Optional[str]:
"""Batch key in :attr:`~anndata.AnnData.obs`."""
Expand Down
Loading

0 comments on commit 7a4883e

Please sign in to comment.