Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constrained Hurdle Gamma Error #768

Open
zwelitunyiswa opened this issue Dec 30, 2023 · 4 comments
Open

Constrained Hurdle Gamma Error #768

zwelitunyiswa opened this issue Dec 30, 2023 · 4 comments

Comments

@zwelitunyiswa
Copy link

@tomicapretto I attempted to used the constrained function on a hurdle gamma model and got a "NotImplementedError: Truncation not implemented for SymbolicRandomVariable MarginalMixtureRV{inline=True}" error.

@tomicapretto
Copy link
Collaborator

Hi there! Could you share a basic example? To me, it looks like it's a PyMC thing. With an example I could reproduce the PyMC model and see if the problem is there.

@zwelitunyiswa
Copy link
Author

zwelitunyiswa commented Jan 5, 2024

@tomicapretto

No problem. Try this:

draws=20

dist_pre = pm.Gamma.dist(mu=9, sigma=10)
draws_pre = pm.draw(dist_pre, draws=40)

dist_post = pm.HurdleGamma.dist(psi=.7, mu=5, sigma=3)
draws_post = pm.draw(dist_post, draws=40)

groups = ['group1'] * 20 + ['group2'] * 20

data = pl.DataFrame({'group':groups, 'pre':draws_pre, 'post':draws_post}).with_columns(pl.all().exclude('group').round(2))

formula = bmb.Formula('constrained(post, 0, 30) ~ group + pre')

model = bmb.Model(formula=formula, data=data.to_pandas(), family='hurdle_gamma', link='log')

idata = model.fit()

idata

@GStechschulte
Copy link
Collaborator

Hey @zwelitunyiswa, I haven't ignored this. Things are a bit busy at work, but I will try to get to it soon. Thanks!

@tomicapretto
Copy link
Collaborator

The problem is that truncation is not implemented for mixtures such as the HurdleLogNormal distribution.

import numpy as np
import pymc as pm

dist_pre = pm.Gamma.dist(mu=9, sigma=10)
draws_pre = pm.draw(dist_pre, draws=40)

dist_post = pm.HurdleGamma.dist(psi=.7, mu=5, sigma=3)
draws_post = pm.draw(dist_post, draws=40)

groups = np.asarray(['group1'] * 20 + ['group2'] * 20)
group_names, group_idxs = np.unique(groups, return_inverse=True)
coords = {"group": group_names}

# This works
with pm.Model(coords=coords) as model:
    baseline = pm.Normal("baseline", dims="group")
    slope = pm.Normal("slope")

    mu = pm.math.exp(baseline[group_idxs] + slope * draws_pre)
    sigma = pm.HalfNormal("sigma")
    psi = pm.Beta("psi", alpha=2, beta=2)

    pm.HurdleLogNormal("post", psi=psi, mu=mu, sigma=sigma, observed=draws_post)

    idata = pm.sample(tune=100, draws=100, random_seed=1234)

# This does not work
with pm.Model(coords=coords) as model_censored:
    baseline = pm.Normal("baseline", dims="group")
    slope = pm.Normal("slope")

    mu = pm.math.exp(baseline[group_idxs] + slope * draws_pre)
    sigma = pm.HalfNormal("sigma")
    psi = pm.Beta("psi", alpha=2, beta=2)

    pm.Truncated(
        "post",
        pm.HurdleLogNormal.dist(psi=psi, mu=mu, sigma=sigma),
        lower=0,
        upper=30,
        observed=draws_post
    )

    idata = pm.sample(tune=100, draws=100, random_seed=1234)

On top of that, why do you use lower=0? The support of the distribution you're using already constrains values to be in the [0, infty) interval.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants