Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): Linear GENOT #662

Merged
merged 334 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
334 commits
Select commit Hold shift + click to select a range
f86eb0e
fix jaxsampler
MUCDK Dec 29, 2022
ac3c781
fix jaxsampler
MUCDK Dec 29, 2022
385bcfd
fix jaxsampler
MUCDK Dec 29, 2022
cfebc4d
fix tests
MUCDK Dec 29, 2022
e056042
add plot_convergence
MUCDK Dec 30, 2022
a7af86b
remove jit from _compute_unbalanced marginals
MUCDK Jan 2, 2023
c8cbe9d
fix sinkhorn_divergence
MUCDK Jan 2, 2023
ff145d1
adapt tox.ini file
MUCDK Jan 4, 2023
9e241db
shape mismatch fixed without precommit
AlejandroTL Jan 11, 2023
3cc279f
remove print statement
MUCDK Jan 12, 2023
c10c5b6
finish merge
MUCDK Jan 12, 2023
74adb1e
adapt callbacks and rename tag `cost` to `cost_matrix` (#426)
MUCDK Dec 9, 2022
a282376
Feature/correlation test (#423)
MUCDK Dec 14, 2022
7360819
fix sankey return statement (#428)
MUCDK Dec 14, 2022
e08f63f
Bump version: 0.1.0 → 0.1.1
MUCDK Dec 14, 2022
728537c
fix return statements
MUCDK Dec 14, 2022
f2153e9
add save tests
MUCDK Dec 14, 2022
4f1d26e
fix return type in mpl (#432)
MUCDK Dec 14, 2022
0ffc730
Simplify linear operator (#431)
michalk8 Dec 14, 2022
cee2466
Explicitly jit the solvers (#433)
michalk8 Dec 16, 2022
6b11188
Feature/interpolate colors sankey (#434)
MUCDK Dec 23, 2022
e28816f
Remove `FGWSolver` (#437)
michalk8 Jan 4, 2023
3ff63be
fix bug in SinkhornProblem (#442)
MUCDK Jan 4, 2023
865ad5a
fix pre commits
MUCDK Jan 12, 2023
2571935
make push/pull always use source/target (#443)
MUCDK Jan 5, 2023
3d22677
fix strip plotting in sankey (#445)
MUCDK Jan 16, 2023
b5370e0
Feature/spearman correlation (#444)
MUCDK Jan 17, 2023
34ef904
Delete logo.png
MUCDK Jan 23, 2023
9806854
Feature/plot order (#453)
MUCDK Jan 27, 2023
50dbdca
Expose marginal kwargs for `moscot.temporal` and check for numeric ty…
MUCDK Feb 2, 2023
46e1c65
adapt plot_convergence (#454)
MUCDK Feb 3, 2023
61f5481
Bug/docs generic analysis mixin (#455)
MUCDK Feb 3, 2023
082c879
Docs/improvements (#456)
MUCDK Feb 3, 2023
a53ef22
remove uns_key from set_plotting_vars (#458)
MUCDK Feb 3, 2023
3a9f992
resolve `fig referenced before assignment` (#460)
MUCDK Feb 3, 2023
c323cb5
move generic mixins tests to problems` (#461)
MUCDK Feb 3, 2023
edb3c11
Tests/spatiotemporalproblem (#464)
MUCDK Feb 5, 2023
e3c3911
Feature/move taggedarray (#457)
MUCDK Feb 5, 2023
fdc5d54
add marginal_kwargs to prepare method of TemporalNeuralProblem
MUCDK Feb 5, 2023
859700c
fix to scaling in
lucaeyring Mar 3, 2023
22a48f5
Revert "fix to scaling in"
lucaeyring Mar 3, 2023
dfa2d82
fix to scaling argument in marginal_kwargs
lucaeyring Mar 3, 2023
4bcd966
updated conditional not pipeline
lucaeyring Mar 21, 2023
cb7bc2d
merge into condot branch
lucaeyring Mar 21, 2023
ddb1772
merge into condot branch
lucaeyring Mar 21, 2023
9636f1c
incoporated comments
lucaeyring Mar 23, 2023
f27dab0
incoporated comments
lucaeyring Mar 23, 2023
816b084
incoporated comments
lucaeyring Mar 24, 2023
1eb60f4
removed new_adata for push/pull
lucaeyring Mar 27, 2023
2b6686f
Merge pull request #9 from theislab/condot_revamp
MUCDK Mar 27, 2023
9e053ca
Merge pull request #7 from theislab/conditional_not_precommit
lucaeyring Mar 27, 2023
0d96dbe
[ci skip] start docs
MUCDK Mar 27, 2023
83e6d71
added temporal neural test
lucaeyring Mar 28, 2023
df346d9
[ci skip] continue docs
MUCDK Mar 28, 2023
ec0ab31
continue docs
MUCDK Mar 29, 2023
fbd4181
continue docs
MUCDK Mar 29, 2023
10799ec
change validation epsilon
MUCDK Mar 29, 2023
7f94ef9
fixed error when not computing wasserstein baseline
lucaeyring Mar 29, 2023
1421c65
fixed error when not computing wasserstein baseline
lucaeyring Mar 29, 2023
b828752
Merge pull request #11 from theislab/feature/docs
lucaeyring Mar 29, 2023
1b584ae
Merge branch 'main' into temporal_neural_test
lucaeyring Mar 29, 2023
d37c7d3
Merge pull request #12 from theislab/temporal_neural_test
lucaeyring Mar 29, 2023
dd106af
correct typo
MUCDK Mar 31, 2023
453922c
fix bug
MUCDK Mar 31, 2023
298f8fb
added neural tests
lucaeyring Mar 31, 2023
491a9a4
[ci skip] draft CondNeuralOutput
MUCDK Apr 8, 2023
e6b1f9a
include CondDualPotentials and CondDualSolver
MUCDK Apr 8, 2023
77fea48
merge moscot main restructuring
lucaeyring May 1, 2023
b1316bc
fixes to main merge
lucaeyring May 1, 2023
1ff844f
fix typo
MUCDK May 8, 2023
29ff674
fix test_cell_transition_subset_pipeline
MUCDK May 8, 2023
518a8e5
fix tests
MUCDK May 8, 2023
2a5bcd7
update conditionalDualPotentials
MUCDK May 25, 2023
6594d43
update conditionalDualPotentials
MUCDK May 25, 2023
2c43b07
fix most pre-commit hooks and fix tests
MUCDK May 25, 2023
594e2f9
fix pandas version to <2.0
MUCDK May 25, 2023
ebc7877
fix tests for non-conditional solvers
MUCDK May 25, 2023
df0b6a8
merge continue_docs
MUCDK May 25, 2023
db7b2c0
continue
MUCDK May 25, 2023
3220f9e
fix
MUCDK May 25, 2023
4fb6e51
continue fixing
MUCDK May 25, 2023
fdf2882
fix ICNN setup
MUCDK May 29, 2023
35db509
fix tests
MUCDK May 29, 2023
65850ae
Merge pull request #18 from theislab/neural_tests_local
MUCDK May 29, 2023
b7d8a00
swap role of f and g, such that push/pull is correct again
MUCDK May 30, 2023
ea7187e
[ci skip] restructure to include more general neural solvers
MUCDK Jun 1, 2023
0c9a587
[ci skip] restructure ICNNs to allow passing instances of ICNN
MUCDK Jun 1, 2023
d3a529d
adapt tests
MUCDK Jun 1, 2023
aba225d
Filled in Monge Gap structure
gocato Jun 10, 2023
d5005b1
Added Monge Gap paper to documentation
gocato Jun 10, 2023
2e0fb4a
Ammend PointCloud Import
gocato Jun 10, 2023
f328056
Merge remote-tracking branch 'origin/feature/add_monge_gap' into feat…
gocato Jun 10, 2023
4692dc5
Update _utils.py
gocato Jun 10, 2023
36b1953
Merge remote-tracking branch 'origin/feature/add_monge_gap' into feat…
gocato Jun 10, 2023
66e5794
Solve compatibility issue with ProblemKind
gocato Jun 10, 2023
2ad39a8
Solve missing Import
gocato Jun 10, 2023
16a797d
Fix call to deprecated function
gocato Jun 10, 2023
7d63082
Fix style and comment issues
gocato Jun 12, 2023
9aff229
add callback, swap f & g
lucaeyring Jun 21, 2023
8f91e4a
add callback, swap f & g
lucaeyring Jun 21, 2023
2d63db7
add callback, swap f & g
lucaeyring Jun 21, 2023
5452485
Merge pull request #26 from theislab/feature/add_callback
MUCDK Jun 21, 2023
6f413a2
Merge pull request #27 from theislab/dev
MUCDK Jun 21, 2023
43b873d
intermediate save
MUCDK Sep 1, 2023
602b8a7
intermediate save
MUCDK Sep 1, 2023
b85c8e0
intermediate save
MUCDK Sep 5, 2023
de6facc
intermediate save
MUCDK Sep 5, 2023
a32038c
partially resolve precommit errors
MUCDK Oct 10, 2023
fa9d6cf
[ci skip] fix merge conflicts
MUCDK Oct 10, 2023
501b018
resolve conflict
MUCDK Oct 10, 2023
0419670
remove pairwise policy
MUCDK Oct 10, 2023
ce7667a
add neural dependencies
MUCDK Oct 10, 2023
f97b08c
add neural dependencies
MUCDK Oct 10, 2023
243c73a
add flax
MUCDK Oct 10, 2023
4516612
fix _call_kwargs
MUCDK Oct 11, 2023
68e7bf3
fix marginal kwargs
MUCDK Oct 11, 2023
536f681
remove monge gap solver
MUCDK Oct 11, 2023
3185371
clean condneuralsolver
MUCDK Oct 11, 2023
582ad43
[ci skip] introduce new data container for joint neural problems
MUCDK Oct 12, 2023
7e3d4f7
add conditions in distirbutioncontainer
MUCDK Oct 17, 2023
51915db
resolve unfreeze/freeze
MUCDK Oct 17, 2023
47508b3
enable pretraining and weight clipping
MUCDK Oct 17, 2023
5363d98
make dicts compatible with older python versions
MUCDK Oct 17, 2023
0dd6b8b
resolve precommit errors partially
MUCDK Oct 17, 2023
21f3309
resolve precommit errors partially
MUCDK Oct 17, 2023
3bbb4da
adapt tests
MUCDK Oct 19, 2023
2cced99
[ci skip] draft unbalancedNeuralMixin
MUCDK Oct 21, 2023
aa49c10
[ci skip] fix naming of posterior marginals
MUCDK Oct 21, 2023
16d3204
[ci skip] add MLP_marginals
MUCDK Oct 23, 2023
c61633c
adapt neural output to incorporate learnt rescaling functions
MUCDK Oct 24, 2023
2dc6ff3
fix _solve in neuraldualsolver
MUCDK Oct 25, 2023
badd57b
incorporate feedback
MUCDK Oct 25, 2023
bd983fd
fix distributioncollection class
MUCDK Oct 25, 2023
b00e2b7
unify _split_data
MUCDK Oct 25, 2023
79f050e
fix tests
MUCDK Oct 25, 2023
5c19f1f
fix some precommit hooks
MUCDK Oct 25, 2023
4e964e5
make neural dependencies optional
MUCDK Oct 25, 2023
517db5f
make neural dependencies optional
MUCDK Oct 25, 2023
742f02c
delete old files
MUCDK Oct 25, 2023
d763de4
adapt pyproject.toml
MUCDK Oct 25, 2023
2c64e3b
adapt pyproject.toml
MUCDK Oct 25, 2023
73c7830
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2023
6d40537
[ci skip] adjust _format_params
MUCDK Oct 25, 2023
9483e69
adapt neuraldualsolver to be more similar to ott-jax
MUCDK Oct 26, 2023
32300e5
adapt neuraldualsolver
MUCDK Oct 26, 2023
88b9a59
TODO: make JaxSampler return conditions
MUCDK Oct 26, 2023
dd74215
add basic neural test
MUCDK Oct 27, 2023
dce2b21
[ci skip] intermediate save
MUCDK Oct 30, 2023
d2ece68
adapt neuraldualsolver and finish tests for neural backend
MUCDK Oct 30, 2023
3bd674f
[ci skip] TODO: re-iterate on initialisation of neural solver
MUCDK Oct 30, 2023
61e3a01
adapt distributioncontainer
MUCDK Nov 3, 2023
b949d6f
fix dict bug
MUCDK Nov 3, 2023
e90934c
resolve passing of arguments in solver call methods
MUCDK Nov 3, 2023
295eb6e
[ci skip] adapt `solve` in `CondOTProblem`
MUCDK Nov 3, 2023
08fcec2
adapt tests and valid loader conditions
MUCDK Nov 5, 2023
868146a
adapt neural backend tests
MUCDK Nov 5, 2023
f84bd31
fix mypy errors
MUCDK Nov 5, 2023
e74e0a4
Merge branch 'main' into feature/moscot_not
MUCDK Nov 5, 2023
9482ccc
make basesolveroutput to basediscretesolveroutput
MUCDK Nov 5, 2023
182bab7
move `to` to BaseSolverOutput`
MUCDK Nov 5, 2023
166bd49
adapt transport_matrix docs
MUCDK Nov 5, 2023
fea3ac9
adapt transport_matrix docs
MUCDK Nov 5, 2023
7669f7b
adapt tests
MUCDK Nov 5, 2023
e53f9c0
adapt tests
MUCDK Nov 5, 2023
cc6bfb5
update unbalancedness mixin
MUCDK Nov 8, 2023
a438629
use implementation from moscot
MUCDK Nov 9, 2023
3f422bf
uncomment unused code
MUCDK Nov 9, 2023
3e9dfc5
before passing states to loss-fn
MUCDK Nov 10, 2023
f4b7c76
intermediate save
MUCDK Nov 10, 2023
4152baf
adapt neuraldualsolver
MUCDK Nov 10, 2023
06e55c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2023
0b46bae
resolve some / not all pre commit errors
MUCDK Nov 10, 2023
4566eb9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2023
c981af7
(wip): tests run, code swapped out for now
ilan-gold Feb 7, 2024
8763321
(wip): `NeuralSolver`s implemented minus quad/linear
ilan-gold Feb 7, 2024
c2312e3
(wip): begin more generic problem
ilan-gold Feb 7, 2024
7f09740
(wip): more refactoring to pass arguments to GENOT
ilan-gold Feb 14, 2024
10c1102
(chore): remove more kantorovich
ilan-gold Feb 14, 2024
4a7a79e
(chore): update branch to moscot neural + first test moving to solving
ilan-gold Feb 14, 2024
f10be61
(fix): split data remains in numpy
ilan-gold Feb 14, 2024
03bec6b
(fix): push/pull api
ilan-gold Feb 14, 2024
a99231a
(fix): make push test work
ilan-gold Feb 14, 2024
b1f3ea4
(feat): allow for custom optimziers
ilan-gold Feb 14, 2024
e3b31d6
(chore): remove unclear test
ilan-gold Feb 14, 2024
06376a1
(refactor): change to composition API
ilan-gold Feb 19, 2024
0f955d5
(refactor): start towards model-specific problems
ilan-gold Feb 26, 2024
f1a0718
(chore): clean up all unnecessary classes
ilan-gold Feb 26, 2024
94441f0
(chore): updating to moscot latest
ilan-gold Feb 26, 2024
2583c56
Merge branch 'main' into ig/neural_solvers
ilan-gold Feb 26, 2024
591f3b4
(chore): remove (hopefully) final ICNN vestiges
ilan-gold Feb 26, 2024
96bd942
Merge branch 'main' into ig/neural_solvers
ilan-gold Feb 26, 2024
37c3757
(chore): more cleanup
ilan-gold Feb 26, 2024
5b4afdd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
85ac491
Merge branch 'main' into ig/neural_solvers
ilan-gold Feb 26, 2024
25844d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2024
a0f94a4
(fix): pass pre-commit hooks
ilan-gold Feb 26, 2024
f017c9e
Merge branch 'ig/neural_solvers' of github.com:theislab/moscot into i…
ilan-gold Feb 26, 2024
6e85255
(chore): remove duplicatec docs
ilan-gold Feb 26, 2024
133ed51
(chore): add torch for testing
ilan-gold Feb 26, 2024
556f6ce
(fix): add ott jax branch as dep
ilan-gold Feb 26, 2024
4bea0d2
(fix): repo name
ilan-gold Feb 26, 2024
e0a5628
(chore): remove unbalanced, update api, fix tests + drive by typing fix
ilan-gold Apr 10, 2024
d09a9eb
(feat): first pass at neural mixin
ilan-gold Apr 10, 2024
6ac3d3f
(chore): add my name to todos
ilan-gold Apr 10, 2024
b2faf8a
(fix): conditions left out if not necessary
ilan-gold Apr 17, 2024
79e76e1
(feat): logs and fix conditional attr
ilan-gold Apr 17, 2024
fbfd93f
(fix): add `seed` to call_kwargs so reproducibility works
ilan-gold Apr 17, 2024
ff27786
(chore): remove `is_conditional` business
ilan-gold Apr 18, 2024
7e90a53
(fix): create hidden dims arg for velocity field
ilan-gold Apr 18, 2024
1dbd78c
(chore): raise not implemented error for `pull`
ilan-gold May 2, 2024
f64d4ff
(fix): default args
ilan-gold May 2, 2024
050360c
(fix): add explicit policy
ilan-gold May 16, 2024
556a815
(fix): allow iteration to continue
ilan-gold May 16, 2024
d18154a
(chore): add star policy to GENOT
ilan-gold May 16, 2024
24ee61e
(chore): notebooks
ilan-gold May 22, 2024
77cd780
Merge branch 'main' into ig/neural_solvers
ilan-gold May 22, 2024
83f9fb1
(chore): remove deps
ilan-gold May 22, 2024
a4ec4f9
(chore): remove unnecessary spaces
ilan-gold May 22, 2024
0145b5b
(chore): simplify quad handling
ilan-gold May 22, 2024
98177c1
(fix): need to require `optax`/`flax`
ilan-gold May 22, 2024
ea3dc93
(fix): use `ott-jax[neural]`
ilan-gold May 22, 2024
b9b66d5
(chore): fix docs
ilan-gold May 24, 2024
16e45a1
(fix): small test fixes
ilan-gold May 24, 2024
4cbc912
(chore): small notebook changes
ilan-gold May 24, 2024
cb04bef
(fix): broken link in citation
ilan-gold May 24, 2024
884996d
(chore): make notebook dependent on ci
ilan-gold May 24, 2024
0f353fc
(fix): small todos just to push something
ilan-gold May 24, 2024
341068c
(fix): variable is a string
ilan-gold May 24, 2024
64f89ea
(fix): pass environment variable to tox
ilan-gold May 24, 2024
85342f4
(fix): actually pass through
ilan-gold May 24, 2024
0098193
(fix): hidden dims ci
ilan-gold May 24, 2024
6f9890a
(fix): re-add notebook
ilan-gold May 27, 2024
7860056
Merge branch 'main' into ig/neural_solvers
ilan-gold May 28, 2024
46e5fc2
(chore): make`recall_target` and `aggregate_to_topk`
ilan-gold Jun 19, 2024
bcd288c
(chore): fix default arguments
ilan-gold Jun 19, 2024
822628d
(chore): `project_transport_matrix` -> `project_to_transport_matrix`
ilan-gold Jun 19, 2024
5835390
(fix): remove dead `NeuralAnalysisMixin` code
ilan-gold Jun 19, 2024
d245cb0
Merge branch 'ig/neural_solvers' of github.com:theislab/moscot into i…
ilan-gold Jun 19, 2024
90e9fc6
Merge branch 'main' into ig/neural_solvers
ilan-gold Jun 19, 2024
e9cd90e
(feat): allow custom `data_match_fn`
ilan-gold Jun 24, 2024
cf11513
(fix): inherit from `MutableMapping` instead of `dict`
ilan-gold Aug 5, 2024
0f91c4e
Merge branch 'main' into ig/neural_solvers
ilan-gold Aug 5, 2024
ab51550
(Fix): docs
ilan-gold Aug 5, 2024
2ce5c0b
(fix): notebooks
ilan-gold Aug 5, 2024
9b38b65
(fix): docs reference
ilan-gold Aug 5, 2024
60e074c
(fix): remove `attr`
ilan-gold Aug 5, 2024
c8a9cbd
(fix): erroneous change
ilan-gold Aug 5, 2024
b97a2b4
(fix): remove empty
ilan-gold Aug 5, 2024
d3b31ec
(fix): notebooks again?
ilan-gold Aug 5, 2024
2678705
(chore): ok?
ilan-gold Aug 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,24 @@
nitpicky = True
nitpick_ignore = [
("py:class", "numpy.float64"),
# see: https://github.com/numpy/numpydoc/issues/275
("py:class", "None. Remove all items from D."),
("py:class", "a set-like object providing a view on D's items"),
("py:class", "a set-like object providing a view on D's keys"),
("py:class", "v, remove specified key and return the corresponding value."), # noqa: E501
("py:class", "None. Update D from dict/iterable E and F."),
("py:class", "an object providing a view on D's values"),
("py:class", "a shallow copy of D"),
Comment on lines +66 to +73
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what these mean. Is this intentional or was it left from debugging

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these are warnings from the RTD that should be ignored base on the static class rendering file. But they are indeed too many I'm not sure if it's wanted.

Copy link
Collaborator Author

@ilan-gold ilan-gold Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, what to do here. We can't import from MutableMapping for DistributionCollection without re-implementing all the abstract methods, which seems like a waste. The other option would be to remove the DistributionCollection class since it is used only for repr/cleanness of API. @MUCDK this code was pulled from your starting point, so perhaps you can chime in.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be fine with keeping it as it is if it's not too much of an issue. @selmanozleyen wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If typing.Dict doesn't work as well I guess using those ignores is a better solution than re-implementation.

]
# TODO(michalk8): remove once typing has been cleaned-up
nitpick_ignore_regex = [
(r"py:class", r"moscot\..*(K|B|O)"),
(r"py:class", r"numpy\._typing.*"),
(r"py:class", r"moscot\..*Protocol.*"),
(
r"py:class",
r"moscot.base.output.BaseSolverOutput",
), # https://github.com/sphinx-doc/sphinx/issues/10974 means there is simply no way around this with generics
]


Expand Down
8 changes: 6 additions & 2 deletions docs/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Backends
backends.ott.GWSolver
backends.ott.OTTOutput
backends.ott.GraphOTTOutput
backends.ott.GENOTLinSolver
backends.ott.output.OTTNeuralOutput
backends.utils.get_solver
backends.utils.get_available_backends

Expand Down Expand Up @@ -44,6 +46,7 @@ Problems
problems.BaseCompoundProblem
problems.CompoundProblem
cost.BaseCost
problems.CondOTProblem

Mixins
^^^^^^
Expand All @@ -62,14 +65,13 @@ Solvers

solver.BaseSolver
solver.OTSolver
output.BaseSolverOutput

Output
^^^^^^
.. autosummary::
:toctree: genapi

output.BaseSolverOutput
output.BaseDiscreteSolverOutput
output.MatrixSolverOutput

Utils
Expand Down Expand Up @@ -100,6 +102,8 @@ Miscellaneous
data.apoptosis_markers
tagged_array.TaggedArray
tagged_array.Tag
tagged_array.DistributionCollection
tagged_array.DistributionContainer

.. currentmodule:: moscot.base.problems
.. autosummary::
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks
9 changes: 9 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,12 @@ @article{srivatsan:20
year={2020},
publisher={American Association for the Advancement of Science}
}

@misc{klein2023generative,
title={Generative Entropic Neural Optimal Transport To Map Within and Across Spaces},
author={Dominik Klein and Théo Uscidda and Fabian Theis and Marco Cuturi},
year={2023},
eprint={2310.09254},
archivePrefix={arXiv},
primaryClass={stat.ML}
}
1 change: 1 addition & 0 deletions docs/user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Generic Problems
generic.SinkhornProblem
generic.GWProblem
generic.FGWProblem
generic.GENOTLinProblem

Plotting
~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ dependencies = [
"scanpy>=1.9.3",
"wrapt>=1.13.2",
"docrep>=0.3.2",
"ott-jax>=0.4.6",
"ott-jax[neural]>=0.4.6",
"cloudpickle>=2.2.0",
"rich>=13.5",
]
Expand Down Expand Up @@ -267,11 +267,11 @@ skip_missing_interpreters = true

[testenv]
extras = test
pass_env = PYTEST_*,CI
commands =
python -m pytest {tty:--color=yes} {posargs: \
--cov={env_site_packages_dir}{/}moscot --cov-config={tox_root}{/}pyproject.toml \
--no-cov-on-fail --cov-report=xml --cov-report=term-missing:skip-covered}
passenv = PYTEST_*,CI

[testenv:lint-code]
description = Lint the code.
Expand Down
7 changes: 4 additions & 3 deletions src/moscot/backends/ott/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from ott.geometry import costs

from moscot.backends.ott._utils import sinkhorn_divergence
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
from moscot.backends.ott.solver import GWSolver, SinkhornSolver
from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput
from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
from moscot.costs import register_cost

__all__ = ["OTTOutput", "GraphOTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "OTTNeuralOutput", "sinkhorn_divergence", "GENOTLinSolver"]


register_cost("euclidean", backend="ott")(costs.Euclidean)
register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
Expand Down
109 changes: 105 additions & 4 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Any, Literal, Optional, Tuple, Union
from collections import defaultdict
from functools import partial
from typing import Any, Dict, Iterable, Literal, Optional, Tuple, Union

import jax
import jax.experimental.sparse as jesp
import jax.numpy as jnp
import numpy as np
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
from ott.tools import sinkhorn_divergence as sdiv
from ott.neural import datasets
from ott.solvers import utils as solver_utils
from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div

from moscot._logging import logger
from moscot._types import ArrayLike, ScaleCost_t
Expand All @@ -22,22 +27,27 @@ def sinkhorn_divergence(
a: Optional[ArrayLike] = None,
b: Optional[ArrayLike] = None,
epsilon: Union[float, epsilon_scheduler.Epsilon] = 1e-1,
tau_a: float = 1.0,
tau_b: float = 1.0,
scale_cost: ScaleCost_t = 1.0,
batch_size: Optional[int] = None,
**kwargs: Any,
) -> float:
point_cloud_1 = jnp.asarray(point_cloud_1)
point_cloud_2 = jnp.asarray(point_cloud_2)
a = None if a is None else jnp.asarray(a)
b = None if b is None else jnp.asarray(b)

output = sdiv.sinkhorn_divergence(
output = sinkhorn_div(
pointcloud.PointCloud,
x=point_cloud_1,
y=point_cloud_2,
batch_size=batch_size,
a=a,
b=b,
epsilon=epsilon,
sinkhorn_kwargs={"tau_a": tau_a, "tau_b": tau_b},
scale_cost=scale_cost,
epsilon=epsilon,
**kwargs,
)
xy_conv, xx_conv, *yy_conv = output.converged
Expand All @@ -52,6 +62,17 @@ def sinkhorn_divergence(
return float(output.divergence)


@partial(jax.jit, static_argnames=["k"])
def get_nearest_neighbors(
input_batch: jnp.ndarray, target: jnp.ndarray, k: int = 30
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Get the k nearest neighbors of the input batch in the target."""
if target.shape[0] < k:
raise ValueError(f"k is {k}, but must be smaller or equal than {target.shape[0]}.")
pairwise_euclidean_distances = pointcloud.PointCloud(input_batch, target).cost_matrix
return jax.lax.approx_min_k(pairwise_euclidean_distances, k=k, recall_target=0.95, aggregate_to_topk=True)

ilan-gold marked this conversation as resolved.
Show resolved Hide resolved

def check_shapes(geom_x: geometry.Geometry, geom_y: geometry.Geometry, geom_xy: geometry.Geometry) -> None:
n, m = geom_xy.shape
n_, m_ = geom_x.shape[0], geom_y.shape[0]
Expand Down Expand Up @@ -133,3 +154,83 @@ def _instantiate_geodesic_cost(
cm_full = geodesic.Geodesic.from_graph(arr, t=t, directed=directed, **kwargs).cost_matrix
cm = cm_full[:n_src, n_src:] if is_linear_term else cm_full
return geometry.Geometry(cm, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)


def data_match_fn(
src_lin: Optional[jnp.ndarray] = None,
tgt_lin: Optional[jnp.ndarray] = None,
src_quad: Optional[jnp.ndarray] = None,
tgt_quad: Optional[jnp.ndarray] = None,
*,
typ: Literal["lin", "quad", "fused"],
**data_match_fn_kwargs,
) -> jnp.ndarray:
if typ == "lin":
return solver_utils.match_linear(x=src_lin, y=tgt_lin, **data_match_fn_kwargs)
if typ == "quad":
return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, **data_match_fn_kwargs)
if typ == "fused":
return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin, **data_match_fn_kwargs)
raise NotImplementedError(f"Unknown type: {typ}.")


class Loader:

def __init__(self, dataset: datasets.OTDataset, batch_size: int, seed: Optional[int] = None):
self.dataset = dataset
self.batch_size = batch_size
self._rng = np.random.default_rng(seed)

def __iter__(self):
return self

def __next__(self) -> Dict[str, jnp.ndarray]:
data = defaultdict(list)
for _ in range(self.batch_size):
ix = self._rng.integers(0, len(self.dataset))
for k, v in self.dataset[ix].items():
data[k].append(v)
return {k: jnp.vstack(v) for k, v in data.items()}

def __len__(self):
return len(self.dataset)


class MultiLoader:
"""Dataset for OT problems with conditions.

This data loader wraps several data loaders and samples from them.

Args:
datasets: Datasets to sample from.
seed: Random seed.
"""

def __init__(
self,
datasets: Iterable[Loader],
seed: Optional[int] = None,
):
self.datasets = tuple(datasets)
self._rng = np.random.default_rng(seed)
self._iterators: list[MultiLoader] = []
self._it = 0

def __next__(self) -> Dict[str, jnp.ndarray]:
self._it += 1

ix = self._rng.choice(len(self._iterators))
iterator = self._iterators[ix]
if self._it < len(self):
return next(iterator)
# reset the consumed iterator and return it's first element
self._iterators[ix] = iterator = iter(self.datasets[ix])
return next(iterator)

def __iter__(self) -> "MultiLoader":
self._it = 0
self._iterators = [iter(ds) for ds in self.datasets]
return self

def __len__(self) -> int:
return max((len(ds) for ds in self.datasets), default=0)
Loading
Loading