Skip to content

Commit

Permalink
add args. and error handling for new predictions functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
GStechschulte committed Oct 13, 2023
1 parent eceed09 commit 296df2c
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 50 deletions.
63 changes: 45 additions & 18 deletions bambi/interpret/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
merge,
VariableInfo,
)
from bambi.utils import get_aliased_name, listify
from bambi.utils import get_aliased_name

# TODO: aliases for type hints?
# TODO: functions for error handling

SUPPORTED_SLOPES = ("dydx", "eyex")
SUPPORTED_COMPARISONS = {
Expand Down Expand Up @@ -415,7 +413,7 @@ def average_by(self, variable: Union[bool, str]) -> pd.DataFrame:
A dataframe containing the marginal or group by average.
"""
if variable is True:
contrast_df_avg = average_over(self.summary_df, None)
contrast_df_avg = average_over(self.summary_df, "all")
contrast_df_avg.insert(0, "term", self.variable.name)
contrast_df_avg.insert(1, "estimate_type", self.estimate_name)
if self.kind != "slopes" and len(self.variable.values) < 3:
Expand All @@ -433,7 +431,7 @@ def average_by(self, variable: Union[bool, str]) -> pd.DataFrame:
def predictions(
model: Model,
idata: az.InferenceData,
conditional: Union[str, list, None] = None,
conditional: Union[str, dict, list, None] = None,
average_by: Union[str, list, bool, None] = None,
target: str = "mean",
pps: bool = False,
Expand All @@ -451,8 +449,9 @@ def predictions(
idata : arviz.InferenceData
The InferenceData object that contains the samples from the posterior distribution of
the model.
conditional : str, list, optional
The covariates we would like to condition on.
conditional : str, list, dict, optional
The covariates we would like to condition on. If dict, keys are the covariate names and
values are the values to condition on.
average_by: str, list, bool, optional
The covariates we would like to average by. The passed covariate(s) will marginalize
over the other covariates in the model. If True, it averages over all covariates
Expand Down Expand Up @@ -484,23 +483,29 @@ def predictions(
------
ValueError
If ``pps`` is ``True`` and ``target`` is not ``"mean"``.
If ``conditional`` is a list and the length is greater than 3.
If ``prob`` is not > 0 and < 1.
"""

if pps and target != "mean":
raise ValueError("When passing 'pps=True', target must be 'mean'")

# TODO: error handling
if isinstance(conditional, list):
if len(conditional) > 3:
raise ValueError(
f"Only 3 covariates can be passed to 'conditional'. {len(conditional)} "
"were passed. If you would like to pass more than 3 covariates, use "
"a dictionary."
)

conditional_info = ConditionalInfo(model, conditional)

transforms = transforms if transforms is not None else {}

if prob is None:
prob = az.rcParams["stats.hdi_prob"]
if not 0 < prob < 1:
raise ValueError(f"'prob' must be greater than 0 and smaller than 1. It is {prob}.")

cap_data = create_predictions_data(conditional_info, model)
cap_data = create_predictions_data(conditional_info, conditional_info.user_passed)

if target != "mean":
component = model.components[target]
Expand All @@ -521,13 +526,13 @@ def predictions(
idata = model.predict(
idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False, kind="pps"
)
y_hat = response_transform(idata.posterior_predictive[response.name])
y_hat = response_transform(idata["posterior_predictive"][response.name])
y_hat_mean = y_hat.mean(("chain", "draw"))
else:
idata = model.predict(
idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False
)
y_hat = response_transform(idata.posterior[response.name_target])
y_hat = response_transform(idata["posterior"][response.name_target])
y_hat_mean = y_hat.mean(("chain", "draw"))

if use_hdi and pps:
Expand All @@ -543,6 +548,7 @@ def predictions(
upper_bound = 1 - lower_bound
response.lower_bound, response.upper_bound = lower_bound, upper_bound

cap_data = cap_data.copy()
if y_hat_mean.ndim > 1:
cap_data = merge(y_hat_mean, y_hat_bounds, cap_data)
cap_data = cap_data.rename(
Expand All @@ -559,6 +565,8 @@ def predictions(
cap_data[response.upper_bound_name] = y_hat_bounds[1]

if average_by is not None:
if average_by is True:
average_by = "all"
cap_data = average_over(cap_data, covariate=average_by)

return cap_data
Expand Down Expand Up @@ -587,8 +595,9 @@ def comparisons(
the model.
contrast : str, dict
The predictor name whose contrast we would like to compare.
conditional : str, dict, list
The covariates we would like to condition on.
conditional : str, list, dict, optional
The covariates we would like to condition on. If dict, keys are the covariate names and
values are the values to condition on.
average_by: str, list, bool, optional
The covariates we would like to average by. The passed covariate(s) will marginalize
over the other covariates in the model. If True, it averages over all covariates
Expand Down Expand Up @@ -620,10 +629,10 @@ def comparisons(
If `wrt` is a dict and length of ``contrast`` is greater than 2 and
``conditional`` is ``None``.
If ``conditional`` is None and ``contrast`` is categorical with > 2 values.
If ``conditional`` is a list and the length is greater than 3.
If ``comparison_type`` is not 'diff' or 'ratio'.
If ``prob`` is not > 0 and < 1.
"""

contrast_name = contrast
if isinstance(contrast, dict):
if len(contrast) > 1:
Expand All @@ -637,6 +646,14 @@ def comparisons(
f"{contrast_name} was passed {len(contrast_values)} values."
)

if isinstance(conditional, list):
if len(conditional) > 3:
raise ValueError(
f"Only 3 covariates can be passed to 'conditional'. {len(conditional)} "
"were passed. If you would like to pass more than 3 covariates, "
"use a dictionary."
)

if conditional is None:
if is_categorical_dtype(model.data[contrast_name]) or is_string_dtype(
model.data[contrast_name]
Expand Down Expand Up @@ -728,8 +745,9 @@ def slopes(
the model.
wrt : str, dict
The slope of the regression with respect to (wrt) this predictor will be computed.
conditional : str, dict, list
The covariates we would like to condition on.
conditional : str, list, dict, optional
The covariates we would like to condition on. If dict, keys are the covariate names and
values are the values to condition on.
average_by: str, list, bool, optional
The covariates we would like to average by. The passed covariate(s) will marginalize
over the other covariates in the model. If True, it averages over all covariates
Expand Down Expand Up @@ -771,6 +789,7 @@ def slopes(
If length of ``wrt`` is greater than 1.
If ``conditional`` is ``None`` and ``wrt`` is passed more than 2 values.
If ``conditional`` is ``None`` and default ``wrt`` has more than 2 unique values.
If ``conditional`` is a list and the length is greater than 3.
If ``slope`` is not 'dydx', 'dyex', 'eyex', or 'eydx'.
If ``prob`` is not > 0 and < 1.
"""
Expand All @@ -787,6 +806,14 @@ def slopes(
f"{wrt_name} was passed {len(wrt_values)} values."
)

if isinstance(conditional, list):
if len(conditional) > 3:
raise ValueError(
f"Only 3 covariates can be passed to 'conditional'. {len(conditional)} "
" were passed. If you would like to pass more than 3 covariates, "
"use a dictionary."
)

if not isinstance(wrt, dict) and conditional is None:
if is_categorical_dtype(model.data[wrt_name]) or is_string_dtype(model.data[wrt_name]):
num_levels = len(model.data[wrt_name].unique())
Expand Down
62 changes: 32 additions & 30 deletions bambi/interpret/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
from bambi.interpret.utils import get_covariates, ConditionalInfo
from bambi.utils import get_aliased_name, listify

# TODO: aliases for type hints?
# TODO: functions for error handling


def _plot_differences(
model: Model,
Expand All @@ -31,7 +28,6 @@ def _plot_differences(
"""
Common function used for both 'plot_comparisons' and 'plot_slopes'.
"""

if (subplot_kwargs and not average_by) or (subplot_kwargs and average_by):
for key, value in subplot_kwargs.items():
conditional_info.covariates.update({key: value})
Expand Down Expand Up @@ -87,7 +83,7 @@ def _plot_differences(
def plot_predictions(
model: Model,
idata: az.InferenceData,
conditional: Union[str, list, None] = None,
conditional: Union[str, list, dict, None] = None,
average_by: Union[str, list, None] = None,
target: str = "mean",
sample_new_groups: bool = False,
Expand All @@ -109,8 +105,9 @@ def plot_predictions(
idata : arviz.InferenceData
The InferenceData object that contains the samples from the posterior distribution of
the model.
conditional : str, list, optional
A sequence of between one and three names of variables in the model.
conditional : str, list, dict, optional
The covariates we would like to condition on. If dict, keys are the covariate names and
values are the values to condition on.
average_by: str, list, bool, optional
The covariates we would like to average by. The passed covariate(s) will marginalize
over the other covariates in the model. If True, it averages over all covariates
Expand Down Expand Up @@ -154,30 +151,31 @@ def plot_predictions(
Raises
------
ValueError
If number of values passed with ``conditional`` is >= 2 and
``average_by`` are both ``None``.
If ``conditional`` and ``average_by`` are both ``None``.
If length of ``conditional`` is greater than 3 and ``average_by`` is ``None``.
If main covariate is not numeric or categoric.
"""

if conditional is None and average_by is None:
raise ValueError("Must specify at least one of 'conditional' or 'average_by'.")

if isinstance(conditional, dict):
conditional = {key: sorted(listify(value)) for key, value in conditional.items()}
elif conditional is not None:
conditional = {
key: np.array(sorted(listify(value))).flatten() for key, value in conditional.items()
}
elif isinstance(conditional, str):
conditional = listify(conditional)
if len(conditional) > 3 and average_by is None:
raise ValueError(
"Must specify a covariate to 'average_by' when number of covariates"
"passed to 'conditional' is greater than 3."
)

if conditional is not None and len(conditional) > 3 and average_by is None:
raise ValueError(
"Must specify a covariate to 'average_by' when number of covariates "
"passed to 'conditional' is greater than 3."
)

if average_by is True:
raise ValueError(
"Plotting when 'average_by = True' is not possible as 'True' marginalizes "
"over all covariates resulting in a single comparison estimate. "
"Please specify a covariate(s) to 'average_by'."
"over all covariates resulting in a single prediction estimate. "
"Please pass a covariate(s) to 'average_by'."
)

cap_data = predictions(
Expand All @@ -194,6 +192,7 @@ def plot_predictions(
)

conditional_info = ConditionalInfo(model, conditional)
transforms = transforms if transforms is not None else {}

if (subplot_kwargs and not average_by) or (subplot_kwargs and average_by):
for key, value in subplot_kwargs.items():
Expand Down Expand Up @@ -267,7 +266,8 @@ def plot_comparisons(
contrast : str, dict, list
The predictor name whose contrast we would like to compare.
conditional : str, dict, list
The covariates we would like to condition on.
The covariates we would like to condition on. If dict, keys are the covariate names and
values are the values to condition on.
average_by: str, list, optional
The covariates we would like to average by. The passed covariate(s) will marginalize
over the other covariates in the model. Defaults to ``None``.
Expand Down Expand Up @@ -306,11 +306,11 @@ def plot_comparisons(
Raises
------
ValueError
If the number of contrast levels is greater than 2 and ``average_by`` is ``None``.
If ``conditional`` and ``average_by`` are both ``None``.
If length of ``conditional`` is greater than 3 and ``average_by`` is ``None``.
Warning
If length of ``contrast`` is greater than 2.
If ``average_by`` is ``True``.
If main covariate is not numeric or categoric.
"""
contrast_name = contrast
if isinstance(contrast, dict):
Expand Down Expand Up @@ -350,8 +350,8 @@ def plot_comparisons(
if average_by is True:
raise ValueError(
"Plotting when 'average_by = True' is not possible as 'True' marginalizes "
"over all covariates resulting in a single comparison estimate. "
"Please specify a covariate(s) to 'average_by'."
"over all covariates resulting in a single prediction estimate. "
"Please pass a covariate(s) to 'average_by'."
)

conditional_info = ConditionalInfo(model, conditional)
Expand Down Expand Up @@ -413,7 +413,8 @@ def plot_slopes(
If 'wrt' is numeric, the derivative is computed, else if string or categorical,
'comparisons' is called to compute difference in group means.
conditional : str, dict, list
The covariates we would like to condition on.
The covariates we would like to condition on. If dict, keys are the covariate names and
values are the values to condition on.
average_by: str, list, bool, optional
The covariates we would like to average by. The passed covariate(s) will marginalize
over the other covariates in the model. If True, it averages over all covariates
Expand Down Expand Up @@ -464,11 +465,12 @@ def plot_slopes(
Raises
------
ValueError
If number of values passed with ``conditional`` is >= 2 and
``average_by`` are both ``None``.
If the number of ``wrt`` values is greater than 2 and ``average_by`` is ``None``.
If ``conditional`` and ``average_by`` are both ``None``.
If length of ``conditional`` is greater than 3 and ``average_by`` is ``None``.
If ``average_by`` is ``True``.
If ``slope`` is not one of ('dydx', 'dyex', 'eyex', 'eydx').
If main covariate is not numeric or categoric.
"""
wrt_name = wrt
if isinstance(wrt, dict):
Expand Down Expand Up @@ -508,8 +510,8 @@ def plot_slopes(
if average_by is True:
raise ValueError(
"Plotting when 'average_by = True' is not possible as 'True' marginalizes "
"over all covariates resulting in a single slope estimate. "
"Please specify a covariate(s) to 'average_by'."
"over all covariates resulting in a single prediction estimate. "
"Please pass a covariate(s) to 'average_by'."
)

if slope not in ("dydx", "dyex", "eyex", "eydx"):
Expand Down
5 changes: 3 additions & 2 deletions bambi/interpret/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,12 @@ class Covariates:
panel: Union[str, None]


def average_over(data: pd.DataFrame, covariate: Union[str, list, None]) -> pd.DataFrame:
def average_over(data: pd.DataFrame, covariate: Union[str, list]) -> pd.DataFrame:
"""
Average estimates by specified covariate in the model. data.columns[-3:] are
the columns: 'estimate', 'lower', and 'upper'.
"""
if covariate is None:
if covariate == "all":
return pd.DataFrame(data[data.columns[-3:]].mean()).T
else:
return data.groupby(covariate, as_index=False)[data.columns[-3:]].mean()
Expand Down Expand Up @@ -359,6 +359,7 @@ def set_default_values(model: Model, data_dict: dict, kind: str) -> dict:
if not isinstance(value, (list, np.ndarray)):
data_dict[key] = [value]
return data_dict

return data_dict


Expand Down

0 comments on commit 296df2c

Please sign in to comment.