Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FastLogSemiring #108

Open
w-cheng opened this issue Sep 30, 2021 · 3 comments
Open

FastLogSemiring #108

w-cheng opened this issue Sep 30, 2021 · 3 comments

Comments

@w-cheng
Copy link

w-cheng commented Sep 30, 2021

Hi,

Thanks for making this library and it's amazing to have these different CRFs wrapped up in a common and easy to use framework.

I've been playing with the LinearChainCRF and one thing I noticed is the memory usage can be very high during loss backward pass on both CPU and GPU. I found the FastLogSemiring in fast_semirings.py uses genbmm.logbmm() and significantly reduce memory usage on GPU if I change the default LogSemiring used in StructDistribution class to FastLogSemiring. However, I haven't seen this being documented anywhere so my questions are:

  1. Is FastLogSemiring ready to be used? It's not being included in test_semirings.py
  2. If so, what would be the best way to switch between LogSemiring and FastLogSemiring? Is there a plan to introduce a parameter to choose between the semirings in StructDistribution class?
@srush
Copy link
Collaborator

srush commented Sep 30, 2021

Yes! It works and is heavily tested. We should make it default. It just requires the GPU kernels in genbmm be installed.

@w-cheng
Copy link
Author

w-cheng commented Oct 5, 2021

What do you think of performing a check of genbmm library in the imports like:

has_genbmm = False
try:
    import genbmm

    has_genbmm = True
    from .semirings import FastLogSemiring
except ImportError:
    pass

then a function in StructDistribution class:

    def default_log_semiring(self):
        return FastLogSemiring if has_genbmm and self.log_potentials.is_cuda else LogSemiring

So instead of return LogSemiring by default in marginals and partition property we call this default_log_semiring()

@srush
Copy link
Collaborator

srush commented Oct 6, 2021

yes, that would be great. You can do it for max too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants