Skip to content

Commit

Permalink
(feat): Linear GENOT (#662)
Browse files Browse the repository at this point in the history
* fix jaxsampler

* fix jaxsampler

* fix jaxsampler

* fix tests

* add plot_convergence

* remove jit from _compute_unbalanced marginals

* fix sinkhorn_divergence

* adapt tox.ini file

* shape mismatch fixed without precommit

* remove print statement

* finish merge

* adapt callbacks and rename tag `cost` to `cost_matrix` (#426)

* rename tag cost to cost_matrix

* fix renaming

* [CI skip], adapt callback

* incorporate requested changes

* add test for quad custom callback

* adapt kwargs for callback

* fix handle_joint_attr

* incorporate requested changes

* Feature/correlation test (#423)

* fix test in FGWProblem

* add correlation test

* add first test for correlation computation

* add more tests

* fix tests

* add tfs to compute_feature_correlation

* add testing for no nans in compute_feature_correlation

* incorporate requested changes

* fix docstring

* fix sankey return statement (#428)

* fix sankey return statement

* adapt test

* adapt return_fig

* Bump version: 0.1.0 → 0.1.1

* fix return statements

* add save tests

* fix return type in mpl (#432)

* fix return type in mpl

* change import acronyms

* fix tests

* Simplify linear operator (#431)

* Simplify linear operator

* Simplify `align`, fix test

* Explicitly jit the solvers (#433)

* Feature/interpolate colors sankey (#434)

* fix return type in mpl

* change import acronyms

* fix tests

* add interpolation option to sankey

* add test to interpolate color

* define colors for pull/push`

* adapt tests

* introduce axes in mpl.push/pull

* incorporate requested changes

* change default color

* adapt plotting

* introduce scaling

* fix scale

* make start/end categorical in plot

* regenerate images

* Remove `FGWSolver` (#437)

* Remove `FGWSolver`

* Fix `tox.ini`

* Fix wrong shape check

* Use pure GW in generic solver

* Update tests

* fix bug in SinkhornProblem (#442)

* fix bug in SinkhornProblem

* fix tox.ini

* fix pre commits

* make push/pull always use source/target (#443)

* make push/pull always use source/target

* fix bug in StarPolicy _apply

* adapt plotting to source/target

* fix strip plotting in sankey (#445)

* fix strip plotting in sankey

* simplify code

* Feature/spearman correlation (#444)

* add spearman correlation

* add tests

* adapt tests

* Delete logo.png

* Feature/plot order (#453)

* make push/pull plot in good order

* [CI skip], try setting adata.uns color explicitly

* [CI skip], fix copying of adata

* fix pre commits

* fix bug

* Expose marginal kwargs for `moscot.temporal` and check for numeric type of `temporal_key` (#449)

* make marginal_kwargs explicit in temporal problems

* introduce check for numeric dtype in temporal mixin

* add alternative way for marginal prior

* adapt tolerances in tests

* correct docs

* fix bug

* Fix math rendering

* fix test


Co-authored-by: Michal Klein <[email protected]>

* adapt plot_convergence (#454)

* Bug/docs generic analysis mixin (#455)

* adapt plot_convergence

* remove temporal-alluding docs in generic analysis mixin

* Docs/improvements (#456)

* adapt plot_convergence

* remove temporal-alluding docs in generic analysis mixin

* docs suggestions

* remove uns_key from set_plotting_vars (#458)

* resolve `fig referenced before assignment` (#460)

* move generic mixins tests to problems` (#461)

* Tests/spatiotemporalproblem (#464)

* add more tests for spatiotemporalProblem

* move some functions from TemporalProblem to TemporalMixin

* add tests LineageProblem

* fix tests

* Feature/move taggedarray (#457)

* adapt plot_convergence

* remove temporal-alluding docs in generic analysis mixin

* docs suggestions

* move tagged array

* move taggedarray back to solvers

* add marginal_kwargs to prepare method of TemporalNeuralProblem

* fix to scaling in

* Revert "fix to scaling in"

This reverts commit 0a6f7db.

* fix to scaling argument in marginal_kwargs

* updated conditional not pipeline

* merge into condot branch

* incoporated comments

* incoporated comments

* incoporated comments

* removed new_adata for push/pull

* [ci skip] start docs

* added temporal neural test

* [ci skip] continue  docs

* continue docs

* continue docs

* change validation epsilon

* fixed error when not computing wasserstein baseline

* fixed error when not computing wasserstein baseline

* correct typo

* fix bug

* added neural tests

* [ci skip] draft CondNeuralOutput

* include CondDualPotentials and CondDualSolver

* fixes to main merge

* fix test_cell_transition_subset_pipeline

* fix tests

* update conditionalDualPotentials

* update conditionalDualPotentials

* fix most pre-commit hooks and fix tests

* fix pandas version to <2.0

* fix tests for non-conditional solvers

* continue

* fix

* continue fixing

* fix ICNN setup

* fix tests

* swap role of f and g, such that push/pull is correct again

* [ci skip] restructure to include more general neural solvers

* [ci skip] restructure ICNNs to allow passing instances of ICNN

* adapt tests

* Filled in Monge Gap structure

* Added Monge Gap paper to documentation

* Ammend PointCloud Import

* Update _utils.py

Ammend PointCloud import

* Solve compatibility issue with ProblemKind

* Solve missing Import

* Fix call to deprecated function

* Fix style and comment issues

* add callback, swap f & g

* add callback, swap f & g

* add callback, swap f & g

* intermediate save

* intermediate save

* intermediate save

* [ci skip] fix merge conflicts

* resolve conflict

* remove pairwise policy

* add neural dependencies

* add neural dependencies

* add flax

* fix _call_kwargs

* fix marginal kwargs

* remove monge gap solver

* clean condneuralsolver

* [ci skip] introduce new data container for joint neural problems

* add conditions in distirbutioncontainer

* resolve unfreeze/freeze

* enable pretraining and weight clipping

* make dicts compatible with older python versions

* resolve precommit errors partially

* resolve precommit errors partially

* adapt tests

* [ci skip] draft unbalancedNeuralMixin

* [ci skip] fix naming of posterior marginals

* [ci skip] add MLP_marginals

* adapt neural output to incorporate learnt rescaling functions

* fix _solve in neuraldualsolver

* incorporate feedback

* fix distributioncollection class

* unify _split_data

* fix tests

* fix some precommit hooks

* make neural dependencies optional

* make neural dependencies optional

* delete old files

* adapt pyproject.toml

* adapt pyproject.toml

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [ci skip] adjust _format_params

* adapt neuraldualsolver to be more similar to ott-jax

* adapt neuraldualsolver

* TODO: make JaxSampler return conditions

* add basic neural test

* [ci skip] intermediate save

* adapt neuraldualsolver and finish tests for neural backend

* [ci skip] TODO: re-iterate on initialisation of neural solver

* adapt distributioncontainer

* fix dict bug

* resolve passing of arguments in solver call methods

* [ci skip] adapt `solve` in `CondOTProblem`

* adapt tests and valid loader conditions

* adapt neural backend tests

* fix mypy errors

* make basesolveroutput to basediscretesolveroutput

* move `to` to BaseSolverOutput`
"

* adapt transport_matrix docs

* adapt transport_matrix docs

* adapt tests

* adapt tests

* update unbalancedness mixin

* use implementation from moscot

* uncomment unused code

* before passing states to loss-fn

* intermediate save

* adapt neuraldualsolver

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* resolve some / not all pre commit errors

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (wip): tests run, code swapped out for now

* (wip): `NeuralSolver`s implemented minus quad/linear

* (wip): begin more generic problem

* (wip): more refactoring to pass arguments to GENOT

* (chore): remove more kantorovich

* (chore): update branch to moscot neural + first test moving to solving

* (fix): split data remains in numpy

* (fix): push/pull api

* (fix): make push test work

* (feat): allow for custom optimziers

* (chore): remove unclear test

* (refactor): change to composition API

* (refactor): start towards model-specific problems

* (chore): clean up all unnecessary classes

* (chore): updating to moscot latest

* Merge branch 'main' into ig/neural_solvers

* (chore): remove (hopefully) final ICNN vestiges

* (chore): more cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): pass pre-commit hooks

* (chore): remove duplicatec docs

* (chore): add torch for testing

* (fix): add ott jax branch as dep

* (fix): repo name

* (chore): remove unbalanced, update api, fix tests + drive by typing fix

* (feat): first pass at neural mixin

* (chore): add my name to todos

* (fix): conditions left out if not necessary

* (feat): logs and fix conditional attr

* (fix): add `seed` to call_kwargs so reproducibility works

* (chore): remove `is_conditional` business

* (fix): create hidden dims arg for velocity field

* (chore): raise not implemented error for `pull`

* (fix): default args

* (fix): add explicit policy

* (fix): allow iteration to continue

* (chore): add star policy to GENOT

* (chore): notebooks

* (chore): remove deps

* (chore): remove unnecessary spaces

* (chore): simplify quad handling

* (fix): need to require `optax`/`flax`

* (fix): use `ott-jax[neural]`

* (chore): fix docs

* (fix): small test fixes

* (chore): small notebook changes

* (fix): broken link in citation

* (chore): make notebook dependent on ci

* (fix): small todos just to push something

* (fix): variable is a string

* (fix): pass environment variable to tox

* (fix): actually pass through

* (fix): hidden dims ci

* (fix): re-add notebook

* (chore): make`recall_target` and  `aggregate_to_topk`

* (chore): fix default arguments

* (chore): `project_transport_matrix` -> `project_to_transport_matrix`

* (fix): remove dead `NeuralAnalysisMixin` code

* (feat): allow custom `data_match_fn`

* (fix): inherit from `MutableMapping` instead of `dict`

* (Fix): docs

* (fix): notebooks

* (fix): docs reference

* (fix): remove `attr`

* (fix): erroneous change

* (fix): remove empty

* (fix): notebooks again?

* (chore): ok?

---------

Co-authored-by: Dominik Klein <[email protected]>
Co-authored-by: Dominik Klein <[email protected]>
Co-authored-by: AlejandroTL <[email protected]>
Co-authored-by: michalk8 <[email protected]>
Co-authored-by: lucaeyring <[email protected]>
Co-authored-by: gocato <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
8 people authored Aug 21, 2024
1 parent 36d93bf commit 6e65ac1
Show file tree
Hide file tree
Showing 45 changed files with 1,530 additions and 147 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: ["3.8", "3.10"]
python: ["3.9", "3.10"]
include:
- os: macos-latest
python: "3.9"
Expand Down
12 changes: 12 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,24 @@
nitpicky = True
nitpick_ignore = [
("py:class", "numpy.float64"),
# see: https://github.com/numpy/numpydoc/issues/275
("py:class", "None. Remove all items from D."),
("py:class", "a set-like object providing a view on D's items"),
("py:class", "a set-like object providing a view on D's keys"),
("py:class", "v, remove specified key and return the corresponding value."), # noqa: E501
("py:class", "None. Update D from dict/iterable E and F."),
("py:class", "an object providing a view on D's values"),
("py:class", "a shallow copy of D"),
]
# TODO(michalk8): remove once typing has been cleaned-up
nitpick_ignore_regex = [
(r"py:class", r"moscot\..*(K|B|O)"),
(r"py:class", r"numpy\._typing.*"),
(r"py:class", r"moscot\..*Protocol.*"),
(
r"py:class",
r"moscot.base.output.BaseSolverOutput",
), # https://github.com/sphinx-doc/sphinx/issues/10974 means there is simply no way around this with generics
]


Expand Down
8 changes: 6 additions & 2 deletions docs/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Backends
backends.ott.GWSolver
backends.ott.OTTOutput
backends.ott.GraphOTTOutput
backends.ott.GENOTLinSolver
backends.ott.output.OTTNeuralOutput
backends.utils.get_solver
backends.utils.get_available_backends

Expand Down Expand Up @@ -44,6 +46,7 @@ Problems
problems.BaseCompoundProblem
problems.CompoundProblem
cost.BaseCost
problems.CondOTProblem

Mixins
^^^^^^
Expand All @@ -62,14 +65,13 @@ Solvers

solver.BaseSolver
solver.OTSolver
output.BaseSolverOutput

Output
^^^^^^
.. autosummary::
:toctree: genapi

output.BaseSolverOutput
output.BaseDiscreteSolverOutput
output.MatrixSolverOutput

Utils
Expand Down Expand Up @@ -100,6 +102,8 @@ Miscellaneous
data.apoptosis_markers
tagged_array.TaggedArray
tagged_array.Tag
tagged_array.DistributionCollection
tagged_array.DistributionContainer

.. currentmodule:: moscot.base.problems
.. autosummary::
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Installation
============
:mod:`moscot` requires Python version >= 3.8 to run.
:mod:`moscot` requires Python version >= 3.9 to run.

PyPI
----
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks
9 changes: 9 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,12 @@ @article{srivatsan:20
year={2020},
publisher={American Association for the Advancement of Science}
}

@misc{klein2023generative,
title={Generative Entropic Neural Optimal Transport To Map Within and Across Spaces},
author={Dominik Klein and Théo Uscidda and Fabian Theis and Marco Cuturi},
year={2023},
eprint={2310.09254},
archivePrefix={arXiv},
primaryClass={stat.ML}
}
1 change: 1 addition & 0 deletions docs/user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Generic Problems
generic.SinkhornProblem
generic.GWProblem
generic.FGWProblem
generic.GENOTLinProblem

Plotting
~~~~~~~~
Expand Down
9 changes: 4 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "moscot"
dynamic = ["version"]
description = "Multi-omic single-cell optimal transport tools"
readme = "README.rst"
requires-python = ">=3.8"
requires-python = ">=3.9"
license = {file = "LICENSE"}
classifiers = [
"Development Status :: 4 - Beta",
Expand All @@ -19,7 +19,6 @@ classifiers = [
"Operating System :: Microsoft :: Windows",
"Typing :: Typed",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering :: Bio-Informatics",
Expand Down Expand Up @@ -55,7 +54,7 @@ dependencies = [
"scanpy>=1.9.3",
"wrapt>=1.13.2",
"docrep>=0.3.2",
"ott-jax>=0.4.6",
"ott-jax[neural]>=0.4.6",
"cloudpickle>=2.2.0",
"rich>=13.5",
"docstring_inheritance>=2.0.0"
Expand Down Expand Up @@ -263,16 +262,16 @@ max_line_length = 120
legacy_tox_ini = """
[tox]
min_version = 4.0
env_list = lint-code,py{3.8,3.9,3.10,3.11}
env_list = lint-code,py{3.9,3.10,3.11}
skip_missing_interpreters = true
[testenv]
extras = test
pass_env = PYTEST_*,CI
commands =
python -m pytest {tty:--color=yes} {posargs: \
--cov={env_site_packages_dir}{/}moscot --cov-config={tox_root}{/}pyproject.toml \
--no-cov-on-fail --cov-report=xml --cov-report=term-missing:skip-covered}
passenv = PYTEST_*,CI
[testenv:lint-code]
description = Lint the code.
Expand Down
7 changes: 4 additions & 3 deletions src/moscot/backends/ott/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from ott.geometry import costs

from moscot.backends.ott._utils import sinkhorn_divergence
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
from moscot.backends.ott.solver import GWSolver, SinkhornSolver
from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput
from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
from moscot.costs import register_cost

__all__ = ["OTTOutput", "GraphOTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "OTTNeuralOutput", "sinkhorn_divergence", "GENOTLinSolver"]


register_cost("euclidean", backend="ott")(costs.Euclidean)
register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
Expand Down
115 changes: 111 additions & 4 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Any, Literal, Optional, Tuple, Union
from collections import defaultdict
from functools import partial
from typing import Any, Dict, Iterable, Literal, Optional, Tuple, Union

import jax
import jax.experimental.sparse as jesp
import jax.numpy as jnp
import numpy as np
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
from ott.tools import sinkhorn_divergence as sdiv
from ott.neural import datasets
from ott.solvers import utils as solver_utils
from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div

from moscot._logging import logger
from moscot._types import ArrayLike, ScaleCost_t
Expand All @@ -22,22 +27,27 @@ def sinkhorn_divergence(
a: Optional[ArrayLike] = None,
b: Optional[ArrayLike] = None,
epsilon: Union[float, epsilon_scheduler.Epsilon] = 1e-1,
tau_a: float = 1.0,
tau_b: float = 1.0,
scale_cost: ScaleCost_t = 1.0,
batch_size: Optional[int] = None,
**kwargs: Any,
) -> float:
point_cloud_1 = jnp.asarray(point_cloud_1)
point_cloud_2 = jnp.asarray(point_cloud_2)
a = None if a is None else jnp.asarray(a)
b = None if b is None else jnp.asarray(b)

output = sdiv.sinkhorn_divergence(
output = sinkhorn_div(
pointcloud.PointCloud,
x=point_cloud_1,
y=point_cloud_2,
batch_size=batch_size,
a=a,
b=b,
epsilon=epsilon,
sinkhorn_kwargs={"tau_a": tau_a, "tau_b": tau_b},
scale_cost=scale_cost,
epsilon=epsilon,
**kwargs,
)
xy_conv, xx_conv, *yy_conv = output.converged
Expand All @@ -52,6 +62,23 @@ def sinkhorn_divergence(
return float(output.divergence)


@partial(jax.jit, static_argnames=["k"])
def get_nearest_neighbors(
input_batch: jnp.ndarray,
target: jnp.ndarray,
k: int = 30,
recall_target: float = 0.95,
aggregate_to_topk: bool = True,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Get the k nearest neighbors of the input batch in the target."""
if target.shape[0] < k:
raise ValueError(f"k is {k}, but must be smaller or equal than {target.shape[0]}.")
pairwise_euclidean_distances = pointcloud.PointCloud(input_batch, target).cost_matrix
return jax.lax.approx_min_k(
pairwise_euclidean_distances, k=k, recall_target=recall_target, aggregate_to_topk=aggregate_to_topk
)


def check_shapes(geom_x: geometry.Geometry, geom_y: geometry.Geometry, geom_xy: geometry.Geometry) -> None:
n, m = geom_xy.shape
n_, m_ = geom_x.shape[0], geom_y.shape[0]
Expand Down Expand Up @@ -133,3 +160,83 @@ def _instantiate_geodesic_cost(
cm_full = geodesic.Geodesic.from_graph(arr, t=t, directed=directed, **kwargs).cost_matrix
cm = cm_full[:n_src, n_src:] if is_linear_term else cm_full
return geometry.Geometry(cm, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)


def data_match_fn(
src_lin: Optional[jnp.ndarray] = None,
tgt_lin: Optional[jnp.ndarray] = None,
src_quad: Optional[jnp.ndarray] = None,
tgt_quad: Optional[jnp.ndarray] = None,
*,
typ: Literal["lin", "quad", "fused"],
**data_match_fn_kwargs,
) -> jnp.ndarray:
if typ == "lin":
return solver_utils.match_linear(x=src_lin, y=tgt_lin, **data_match_fn_kwargs)
if typ == "quad":
return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, **data_match_fn_kwargs)
if typ == "fused":
return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin, **data_match_fn_kwargs)
raise NotImplementedError(f"Unknown type: {typ}.")


class Loader:

def __init__(self, dataset: datasets.OTDataset, batch_size: int, seed: Optional[int] = None):
self.dataset = dataset
self.batch_size = batch_size
self._rng = np.random.default_rng(seed)

def __iter__(self):
return self

def __next__(self) -> Dict[str, jnp.ndarray]:
data = defaultdict(list)
for _ in range(self.batch_size):
ix = self._rng.integers(0, len(self.dataset))
for k, v in self.dataset[ix].items():
data[k].append(v)
return {k: jnp.vstack(v) for k, v in data.items()}

def __len__(self):
return len(self.dataset)


class MultiLoader:
"""Dataset for OT problems with conditions.
This data loader wraps several data loaders and samples from them.
Args:
datasets: Datasets to sample from.
seed: Random seed.
"""

def __init__(
self,
datasets: Iterable[Loader],
seed: Optional[int] = None,
):
self.datasets = tuple(datasets)
self._rng = np.random.default_rng(seed)
self._iterators: list[MultiLoader] = []
self._it = 0

def __next__(self) -> Dict[str, jnp.ndarray]:
self._it += 1

ix = self._rng.choice(len(self._iterators))
iterator = self._iterators[ix]
if self._it < len(self):
return next(iterator)
# reset the consumed iterator and return it's first element
self._iterators[ix] = iterator = iter(self.datasets[ix])
return next(iterator)

def __iter__(self) -> "MultiLoader":
self._it = 0
self._iterators = [iter(ds) for ds in self.datasets]
return self

def __len__(self) -> int:
return max((len(ds) for ds in self.datasets), default=0)
Loading

0 comments on commit 6e65ac1

Please sign in to comment.