-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
UOT: how to use LRSinkhorn for UOT? #428
Comments
Thanks for your interest in unbalanced LR! We've laid out the ideas here: https://arxiv.org/abs/2305.19727 and @michalk8 has been pushing these modifications. We haven't written a proper tutorial for it yet, though, and we might need a few more weeks. |
Hi @feiyang-k ,
we've just pushed the
These aren't supposed to be used by the user, you can just use the
Both UOT and ULOT give you access to the transport map (both to materialize it, which however defeats the purpose of the LR, as well as how to apply the map to a vector/matrix). We haven't added a tutorial for unbalanced LR solvers yes, but as @marcocuturi says, we will add one in the near future. |
Hi @marcocuturi and @michalk8 , Thanks so much for the detailed info! The referenced paper looks exciting. This is exactly what I've been looking for. I updated to I tried the Is there any idea on this? What is the Thanks! This block of a 10k*10k LR-OT problem solves in 6.8s from ott.solvers.linear import sinkhorn_lr
geom = pointcloud.PointCloud(cld10k, cld10k, epsilon=1e-3)
ot_prob = linear_problem.LinearProblem(geom)
solver = sinkhorn_lr.LRSinkhorn(rank=int(200))
ot_lr = solver(ot_prob)
transp_cost = ot_lr.compute_reg_ot_cost(ot_prob, use_danskin=True)
transp_cost This block of a 1k*10k LR-UOT problem does not complete after 10 minutes from ott.solvers.linear import sinkhorn_lr
geom = pointcloud.PointCloud(cld1k, cld10k, epsilon=1e-3)
ot_prob = linear_problem.LinearProblem(geom, tau_a=0.1)
solver = jax.jit(sinkhorn_lr.LRSinkhorn(rank=200, initializer="k-means"))
ot_lr = solver(ot_prob)
transp_cost = ot_lr.compute_reg_ot_cost(ot_prob)
transp_cost |
Here are some of my thoughts:
|
Thanks @michalk8 ! This helped a lot. I tried this one by one.
Thanks so much! |
True, we should add this to the tutorial, will create an issue for this. Also we should mention there (+ in the docs) to have |
Hi, I want use ULOT algorithm for our single cell sequencing data analysis, not only for its speed but also for its potential biological interpretation under low rank constraints. Thank you !! |
Hi,
I really love the
LRSinkhorn
. Not just for its speed (but yeah, it is really, fast), I think this method is natural and may have broader implications in practical problems as well.The current problem is that I cannot use it for UOT problems (but the codes do seem to have the support for that!)
I got the warning message that UOT is not supported for LRSinkhorn. But it seems the codes do include the functionality for computing UOT. And also, the documentation includes some pages on
ott.solvers.linear.lr_utils.unbalanced_dykstra_lse
andott.solvers.linear.lr_utils.unbalanced_dykstra_kernel
. I don't know how to use these methods.(I can calculate the UOT with
batch-wise
method, which will give me the dual solution. But in this case, I wish to get the transport map which can tell me which samples in the larger distribution is mapped to).Thank you!!
The text was updated successfully, but these errors were encountered: