-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feature request] Parallelism support for sequential plate/guide-side enumeration #3219
Comments
Hmm, I'd guess the most straightforward approach to inter-distribution cpu parallelism would be to rely on the PyTorch jit by simply using JitTrace_ELBO or similar guide. Pros:
Cons:
|
@amifalk Did you have any progress in this area? I'm facing with the same issue when dealing with model selection from a set of models with significantly different structure. I have a partial solution of using However, for complicated model structures and large set of models, the masking becomes quite complicated and prone to mistakes that can not be easily debugged. |
Sorry, no updates currently @pavleb. We ended up resolving speed issues by moving over to numpyro. |
For mixture models with arbitrary distributions over each feature, sampling currently must be done serially, even though these operations are trivially parallelizable.
To sample priors from a hierarchical mixture model with one continuous and one binary feature, you would need to do something like
For mixture models with large number of features, this can become very slow.
I would love to be able to use a Joblib-like syntax for loops like these, i.e.
I have tried something like this, and something about the Joblib backend and Pyro don't play nicely together-the model doesn't converge.
In a similar vein, adding parallelism for sequential guide-side enumeration could also enable dramatic speedups. For example, when trying to fit CrossCat with SVI and two truncated stick breaking processes over views and clusters (my personal use-case), enumerating out the view assignments in the model is not possible. Enumerating the views out in the guide is much too slow if they can't be done simultaneously over multiple cores. Since each model run doesn't share information with the others it seems like this should be possible in theory.
I realize this may be difficult for reasons mentioned in #2354, but is any parallelism like this possible in Pyro?
The text was updated successfully, but these errors were encountered: