From cbbf955864304c9773011a44db515b192dfaeb49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Capretto?= Date: Thu, 20 Jul 2023 18:56:07 -0300 Subject: [PATCH] Implement censored families (#697) --- bambi/backend/model_components.py | 2 +- bambi/backend/pymc.py | 6 +++- bambi/backend/terms.py | 28 +++++++++++++++- bambi/data/datasets.py | 21 ++++++++++++ bambi/defaults/families.py | 21 ++++++++++++ bambi/families/univariate.py | 38 +++++++++++++++++++++ bambi/models.py | 2 ++ bambi/terms/response.py | 4 +-- bambi/terms/utils.py | 2 +- docs/CHANGELOG.md | 6 ++-- docs/notebooks/getting_started.ipynb | 2 ++ tests/test_built_models.py | 50 ++++++++++++++++++++++++++++ tests/test_utils.py | 20 +++++++++-- 13 files changed, 191 insertions(+), 11 deletions(-) diff --git a/bambi/backend/model_components.py b/bambi/backend/model_components.py index 7182ce2b6..686959fa5 100644 --- a/bambi/backend/model_components.py +++ b/bambi/backend/model_components.py @@ -107,7 +107,7 @@ def build_common_terms(self, pymc_backend, bmb_model): # If there's an intercept, center the data # Also store the design matrix without the intercept to uncenter the intercept later - if self.has_intercept: + if self.has_intercept and bmb_model.center_predictors: self.design_matrix_without_intercept = data data = data - data.mean(0) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 74f6b7b89..8734dde9b 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -271,7 +271,11 @@ def _clean_results(self, idata, omit_offsets, include_mean): for pymc_component in self.distributional_components.values(): bambi_component = pymc_component.component - if bambi_component.intercept_term and bambi_component.common_terms: + if ( + bambi_component.intercept_term + and bambi_component.common_terms + and self.spec.center_predictors + ): chain_n = len(idata.posterior["chain"]) draw_n = len(idata.posterior["draw"]) shape, dims = (chain_n, draw_n), ("chain", "draw") diff --git a/bambi/backend/terms.py b/bambi/backend/terms.py index c70a93575..5624bd94c 100644 --- a/bambi/backend/terms.py +++ b/bambi/backend/terms.py @@ -275,7 +275,30 @@ def build_response_distribution(self, kwargs, pymc_backend): kwargs = self.family.transform_backend_kwargs(kwargs) kwargs = self.robustify_dims(pymc_backend, kwargs) - return distribution(self.name, **kwargs) + + if self.term.is_censored: + dims = kwargs.pop("dims", None) + data_matrix = kwargs.pop("observed") + + # Get values of the response variable + observed = np.squeeze(data_matrix[:, 0]) + + # Get censoring codes + censoring_code = np.squeeze(data_matrix[:, 1]) + + is_left_censored = censoring_code == -1 + is_right_censored = censoring_code == 1 + + lower = np.where(is_left_censored, observed, -np.inf) + upper = np.where(is_right_censored, observed, np.inf) + stateless_dist = distribution.dist(**kwargs) + dist_rv = pm.Censored( + self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims + ) + else: + dist_rv = distribution(self.name, **kwargs) + + return dist_rv @property def name(self): @@ -293,6 +316,9 @@ def robustify_dims(self, pymc_backend, kwargs): if isinstance(self.family, (Multinomial, DirichletMultinomial)): return kwargs + if self.term.is_censored: + return kwargs + dims, data = kwargs["dims"], kwargs["observed"] dims_n = len(dims) ndim_diff = data.ndim - dims_n diff --git a/bambi/data/datasets.py b/bambi/data/datasets.py index 663a3a6e6..dc063cb18 100644 --- a/bambi/data/datasets.py +++ b/bambi/data/datasets.py @@ -162,6 +162,27 @@ * vs: Engine (0 = V-shaped, 1 = straight) * am: Transmission (0 = automatic, 1 = manual) * gear: Number of forward gears +""", + ), + "kidney": FileMetadata( + filename="kidney.csv", + url="https://figshare.com/ndownloader/files/41645361", + checksum="46e49372b4e8c3044dca0ffbb4eb2244f56d7398746802e351baac6c12625564", + description=""" +It describes the first and second recurrence times of infection in kidney patients together with +information on risk variables such as age, sex, and disease type. +This dataset is taken from McGilchrist and Aisbett (1991). + +* time: Days to first or second recurrence of the infection, or the time of censoring +* censored: Indicates censoring status. 0 indicates no censoring and 1 indicates right censoring +* patient: Patient ID +* recur: Indicates if the infection occurs for first or second time. +* age: Age of the patient +* sex: Sex of the patient +* disease: The type of disease. Can be "AN", "GN", "PKG", or "other" + +McGilchrist, C. A., & Aisbett, C. W. (1991). Regression with frailty in survival analysis. +Biometrics, 47(2), 461-466 """, ), } diff --git a/bambi/defaults/families.py b/bambi/defaults/families.py index ce4729e5b..40d8f15f6 100644 --- a/bambi/defaults/families.py +++ b/bambi/defaults/families.py @@ -7,6 +7,7 @@ Binomial, Categorical, Cumulative, + Exponential, Gamma, Gaussian, HurdleGamma, @@ -20,6 +21,7 @@ StudentT, VonMises, Wald, + Weibull, ZeroInflatedBinomial, ZeroInflatedNegativeBinomial, ZeroInflatedPoisson, @@ -105,6 +107,15 @@ "link": {"a": "log"}, "family": DirichletMultinomial, }, + "exponential": { + "likelihood": { + "name": "Exponential", + "params": ["mu"], + "parent": "mu", + }, + "link": {"mu": "log"}, + "family": Exponential, + }, "gamma": { "likelihood": { "name": "Gamma", @@ -243,6 +254,16 @@ "family": Wald, "default_priors": {"lam": "HalfCauchy"}, }, + "weibull": { + "likelihood": { + "name": "Weibull", + "params": ["mu", "alpha"], + "parent": "mu", + }, + "link": {"mu": "log", "alpha": "log"}, + "family": Weibull, + "default_priors": {"alpha": "HalfCauchy"}, + }, "zero_inflated_binomial": { "likelihood": { "name": "ZeroInflatedBinomial", diff --git a/bambi/families/univariate.py b/bambi/families/univariate.py index 30391a12b..a01f4900e 100644 --- a/bambi/families/univariate.py +++ b/bambi/families/univariate.py @@ -1,5 +1,6 @@ import numpy as np import pytensor.tensor as pt +import scipy.special as sp import xarray as xr from bambi.families.family import Family @@ -233,6 +234,22 @@ def transform_kwargs(kwargs): return kwargs +class Exponential(UnivariateFamily): + SUPPORTED_LINKS = {"mu": ["identity", "log", "inverse"]} + + @staticmethod + def transform_backend_kwargs(kwargs): + mu = kwargs.pop("mu") + kwargs["lam"] = 1 / mu + return kwargs + + @staticmethod + def transform_kwargs(kwargs): + mu = kwargs.pop("mu") + kwargs["lam"] = 1 / mu + return kwargs + + class Gamma(UnivariateFamily): SUPPORTED_LINKS = {"mu": ["identity", "log", "inverse"], "alpha": ["log"]} @@ -407,6 +424,27 @@ class Wald(UnivariateFamily): SUPPORTED_LINKS = {"mu": ["inverse", "inverse_squared", "identity", "log"], "lam": ["log"]} +class Weibull(UnivariateFamily): + SUPPORTED_LINKS = {"mu": ["log", "identity", "inverse"], "alpha": ["log"]} + + @staticmethod + def transform_backend_kwargs(kwargs): + # The Weibull distribution is specified using alpha (shape) and beta (scale). + # We request a prior for alpha and we model 'mu' as a function of the linear predictor. + # Here we determine 'beta' out of the value of 'mu' and 'alpha' + mu = kwargs.pop("mu") + alpha = kwargs.get("alpha") + kwargs["beta"] = mu / pt.gamma(1 + 1 / alpha) + return kwargs + + @staticmethod + def transform_kwargs(kwargs): + mu = kwargs.pop("mu") + alpha = kwargs.get("alpha") + kwargs["beta"] = mu / sp.gamma(1 + 1 / alpha) + return kwargs + + class ZeroInflatedBinomial(BinomialBaseFamily): SUPPORTED_LINKS = { "p": ["identity", "logit", "probit", "cloglog"], diff --git a/bambi/models.py b/bambi/models.py index 2d184ef99..a1bf37c8b 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -104,6 +104,7 @@ def __init__( dropna=False, auto_scale=True, noncentered=True, + center_predictors=True, extra_namespace=None, ): # attributes that are set later @@ -121,6 +122,7 @@ def __init__( self.formula = formula self.noncentered = noncentered self.potentials = potentials + self.center_predictors = center_predictors # Read and clean data if not isinstance(data, pd.DataFrame): diff --git a/bambi/terms/response.py b/bambi/terms/response.py index 56ea4a829..98fe15986 100644 --- a/bambi/terms/response.py +++ b/bambi/terms/response.py @@ -2,14 +2,14 @@ from bambi.terms.base import BaseTerm -# from bambi.new_terms.utils import is_censored_response +from bambi.terms.utils import is_censored_response class ResponseTerm(BaseTerm): def __init__(self, response, family): self.term = response.term.term self.family = family - # self.is_censored = is_censored_response(self.term.term) + self.is_censored = is_censored_response(self.term) @property def term(self): diff --git a/bambi/terms/utils.py b/bambi/terms/utils.py index cdd3f07c2..9d026304e 100644 --- a/bambi/terms/utils.py +++ b/bambi/terms/utils.py @@ -26,7 +26,7 @@ def is_censored_response(term): """Determines if a formulae term represents a censored response""" if not is_single_component(term): return False - component = term.term.components[0] # get the first (and single) component + component = term.components[0] # get the first (and single) component if not is_call_component(component): return False return is_call_of_kind(component, "censored") diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index a43c952cb..08435cd42 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,15 +1,17 @@ # Changelog - ## 0.X.X ### New features +* Bambi now supports censored responses (#697) +* Implement `"exponential"` and `"weibull"` families (#697) +* Add `"kidney"` dataset (#697) + ### Maintenance and fixes ### Documentation - ### Deprecation diff --git a/docs/notebooks/getting_started.ipynb b/docs/notebooks/getting_started.ipynb index 60ad12b3f..2de9bc685 100644 --- a/docs/notebooks/getting_started.ipynb +++ b/docs/notebooks/getting_started.ipynb @@ -937,6 +937,7 @@ "binomial | Binomial | logit | \n", "categorical | Categorical | softmax | \n", "cumulative | Cumulative | logit | \n", + "exponential | Exponential | log | \n", "dirichlet_multinomial | DirichletMultinomial | logit |\n", "gamma | Gamma | inverse |\n", "gaussian | Normal | identity |\n", @@ -952,6 +953,7 @@ "t | StudentT | identity |\n", "vonmises | VonMises | tan(x / 2) |\n", "wald | InverseGaussian | inverse squared |\n", + "weibull | Weibull | log |\n", "zero_inflated_binomial | ZeroInflatedBinomial | logit |\n", "zero_inflated_negativebinomial | ZeroInflatedNegativeBinomial | log |\n", "zero_inflated_poisson | ZeroInflatedPoisson | log |\n", diff --git a/tests/test_built_models.py b/tests/test_built_models.py index 8e2366f60..9362181c5 100644 --- a/tests/test_built_models.py +++ b/tests/test_built_models.py @@ -997,3 +997,53 @@ def test_predict_new_groups(data, formula, family, df_new, request): model = bmb.Model(formula, data, family=family) idata = model.fit(tune=100, draws=100) model.predict(idata, data=df_new, sample_new_groups=True) + + +def test_censored_response(): + data = bmb.load_data("kidney") + data["status"] = np.where(data["censored"] == 0, "none", "right") + + # Model 1, with intercept + priors = { + "Intercept": bmb.Prior("Normal", mu=0, sigma=1), + "sex": bmb.Prior("Normal", mu=0, sigma=2), + "age": bmb.Prior("Normal", mu=0, sigma=1), + "alpha": bmb.Prior("Gamma", alpha=3, beta=5), + } + model = bmb.Model( + "censored(time, status) ~ 1 + sex + age", data, family="weibull", link="log", priors=priors + ) + idata = model.fit(tune=100, draws=100, random_seed=121195) + model.predict(idata, kind="pps") + model.predict(idata, data=data, kind="pps") + + # Model 2, without intercept + priors = { + "sex": bmb.Prior("Normal", mu=0, sigma=2), + "age": bmb.Prior("Normal", mu=0, sigma=1), + "alpha": bmb.Prior("Gamma", alpha=3, beta=5), + } + model = bmb.Model( + "censored(time, status) ~ 0 + sex + age", data, family="weibull", link="log", priors=priors + ) + idata = model.fit(tune=100, draws=100, random_seed=121195) + model.predict(idata, kind="pps") + model.predict(idata, data=data, kind="pps") + + # Model 3, with group-specific effects + priors = { + "alpha": bmb.Prior("Gamma", alpha=3, beta=5), + "sex": bmb.Prior("Normal", mu=0, sigma=1), + "age": bmb.Prior("Normal", mu=0, sigma=1), + "1|patient": bmb.Prior("Normal", mu=0, sigma=bmb.Prior("InverseGamma", alpha=5, beta=10)), + } + model = bmb.Model( + "censored(time, status) ~ 1 + sex + age + (1|patient)", + data, + family="weibull", + link="log", + priors=priors, + ) + idata = model.fit(tune=100, draws=100, random_seed=121195) + model.predict(idata, kind="pps") + model.predict(idata, data=data, kind="pps") diff --git a/tests/test_utils.py b/tests/test_utils.py index eaeb61037..44b72866f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,6 @@ import numpy as np import pandas as pd -import bambi as bmb from bambi.utils import listify from bambi.backend.pymc import probit, cloglog from bambi.transformations import censored @@ -25,7 +24,6 @@ def test_cloglog(): assert (x > 0).all() and (x < 1).all() -@pytest.mark.skip(reason="Censored still not ported") def test_censored(): df = pd.DataFrame( { @@ -39,9 +37,25 @@ def test_censored(): x = censored(df["x"], df["status"]) assert x.shape == (5, 2) + assert (x[:, -1] == np.array([0, 1, 2, -1, 0])).all() x = censored(df["x"], df["y"], df["status"]) assert x.shape == (5, 3) + assert (x[:, -1] == np.array([0, 1, 2, -1, 0])).all() - with pytest.raises(AssertionError): + # Statuses are not the expected + with pytest.raises(AssertionError, match="Statuses must be in"): censored(df_bad["x"], df_bad["status"]) + + # Upper bound is not always larger than lower bound + df_bad = pd.DataFrame({"l": [1, 2], "r": [1, 1], "status": ["foo", "bar"]}) + + with pytest.raises(AssertionError, match="Upper bound must be larger than lower bound"): + censored(df_bad["l"], df_bad["r"], df_bad["status"]) + + # Bad number of arguments + with pytest.raises(ValueError, match="needs 2 or 3 argument values"): + censored(df["x"]) + + with pytest.raises(ValueError, match="needs 2 or 3 argument values"): + censored(df["x"], df["x"], df["x"], df["x"])