Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use bayeux to access a wide range of samplers #775

Merged
merged 34 commits into from
Mar 29, 2024

Conversation

GStechschulte
Copy link
Collaborator

@GStechschulte GStechschulte commented Feb 4, 2024

I have been following @ColCarroll bayeux library and thought it would be interesting to see how Bambi could incorporate it to offer the users a wide range of samplers (more than nuts_blackjax and nuts_numpyro).

Edit: Now I access the samplers programmatically using the inference_method arg. This removes previously needed code for nuts_blackjax and nuts_numpyro. If a user passes an MCMC inference method other than the PyMC MCMC sampler mcmc, Bambi will use bayeux to call that sampler.

data = bmb.load_data("ANES")
clinton_data = data.loc[data["vote"].isin(["clinton", "trump"]), :]

model = bmb.Model("vote['clinton'] ~ party_id + party_id:age", clinton_data, family="bernoulli")
model.build()

idata = model.fit(inference_method="blackjax_hmc")

However, when cleaning the InferenceData, I am getting an xarray error

"name": "ValueError",
	"message": "('chain', 'draw') must be a permuted list of FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 500, 'party_id_dim_0': 2, 'party_id:age_dim_0': 3}), unless `...` is included"

It seems xarray is not liking something that bayeux is doing with the InferenceData.

Another thought is that using bayeux with Bambi is so easy

model = bmb.Model()
model.build()

bx_model = bx.Model.from_pymc(model.backend.model)
bx_model.<some sampling func>

that maybe we just add documentation explaining how to use bayeux with Bambi to avoid overhead from Bambi's side?

To Do:

  • add additional tests in test_alternative_samplers.py
  • update docstrings referencing JAX based samplers
  • add docs on using alternative backends

@ColCarroll
Copy link
Collaborator

Really cool! A few suggestions --

  • bayeux could be invisible here, and you could access all the methods programatically. That's done here, but i can factor that out into a function that gives methods instead of strings -- currently it only adds a method if the underlying library, e.g. optax, is installed. i'm not sure how to avoid using a string at some point. You could have an api like model.fit.bx that initializes the bayeux.Model?
  • I'm happy to add a from_bambi constructor on the bayeux side to make your second option even easier.

@ahartikainen
Copy link
Collaborator

What kind of API bayex has? Could we enable support for external samplers if we define specific API we support (need)? (Users could create class for external samplers if needed?)

Of course that does not mean we could not have a text based support on certain libraries?

@ColCarroll
Copy link
Collaborator

bayeux is inspired by arviz, in that it just provides a representation of a model that is general enough for most samplers, but it does make the decision that it is specialized to JAX-based models (most of the algorithms use autodiff, vectorization is baked in, and automatic function inverses are also used). If you've got a sampler that accepts a JAX-based log density, you could use bayeux with it (or contribute it to bayeux!)

@GStechschulte
Copy link
Collaborator Author

Really cool! A few suggestions --

  • bayeux could be invisible here, and you could access all the methods programatically. That's done here, but i can factor that out into a function that gives methods instead of strings -- currently it only adds a method if the underlying library, e.g. optax, is installed. i'm not sure how to avoid using a string at some point. You could have an api like model.fit.bx that initializes the bayeux.Model?
  • I'm happy to add a from_bambi constructor on the bayeux side to make your second option even easier.

Thanks for the suggestions! That makes sense. I am liking the second option, but I will run some ideas past the others first before asking for the feature. Thanks!

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@GStechschulte
Copy link
Collaborator Author

GStechschulte commented Feb 9, 2024

In this example, the error is because bayeux is appending _0 to party_id_dim of the posterior dims. This results in Bambi discarding all posterior dims because the dims in the PyMC model are inconsistent with the dims of the InferenceData returned by bayeux.

For example:

print(bayeux_idata.posterior.dims, pymc_idata.posterior.dims)
(FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 500, 'party_id_dim_0': 2, 'party_id:age_dim_0': 3}),
 FrozenMappingWarningOnValuesAccess({'chain': 4, 'draw': 1000, 'party_id_dim': 2, 'party_id:age_dim': 3}))

@GStechschulte
Copy link
Collaborator Author

GStechschulte commented Feb 10, 2024

In this example, the error is because bayeux is appending _0 to party_id_dim of the posterior dims. This results in Bambi discarding all posterior dims because the dims in the PyMC model are inconsistent with the dims of the InferenceData returned by bayeux.

For example:

print(bayeux_idata.posterior.dims, pymc_idata.posterior.dims)
(FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 500, 'party_id_dim_0': 2, 'party_id:age_dim_0': 3}),
FrozenMappingWarningOnValuesAccess({'chain': 4, 'draw': 1000, 'party_id_dim': 2, 'party_id:age_dim': 3}))

Update: I have added logic in the cleaning of idata to: (1) identify bayeux idata and to remove the trailing numeric suffix from the _dims, and (2) to rename the posterior dims to be consistent with the PyMC model coords.

Although this works for simple models, I haven't tried this logic with more complex models in Bambi such as HSGP or with models that have a large number of dims and or factors. Since the idata contains very "important data", I also think it could be worthwhile to not clean idata when the user calls samplers from bayeux at the moment in order to avoid unknown effects appearing in the inference data.

@zwelitunyiswa
Copy link

zwelitunyiswa commented Feb 10, 2024 via email

@ColCarroll
Copy link
Collaborator

I think you could -- @GStechschulte has a good outline here. If @tomicapretto thinks this is a reasonable idea in principle, I'd be happy to either collaborate on this to (programatically) get the bayeux inference methods in, or send a follow-up that generalizes it a bit, and allows optimization and VI.

@tomicapretto
Copy link
Collaborator

I think this is really cool, thanks @GStechschulte and thanks @ColCarroll for bayeux. I'm not sure I am aware of all the details, but what is the reason why bayeux is appending _dim_0 to dimension names? As far as I remember that was an xarray thing. Or is it that bayeux is not receiving dimension names from the PyMC model and thus it appends _dim_0?

Another thing, I see we're replacing blackjax, jax, jaxlib, and numpyro with bayeux. However, as far as I know bayeux does not install these dependencies, so doing pip install bambi[jax] won't give users access to JAX based samplers, right? (I'm not very familiar with bayeux so I may be wrong).

@ColCarroll
Copy link
Collaborator

bayeux will pull those in, but i agree it is better to be explicit and require dependencies (in case bayeux makes weird decisions).

I'll double check on the naming conventions!

@ColCarroll
Copy link
Collaborator

Oh right, yes: bayeux has no concept of the dimensions from pymc. That would have to be implemented as a post-processing step to rename the arviz dimensions.

@tomicapretto
Copy link
Collaborator

Oh right, yes: bayeux has no concept of the dimensions from pymc. That would have to be implemented as a post-processing step to rename the arviz dimensions.

Thanks for the answer, it makes much more sense now!

@GStechschulte
Copy link
Collaborator Author

@ColCarroll thanks for the information.

@tomicapretto I can apply this post processing step on Bambi's side.

@tomicapretto
Copy link
Collaborator

@ColCarroll thanks for the information.

@tomicapretto I can apply this post processing step on Bambi's side.

Sounds great, just let me know if you need help or a second opinion :)

@GStechschulte
Copy link
Collaborator Author

Two updates:

  1. I added a processing step for when Bambi cleans the idata, it renames the idata dims and coordinates to match those of the underlying PyMC model.
  2. I explicitly added JAX based sampler dependencies.

Regarding

I'd be happy to either collaborate on this to (programatically) get the bayeux inference methods in, or send a follow-up that generalizes it a bit, and allows optimization and VI.

@ColCarroll I'd be happy to collaborate and see how you would do this 👍🏼

bambi/backend/pymc.py Outdated Show resolved Hide resolved
bambi/backend/pymc.py Outdated Show resolved Hide resolved
pyproject.toml Outdated Show resolved Hide resolved
@GStechschulte
Copy link
Collaborator Author

@ColCarroll thanks a lot for the review! I will incorporate these in the coming days.

@GStechschulte
Copy link
Collaborator Author

Ugh. pylint is making the CI fail. It says it cannot import bayeux. However, when I check the logs of the step "Install Bambi and all its dependencies", I can see that bayeux was installed.

@ColCarroll
Copy link
Collaborator

the package is named bayeux-ml -- i'm checking now that you got that one.

pyproject.toml Outdated Show resolved Hide resolved
@ColCarroll
Copy link
Collaborator

Oh. Last failures are because bayeux uses Python 3.10-only features (union types with a pipe, like str | None). I can get rid of those, but also NEP 29 suggests that Python 3.9 get dropped in 30 days. I see at least three choices here:

  1. I release a Python 3.9 compatible bayeux version, then we can merge this.
  2. We disable testing on Python 3.9 in this PR and merge. Things are confusing for users on Python 3.9 who pip install bambi[jax], but I can't find a way to restrict python versions in optional dependencies.
  3. We just wait 30 days, do a PR removing Python 3.9 support, then merge this.

Any preferences?

@tomicapretto
Copy link
Collaborator

tomicapretto commented Mar 7, 2024

I'm in favor of following NumPy's deprecation cycle but at the same time I feel it moves faster than many users. For that reason, even when the pipe operator is a feature I really like, I have not been using it because of this (i.e. not working for Python <= 3.9).

According to PyPI stats, users are downloading Bambi using Python 3.7, 3.8, and 3.9. I know PyPI stats may not be a crystalline reflection of reality, but it's still useful. I don't want to break their installations.

So, we could drop official support for 3.9 as soon as NumPy does (this includes testing with Python >= 3.10 only), but I would not start using Python >= 3.10 specific features in the default installation of Bambi right now. I guess there will be a point where it'll become unavoidable (for example, when PyMC or PyTensor start to do so) and that's fine.

What do others think?

Edit I think I didn't answer your questions straightforwardly.

I would wait, do a PR removing Python 3.9 support, and then merge this. I don't want to ask you to work on a Python 3.9 compatible version that will be disposed in the short future (unless you think it's worth it).

@GStechschulte
Copy link
Collaborator Author

Thanks @ColCarroll for the information, and I agree with @tomicapretto thoughts. Looking forward to eventually merge this!

@tomicapretto
Copy link
Collaborator

According to PyPI stats, users are downloading Bambi using Python 3.7, 3.8, and 3.9. I know PyPI stats may not be a crystalline reflection of reality, but it's still useful. I don't want to break their installations.

Just want to add that after some comms in our Slack it looks like we won't be breaking users' installations. They'll just get an older version of Bambi if they use Python < 3.10

@GStechschulte
Copy link
Collaborator Author

@tomicapretto so do we want to wait the $x$ days to drop support for Python 3.9, or to drop it now?

@tomicapretto
Copy link
Collaborator

tomicapretto commented Mar 14, 2024

@tomicapretto so do we want to wait the x days to drop support for Python 3.9, or to drop it now?

sorry for the slow response, we can do it now if you agree on that

edit also, it won't impact most users until we make a release, so now i realize it's even less dangerous

@codecov-commenter
Copy link

codecov-commenter commented Mar 19, 2024

Codecov Report

Attention: Patch coverage is 82.45614% with 10 lines in your changes are missing coverage. Please review.

Project coverage is 90.14%. Comparing base (9a1387a) to head (9f9d769).
Report is 3 commits behind head on main.

Files Patch % Lines
bambi/backend/pymc.py 77.77% 10 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #775      +/-   ##
==========================================
+ Coverage   89.86%   90.14%   +0.27%     
==========================================
  Files          46       46              
  Lines        3810     3836      +26     
==========================================
+ Hits         3424     3458      +34     
+ Misses        386      378       -8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@GStechschulte
Copy link
Collaborator Author

GStechschulte commented Mar 19, 2024

CI failing on 3.12 seems related to this PyTensor issue and corresponding PR.

@tomicapretto
Copy link
Collaborator

CI failing on 3.12 seems related to this PyTensor issue and corresponding PR.

Ha! Should we avoid supporting 3.12 in the immediate future?

@tomicapretto
Copy link
Collaborator

I was giving this a shot and just found that the development version of PyMC now supports Python 3.12

https://github.com/pymc-devs/pymc/blob/61ce412aa599939eaf299a6328059c184e6c25db/setup.py#L33-L36

This is not the case in the latest release 5.11. I think we can just wait a bit until PyMC releases a new version to upgrade the PyMC version requirement in our pyproject.toml and then we're done.

Another option would be to pin PyMC to the current development version but that will have to be updated as soon as PyMC releases the version suporting Python 5.12 so I think it's just better to wait a bit.

https://github.com/pymc-devs/pymc/blob/a06081e1e9649bd56e3528cb96380efdf6bb2dc0/setup.py#L33-L35

@tomicapretto
Copy link
Collaborator

There's a new version of PyMC supporting Python 3.12... 🤞

@tomicapretto
Copy link
Collaborator

@GStechschulte I let you have the honor to click the green button :D

@GStechschulte
Copy link
Collaborator Author

@GStechschulte I let you have the honor to click the green button :D

Woooooo! Gracias! 🎉

@GStechschulte GStechschulte merged commit 714ccb7 into bambinos:main Mar 29, 2024
4 checks passed
@ColCarroll
Copy link
Collaborator

Thanks for all the persistence @GStechschulte!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants