Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UOT: how to use LRSinkhorn for UOT? #428

Open
feiyang-k opened this issue Sep 8, 2023 · 7 comments
Open

UOT: how to use LRSinkhorn for UOT? #428

feiyang-k opened this issue Sep 8, 2023 · 7 comments
Assignees
Labels
question Further information is requested

Comments

@feiyang-k
Copy link

Hi,

I really love the LRSinkhorn. Not just for its speed (but yeah, it is really, fast), I think this method is natural and may have broader implications in practical problems as well.

The current problem is that I cannot use it for UOT problems (but the codes do seem to have the support for that!)

I got the warning message that UOT is not supported for LRSinkhorn. But it seems the codes do include the functionality for computing UOT. And also, the documentation includes some pages on ott.solvers.linear.lr_utils.unbalanced_dykstra_lse and ott.solvers.linear.lr_utils.unbalanced_dykstra_kernel. I don't know how to use these methods.

  • Is it possible to calculate UOT with LRSinkhorn with the current ott package? Or is there any alternative way to do that?

(I can calculate the UOT with batch-wise method, which will give me the dual solution. But in this case, I wish to get the transport map which can tell me which samples in the larger distribution is mapped to).

Thank you!!

@marcocuturi
Copy link
Contributor

Thanks for your interest in unbalanced LR!

We've laid out the ideas here: https://arxiv.org/abs/2305.19727 and @michalk8 has been pushing these modifications. We haven't written a proper tutorial for it yet, though, and we might need a few more weeks.

@michalk8
Copy link
Collaborator

michalk8 commented Sep 8, 2023

Hi @feiyang-k ,

Is it possible to calculate UOT with LRSinkhorn with the current ott package? Or is there any alternative way to do that?

we've just pushed the ott-jax==0.4.4 version, could you please try updating?

And also, the documentation includes some pages on ott.solvers.linear.lr_utils.unbalanced_dykstra_lse and ott.solvers.linear.lr_utils.unbalanced_dykstra_kernel. I don't know how to use these methods.

These aren't supposed to be used by the user, you can just use the LRSinkhorn, as in this tutorial and pass tau_a or tau_b into the LinearProblem. The 2 above-mentioned functions are exposed in the docs so that users know which arguments can be passed via kwargs_dys in LRSinkhorn.

(I can calculate the UOT with batch-wise method, which will give me the dual solution. But in this case, I wish to get the transport map which can tell me which samples in the larger distribution is mapped to).

Both UOT and ULOT give you access to the transport map (both to materialize it, which however defeats the purpose of the LR, as well as how to apply the map to a vector/matrix).

We haven't added a tutorial for unbalanced LR solvers yes, but as @marcocuturi says, we will add one in the near future.
P.S.: if the resulting coupling is not good (by some metric you're using to evaluate), consider using the k-means initializer for LRSinkhorn (default is random) + look into the convergence/cost curves stored in out = solver(prob); out.errors; out.costs.

@michalk8 michalk8 added the question Further information is requested label Sep 8, 2023
@michalk8 michalk8 self-assigned this Sep 8, 2023
@feiyang-k
Copy link
Author

feiyang-k commented Sep 11, 2023

Hi @marcocuturi and @michalk8 ,

Thanks so much for the detailed info! The referenced paper looks exciting. This is exactly what I've been looking for.

I updated to ott-jax==0.4.4 with jax==0.4.6 and jaxlib==0.4.6+cuda11+cudnn82. I tried a problem with the scale 1k by 10k. LRSinkhorn with rank=200 solves it in a few seconds. If I change it ot_prob = linear_problem.LinearProblem(geom, tau_a=0.1), the solution will not finish in minutes. Changing the initializer to k-means does not help.

I tried the docs branch with ott-jax==0.4.5dev and also updated to jax==0.4.10 and jaxlib==0.4.10+cuda12+cudnn88. The situation is the same. The GPU is NVIDIA RTX A6000 with Driver Version: 530.30.02 CUDA Version: 12.1 and Python 3.9.0.

Is there any idea on this? What is the jaxlib version you are using while developing these functions?

Thanks!


This block of a 10k*10k LR-OT problem solves in 6.8s


from ott.solvers.linear import sinkhorn_lr

geom = pointcloud.PointCloud(cld10k, cld10k, epsilon=1e-3)
ot_prob = linear_problem.LinearProblem(geom)

solver = sinkhorn_lr.LRSinkhorn(rank=int(200))
ot_lr = solver(ot_prob)
transp_cost = ot_lr.compute_reg_ot_cost(ot_prob, use_danskin=True)
transp_cost

This block of a 1k*10k LR-UOT problem does not complete after 10 minutes


from ott.solvers.linear import sinkhorn_lr

geom = pointcloud.PointCloud(cld1k, cld10k, epsilon=1e-3)
ot_prob = linear_problem.LinearProblem(geom, tau_a=0.1)

solver = jax.jit(sinkhorn_lr.LRSinkhorn(rank=200, initializer="k-means"))
ot_lr = solver(ot_prob)
transp_cost = ot_lr.compute_reg_ot_cost(ot_prob)
transp_cost

@michalk8
Copy link
Collaborator

Here are some of my thoughts:

  • I don't think jaxlib is the cause of the issues, as long as jax can correctly use it, it should be fine
  • tau_a might be too low, would try increasing it (unless you really want such unbalanced problem in your application)
  • the other slow-down can come the from k-means initializer, since by default, it does 100 iterations (with 10 k-means++ random initializations, similarly as in scikit-learn)
  • to check whether initialization slows it down (my hunch is yes, since you cluster 1k/10k points into 200 clusters), easiest way would be to pass a callback function to see how much does 1 Sinkhorn iteration take (should be similar or even faster than in the balanced case)
  • the other thing one can play with are the parameters of the Dykstra's algorithm, can be passed as LRSinkhorn(..., kwargs_dys={"min_iter": ...})
  • adding a bit of unbalancedness as in tau_b=0.999 may improve the convergence of the inner Dykstra iterations
  • I'd also check whether the convergence threshold makes sense (by plotting ot_lr.errors); maybe it converges, but never reaches it (which would result in doing the full 2k iterations)
  • lastly, there's no need to run ot_lr.compute_reg_ot_cost in the end, you can just access the pre-computed property transp_cost = ot_lr.reg_ot_cost

@feiyang-k
Copy link
Author

feiyang-k commented Sep 11, 2023

Thanks @michalk8 ! This helped a lot. I tried this one by one.

  • In this particular case, it seems k-means isn't particularly helpful as the initial err seems even larger.

  • To my experience, k-means clustering with scikit-learn for 10k samples and 200 clusters usually finishes in seconds.

  • The callback functions are super useful, especially for parameter tuning!

  • I added a gamma parameter to the solver function. This time it converges timely! gamma=1 or gamma=0.1 both work well; gamma=10 won't converge. This may be helpful to add to the tutorial :)

solver = sinkhorn_lr.LRSinkhorn(rank=200, initializer="random", progress_fn=progress_fn, gamma=0.1, kwargs_dys={"max_iter":1000})

Thanks so much!

@feiyang-k feiyang-k reopened this Sep 11, 2023
@michalk8
Copy link
Collaborator

I added a gamma parameter to the solver function. This time it converges timely! gamma=1 or gamma=0.1 both work well; gamma=10 won't converge. This may be helpful to add to the tutorial :)

True, we should add this to the tutorial, will create an issue for this. Also we should mention there (+ in the docs) to have gamma * epsilon < 1 (when also using the entropic regularization).

@duzc-Repos
Copy link

Hi, I want use ULOT algorithm for our single cell sequencing data analysis, not only for its speed but also for its potential biological interpretation under low rank constraints.
The question is whether the LOT or ULOT algorithm is suitable for data violating the unit simplex. More specifically, Let $x, y \in R^n$ be two gene expression vectors (non-negative), can I got the transport plan between $x, y$ under $||x||_1=a$ and $||y||_1=b$ where $a, b&gt;1, a \not= b$?
I don't want to normalise these two vectors, because the expression itself is meaningful information (somehow, this is a unbalanced optimal mass transport?).

Thank you !!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants