Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is any method implemented in ott to help reduce memory overhead? (Update: I found batch-sizeoption) #422

Open
feiyang-k opened this issue Sep 1, 2023 · 3 comments
Labels
question Further information is requested

Comments

@feiyang-k
Copy link

Hi,

In our ML tasks, the problem of scale is often defined by num_of_training_samples by num_of_validation_samples. Our GPUs currently has 40~80 GB memory per card, which could handle problem of sizes around 350k by 10k. This is fine for classic datasets such as CIFAR-10, but is still away from million scales for modern datasets (or billion scales for language corpus). Is there any standard approach to reduce memory overhead? Is there any approximation method or batch-wise methods provided in the package that help with memory usage?

Thanks!

Update: I found batch-sizeoption for the online cost computation, which looks highly relevant!

@marcocuturi
Copy link
Contributor

Hi @feiyang-k !!

Indeed, scalability will be an important issue for OT tasks. There are two possible workarounds at the moment:

  • the batch_size option when creating a geometry object. This ensures the num_train x num_val matrix you mentioned is not materialized.
  • the low rank approach. In that case you would need to run a LR sinkhorn solver. Those scale much better usually. A very basic tutorial can be seen here

@michalk8 michalk8 added the question Further information is requested label Sep 1, 2023
@feiyang-k
Copy link
Author

Thanks @! These methods are marvelous! It works like magic. I had been having some vague ideas on improving the efficiency of the OT solution. The moment I saw the idea of these methods, I finally understood what ideas are about :)

The batch-size works smoothly. I'm very interested in this 'low-rank' method. It seems to me it implements the idea similar to clustering, which has been seen as an effort to improve the scalability for computing distributional divergence metrics, but in a much more natural and elegant way. I really enjoyed reading this!

I read through the reference paper as well as the documentation.

Different from regular 'Sinkhorn' algorithm that approaches the dual problem, which is the same as other OT solvers, LRSinkhorn directly performs the low-rank factorization to the coupling matrix, which solves the primal problem. The results are given in q, r, g, where the products will give the transport plan.

Interestingly, I'm interested in the dual solutions more than the primal. I need the derivatives to perform some exploration into the corresponding practical problem. Directly recovering the dual solution from the primal can be tricky due to numerical issues. In this case, do you have any idea what will be a good way to obtain the dual solutions?

I also tried using the auto-differentiation in JAX for that, where the dual solutions will be the derivative of the marginals a and b, but I ran into memory issues immediately. I saw similar issues in another post. But the difference here is it seems I cannot use danskin for LRSinkhorn, is it?

More interesting, I saw it in the source code and some part of the document that there is the attribute for use_danskin, but setting it to True does not seem to make any difference. Is this a usable function?

@feiyang-k feiyang-k reopened this Sep 8, 2023
@feiyang-k
Copy link
Author

Thanks @! These methods are marvelous! It works like magic. I had been having some vague ideas on improving the efficiency of the OT solution. The moment I saw the idea of these methods, I finally understood what ideas are about :)

The batch-size works smoothly. I'm very interested in this 'low-rank' method. It seems to me it implements the idea similar to clustering, which has been seen as an effort to improve the scalability for computing distributional divergence metrics, but in a much more natural and elegant way. I really enjoyed reading this!

I read through the reference paper as well as the documentation.

Different from regular 'Sinkhorn' algorithm that approaches the dual problem, which is the same as other OT solvers, LRSinkhorn directly performs the low-rank factorization to the coupling matrix, which solves the primal problem. The results are given in q, r, g, where the products will give the transport plan.

  • Interestingly, I'm interested in the dual solutions more than the primal. I need the derivatives to perform some exploration into the corresponding practical problem. Directly recovering the dual solution from the primal can be tricky due to numerical issues. In this case, do you have any idea what will be a good way to obtain the dual solutions?

  • I also tried using the auto-differentiation in JAX for that, where the dual solutions will be the derivative of the marginals a and b, but I ran into memory issues immediately. I saw similar issues in another post. But the difference here is it seems I cannot use danskin for LRSinkhorn, is it?

  • More interesting, I saw it in the source code and some part of the document that there is the attribute for use_danskin, but setting it to True does not seem to make any difference. Is this a usable function?

Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants