Skip to content

Commit

Permalink
Merge pull request #232 from lnccbrown/225-better-defaults-for-models…
Browse files Browse the repository at this point in the history
…ample

Better defaults for model.sample()
  • Loading branch information
digicosmos86 authored Jul 25, 2023
2 parents 7c5edec + 1d917ae commit 2c13fac
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
12 changes: 3 additions & 9 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,9 @@ Some functionalities in HSSM are available through optional dependencies.

### Sampling with JAX through `numpyro` or `blackjax`

JAX-based sampling is done through `numpyro` and `blackjax`. You need to have `numpyro`
installed if you want to use the `nuts_numpyro` sampler.

```bash
pip install numpyro
```

Likewise, you need to have `blackjax` installed if you want to use the `nuts_blackjax`
sampler.
JAX-based sampling is done through `numpyro` and `blackjax`. `numpyro` is installed as
a dependency by default. You need to have `blackjax` installed if you want to use the
`nuts_blackjax` sampler.

```bash
pip install blackjax
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ ssm-simulators = "^0.3.0"
huggingface-hub = "^0.15.1"
onnxruntime = "^1.15.0"
bambi = "^0.12.0"
numpyro = "^0.12.1"

[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
Expand All @@ -39,7 +40,6 @@ hddm-wfpt = { git = "https://github.com/brown-ccv/hddm-wfpt.git" }
ipywidgets = "^8.0.3"
graphviz = "^0.20.1"
ruff = "^0.0.272"
numpyro = "^0.12.1"
mkdocs = "^1.4.3"
mkdocs-material = "^9.1.17"
mkdocstrings-python = "^1.1.2"
Expand Down
33 changes: 28 additions & 5 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,19 +481,21 @@ def _transform_params(

def sample(
self,
sampler: Literal[
"mcmc", "nuts_numpyro", "nuts_blackjax", "laplace", "vi"
] = "mcmc",
sampler: Literal["mcmc", "nuts_numpyro", "nuts_blackjax", "laplace", "vi"]
| None = None,
**kwargs,
) -> az.InferenceData | pm.Approximation:
"""Perform sampling using the `fit` method via bambi.Model.
Parameters
----------
sampler
The sampler to use. Can be either "mcmc" (default), "nuts_numpyro",
The sampler to use. Can be one of "mcmc", "nuts_numpyro",
"nuts_blackjax", "laplace", or "vi". If using `blackbox` likelihoods,
this cannot be "nuts_numpyro" or "nuts_blackjax".
this cannot be "nuts_numpyro" or "nuts_blackjax". By default it is None, and
sampler will automatically be chosen: when the model uses the
`approx_differentiable` likelihood, and `jax` backend, "nuts_numpyro" will
be used. Otherwise, "mcmc" (the default PyMC NUTS sampler) will be used.
kwargs
Other arguments passed to bmb.Model.fit(). Please see [here]
(https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit)
Expand All @@ -506,6 +508,15 @@ def sample(
(default), "nuts_numpyro", "nuts_blackjax" or "laplace". An `Approximation`
object if `"vi"`.
"""
if sampler is None:
if (
self.loglik_kind == "approx_differentiable"
and self.model_config["backend"] == "jax"
):
sampler = "nuts_numpyro"
else:
sampler = "mcmc"

supported_samplers = ["mcmc", "nuts_numpyro", "nuts_blackjax", "laplace", "vi"]

if sampler not in supported_samplers:
Expand All @@ -522,6 +533,18 @@ def sample(
if "step" not in kwargs:
kwargs["step"] = pm.Slice(model=self.pymc_model)

if (
self.loglik_kind == "approx_differentiable"
and self.model_config["backend"] == "jax"
and sampler == "mcmc"
and kwargs.get("cores", None) != 1
):
_logger.warning(
"Parallel sampling might not work with `jax` backend and the PyMC NUTS "
+ "sampler on some platforms. Please consider using `nuts_numpyro` or "
+ "`nuts_blackjax` sampler if that is a problem."
)

self._inference_obj = self.model.fit(inference_method=sampler, **kwargs)

return self.traces
Expand Down

0 comments on commit 2c13fac

Please sign in to comment.