Skip to content

Commit

Permalink
Implement censored families (#697)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto authored Jul 20, 2023
1 parent e868ede commit cbbf955
Show file tree
Hide file tree
Showing 13 changed files with 191 additions and 11 deletions.
2 changes: 1 addition & 1 deletion bambi/backend/model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def build_common_terms(self, pymc_backend, bmb_model):

# If there's an intercept, center the data
# Also store the design matrix without the intercept to uncenter the intercept later
if self.has_intercept:
if self.has_intercept and bmb_model.center_predictors:
self.design_matrix_without_intercept = data
data = data - data.mean(0)

Expand Down
6 changes: 5 additions & 1 deletion bambi/backend/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,11 @@ def _clean_results(self, idata, omit_offsets, include_mean):

for pymc_component in self.distributional_components.values():
bambi_component = pymc_component.component
if bambi_component.intercept_term and bambi_component.common_terms:
if (
bambi_component.intercept_term
and bambi_component.common_terms
and self.spec.center_predictors
):
chain_n = len(idata.posterior["chain"])
draw_n = len(idata.posterior["draw"])
shape, dims = (chain_n, draw_n), ("chain", "draw")
Expand Down
28 changes: 27 additions & 1 deletion bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,30 @@ def build_response_distribution(self, kwargs, pymc_backend):
kwargs = self.family.transform_backend_kwargs(kwargs)

kwargs = self.robustify_dims(pymc_backend, kwargs)
return distribution(self.name, **kwargs)

if self.term.is_censored:
dims = kwargs.pop("dims", None)
data_matrix = kwargs.pop("observed")

# Get values of the response variable
observed = np.squeeze(data_matrix[:, 0])

# Get censoring codes
censoring_code = np.squeeze(data_matrix[:, 1])

is_left_censored = censoring_code == -1
is_right_censored = censoring_code == 1

lower = np.where(is_left_censored, observed, -np.inf)
upper = np.where(is_right_censored, observed, np.inf)
stateless_dist = distribution.dist(**kwargs)
dist_rv = pm.Censored(
self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims
)
else:
dist_rv = distribution(self.name, **kwargs)

return dist_rv

@property
def name(self):
Expand All @@ -293,6 +316,9 @@ def robustify_dims(self, pymc_backend, kwargs):
if isinstance(self.family, (Multinomial, DirichletMultinomial)):
return kwargs

if self.term.is_censored:
return kwargs

dims, data = kwargs["dims"], kwargs["observed"]
dims_n = len(dims)
ndim_diff = data.ndim - dims_n
Expand Down
21 changes: 21 additions & 0 deletions bambi/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,27 @@
* vs: Engine (0 = V-shaped, 1 = straight)
* am: Transmission (0 = automatic, 1 = manual)
* gear: Number of forward gears
""",
),
"kidney": FileMetadata(
filename="kidney.csv",
url="https://figshare.com/ndownloader/files/41645361",
checksum="46e49372b4e8c3044dca0ffbb4eb2244f56d7398746802e351baac6c12625564",
description="""
It describes the first and second recurrence times of infection in kidney patients together with
information on risk variables such as age, sex, and disease type.
This dataset is taken from McGilchrist and Aisbett (1991).
* time: Days to first or second recurrence of the infection, or the time of censoring
* censored: Indicates censoring status. 0 indicates no censoring and 1 indicates right censoring
* patient: Patient ID
* recur: Indicates if the infection occurs for first or second time.
* age: Age of the patient
* sex: Sex of the patient
* disease: The type of disease. Can be "AN", "GN", "PKG", or "other"
McGilchrist, C. A., & Aisbett, C. W. (1991). Regression with frailty in survival analysis.
Biometrics, 47(2), 461-466
""",
),
}
Expand Down
21 changes: 21 additions & 0 deletions bambi/defaults/families.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Binomial,
Categorical,
Cumulative,
Exponential,
Gamma,
Gaussian,
HurdleGamma,
Expand All @@ -20,6 +21,7 @@
StudentT,
VonMises,
Wald,
Weibull,
ZeroInflatedBinomial,
ZeroInflatedNegativeBinomial,
ZeroInflatedPoisson,
Expand Down Expand Up @@ -105,6 +107,15 @@
"link": {"a": "log"},
"family": DirichletMultinomial,
},
"exponential": {
"likelihood": {
"name": "Exponential",
"params": ["mu"],
"parent": "mu",
},
"link": {"mu": "log"},
"family": Exponential,
},
"gamma": {
"likelihood": {
"name": "Gamma",
Expand Down Expand Up @@ -243,6 +254,16 @@
"family": Wald,
"default_priors": {"lam": "HalfCauchy"},
},
"weibull": {
"likelihood": {
"name": "Weibull",
"params": ["mu", "alpha"],
"parent": "mu",
},
"link": {"mu": "log", "alpha": "log"},
"family": Weibull,
"default_priors": {"alpha": "HalfCauchy"},
},
"zero_inflated_binomial": {
"likelihood": {
"name": "ZeroInflatedBinomial",
Expand Down
38 changes: 38 additions & 0 deletions bambi/families/univariate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytensor.tensor as pt
import scipy.special as sp
import xarray as xr

from bambi.families.family import Family
Expand Down Expand Up @@ -233,6 +234,22 @@ def transform_kwargs(kwargs):
return kwargs


class Exponential(UnivariateFamily):
SUPPORTED_LINKS = {"mu": ["identity", "log", "inverse"]}

@staticmethod
def transform_backend_kwargs(kwargs):
mu = kwargs.pop("mu")
kwargs["lam"] = 1 / mu
return kwargs

@staticmethod
def transform_kwargs(kwargs):
mu = kwargs.pop("mu")
kwargs["lam"] = 1 / mu
return kwargs


class Gamma(UnivariateFamily):
SUPPORTED_LINKS = {"mu": ["identity", "log", "inverse"], "alpha": ["log"]}

Expand Down Expand Up @@ -407,6 +424,27 @@ class Wald(UnivariateFamily):
SUPPORTED_LINKS = {"mu": ["inverse", "inverse_squared", "identity", "log"], "lam": ["log"]}


class Weibull(UnivariateFamily):
SUPPORTED_LINKS = {"mu": ["log", "identity", "inverse"], "alpha": ["log"]}

@staticmethod
def transform_backend_kwargs(kwargs):
# The Weibull distribution is specified using alpha (shape) and beta (scale).
# We request a prior for alpha and we model 'mu' as a function of the linear predictor.
# Here we determine 'beta' out of the value of 'mu' and 'alpha'
mu = kwargs.pop("mu")
alpha = kwargs.get("alpha")
kwargs["beta"] = mu / pt.gamma(1 + 1 / alpha)
return kwargs

@staticmethod
def transform_kwargs(kwargs):
mu = kwargs.pop("mu")
alpha = kwargs.get("alpha")
kwargs["beta"] = mu / sp.gamma(1 + 1 / alpha)
return kwargs


class ZeroInflatedBinomial(BinomialBaseFamily):
SUPPORTED_LINKS = {
"p": ["identity", "logit", "probit", "cloglog"],
Expand Down
2 changes: 2 additions & 0 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
dropna=False,
auto_scale=True,
noncentered=True,
center_predictors=True,
extra_namespace=None,
):
# attributes that are set later
Expand All @@ -121,6 +122,7 @@ def __init__(
self.formula = formula
self.noncentered = noncentered
self.potentials = potentials
self.center_predictors = center_predictors

# Read and clean data
if not isinstance(data, pd.DataFrame):
Expand Down
4 changes: 2 additions & 2 deletions bambi/terms/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from bambi.terms.base import BaseTerm

# from bambi.new_terms.utils import is_censored_response
from bambi.terms.utils import is_censored_response


class ResponseTerm(BaseTerm):
def __init__(self, response, family):
self.term = response.term.term
self.family = family
# self.is_censored = is_censored_response(self.term.term)
self.is_censored = is_censored_response(self.term)

@property
def term(self):
Expand Down
2 changes: 1 addition & 1 deletion bambi/terms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def is_censored_response(term):
"""Determines if a formulae term represents a censored response"""
if not is_single_component(term):
return False
component = term.term.components[0] # get the first (and single) component
component = term.components[0] # get the first (and single) component
if not is_call_component(component):
return False
return is_call_of_kind(component, "censored")
6 changes: 4 additions & 2 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Changelog


## 0.X.X

### New features

* Bambi now supports censored responses (#697)
* Implement `"exponential"` and `"weibull"` families (#697)
* Add `"kidney"` dataset (#697)

### Maintenance and fixes

### Documentation


### Deprecation


Expand Down
2 changes: 2 additions & 0 deletions docs/notebooks/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,7 @@
"binomial | Binomial | logit | \n",
"categorical | Categorical | softmax | \n",
"cumulative | Cumulative | logit | \n",
"exponential | Exponential | log | \n",
"dirichlet_multinomial | DirichletMultinomial | logit |\n",
"gamma | Gamma | inverse |\n",
"gaussian | Normal | identity |\n",
Expand All @@ -952,6 +953,7 @@
"t | StudentT | identity |\n",
"vonmises | VonMises | tan(x / 2) |\n",
"wald | InverseGaussian | inverse squared |\n",
"weibull | Weibull | log |\n",
"zero_inflated_binomial | ZeroInflatedBinomial | logit |\n",
"zero_inflated_negativebinomial | ZeroInflatedNegativeBinomial | log |\n",
"zero_inflated_poisson | ZeroInflatedPoisson | log |\n",
Expand Down
50 changes: 50 additions & 0 deletions tests/test_built_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,3 +997,53 @@ def test_predict_new_groups(data, formula, family, df_new, request):
model = bmb.Model(formula, data, family=family)
idata = model.fit(tune=100, draws=100)
model.predict(idata, data=df_new, sample_new_groups=True)


def test_censored_response():
data = bmb.load_data("kidney")
data["status"] = np.where(data["censored"] == 0, "none", "right")

# Model 1, with intercept
priors = {
"Intercept": bmb.Prior("Normal", mu=0, sigma=1),
"sex": bmb.Prior("Normal", mu=0, sigma=2),
"age": bmb.Prior("Normal", mu=0, sigma=1),
"alpha": bmb.Prior("Gamma", alpha=3, beta=5),
}
model = bmb.Model(
"censored(time, status) ~ 1 + sex + age", data, family="weibull", link="log", priors=priors
)
idata = model.fit(tune=100, draws=100, random_seed=121195)
model.predict(idata, kind="pps")
model.predict(idata, data=data, kind="pps")

# Model 2, without intercept
priors = {
"sex": bmb.Prior("Normal", mu=0, sigma=2),
"age": bmb.Prior("Normal", mu=0, sigma=1),
"alpha": bmb.Prior("Gamma", alpha=3, beta=5),
}
model = bmb.Model(
"censored(time, status) ~ 0 + sex + age", data, family="weibull", link="log", priors=priors
)
idata = model.fit(tune=100, draws=100, random_seed=121195)
model.predict(idata, kind="pps")
model.predict(idata, data=data, kind="pps")

# Model 3, with group-specific effects
priors = {
"alpha": bmb.Prior("Gamma", alpha=3, beta=5),
"sex": bmb.Prior("Normal", mu=0, sigma=1),
"age": bmb.Prior("Normal", mu=0, sigma=1),
"1|patient": bmb.Prior("Normal", mu=0, sigma=bmb.Prior("InverseGamma", alpha=5, beta=10)),
}
model = bmb.Model(
"censored(time, status) ~ 1 + sex + age + (1|patient)",
data,
family="weibull",
link="log",
priors=priors,
)
idata = model.fit(tune=100, draws=100, random_seed=121195)
model.predict(idata, kind="pps")
model.predict(idata, data=data, kind="pps")
20 changes: 17 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import pandas as pd

import bambi as bmb
from bambi.utils import listify
from bambi.backend.pymc import probit, cloglog
from bambi.transformations import censored
Expand All @@ -25,7 +24,6 @@ def test_cloglog():
assert (x > 0).all() and (x < 1).all()


@pytest.mark.skip(reason="Censored still not ported")
def test_censored():
df = pd.DataFrame(
{
Expand All @@ -39,9 +37,25 @@ def test_censored():

x = censored(df["x"], df["status"])
assert x.shape == (5, 2)
assert (x[:, -1] == np.array([0, 1, 2, -1, 0])).all()

x = censored(df["x"], df["y"], df["status"])
assert x.shape == (5, 3)
assert (x[:, -1] == np.array([0, 1, 2, -1, 0])).all()

with pytest.raises(AssertionError):
# Statuses are not the expected
with pytest.raises(AssertionError, match="Statuses must be in"):
censored(df_bad["x"], df_bad["status"])

# Upper bound is not always larger than lower bound
df_bad = pd.DataFrame({"l": [1, 2], "r": [1, 1], "status": ["foo", "bar"]})

with pytest.raises(AssertionError, match="Upper bound must be larger than lower bound"):
censored(df_bad["l"], df_bad["r"], df_bad["status"])

# Bad number of arguments
with pytest.raises(ValueError, match="needs 2 or 3 argument values"):
censored(df["x"])

with pytest.raises(ValueError, match="needs 2 or 3 argument values"):
censored(df["x"], df["x"], df["x"], df["x"])

0 comments on commit cbbf955

Please sign in to comment.