diff --git a/bambi/interpret/effects.py b/bambi/interpret/effects.py index f24b02735..d268c169e 100644 --- a/bambi/interpret/effects.py +++ b/bambi/interpret/effects.py @@ -19,10 +19,8 @@ merge, VariableInfo, ) -from bambi.utils import get_aliased_name, listify +from bambi.utils import get_aliased_name -# TODO: aliases for type hints? -# TODO: functions for error handling SUPPORTED_SLOPES = ("dydx", "eyex") SUPPORTED_COMPARISONS = { @@ -415,7 +413,7 @@ def average_by(self, variable: Union[bool, str]) -> pd.DataFrame: A dataframe containing the marginal or group by average. """ if variable is True: - contrast_df_avg = average_over(self.summary_df, None) + contrast_df_avg = average_over(self.summary_df, "all") contrast_df_avg.insert(0, "term", self.variable.name) contrast_df_avg.insert(1, "estimate_type", self.estimate_name) if self.kind != "slopes" and len(self.variable.values) < 3: @@ -433,7 +431,7 @@ def average_by(self, variable: Union[bool, str]) -> pd.DataFrame: def predictions( model: Model, idata: az.InferenceData, - conditional: Union[str, list, None] = None, + conditional: Union[str, dict, list, None] = None, average_by: Union[str, list, bool, None] = None, target: str = "mean", pps: bool = False, @@ -451,8 +449,9 @@ def predictions( idata : arviz.InferenceData The InferenceData object that contains the samples from the posterior distribution of the model. - conditional : str, list, optional - The covariates we would like to condition on. + conditional : str, list, dict, optional + The covariates we would like to condition on. If dict, keys are the covariate names and + values are the values to condition on. average_by: str, list, bool, optional The covariates we would like to average by. The passed covariate(s) will marginalize over the other covariates in the model. If True, it averages over all covariates @@ -484,15 +483,21 @@ def predictions( ------ ValueError If ``pps`` is ``True`` and ``target`` is not ``"mean"``. + If ``conditional`` is a list and the length is greater than 3. + If ``prob`` is not > 0 and < 1. """ - if pps and target != "mean": raise ValueError("When passing 'pps=True', target must be 'mean'") - # TODO: error handling + if isinstance(conditional, list): + if len(conditional) > 3: + raise ValueError( + f"Only 3 covariates can be passed to 'conditional'. {len(conditional)} " + "were passed. If you would like to pass more than 3 covariates, use " + "a dictionary." + ) conditional_info = ConditionalInfo(model, conditional) - transforms = transforms if transforms is not None else {} if prob is None: @@ -500,7 +505,7 @@ def predictions( if not 0 < prob < 1: raise ValueError(f"'prob' must be greater than 0 and smaller than 1. It is {prob}.") - cap_data = create_predictions_data(conditional_info, model) + cap_data = create_predictions_data(conditional_info, conditional_info.user_passed) if target != "mean": component = model.components[target] @@ -521,13 +526,13 @@ def predictions( idata = model.predict( idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False, kind="pps" ) - y_hat = response_transform(idata.posterior_predictive[response.name]) + y_hat = response_transform(idata["posterior_predictive"][response.name]) y_hat_mean = y_hat.mean(("chain", "draw")) else: idata = model.predict( idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False ) - y_hat = response_transform(idata.posterior[response.name_target]) + y_hat = response_transform(idata["posterior"][response.name_target]) y_hat_mean = y_hat.mean(("chain", "draw")) if use_hdi and pps: @@ -543,6 +548,7 @@ def predictions( upper_bound = 1 - lower_bound response.lower_bound, response.upper_bound = lower_bound, upper_bound + cap_data = cap_data.copy() if y_hat_mean.ndim > 1: cap_data = merge(y_hat_mean, y_hat_bounds, cap_data) cap_data = cap_data.rename( @@ -559,6 +565,8 @@ def predictions( cap_data[response.upper_bound_name] = y_hat_bounds[1] if average_by is not None: + if average_by is True: + average_by = "all" cap_data = average_over(cap_data, covariate=average_by) return cap_data @@ -587,8 +595,9 @@ def comparisons( the model. contrast : str, dict The predictor name whose contrast we would like to compare. - conditional : str, dict, list - The covariates we would like to condition on. + conditional : str, list, dict, optional + The covariates we would like to condition on. If dict, keys are the covariate names and + values are the values to condition on. average_by: str, list, bool, optional The covariates we would like to average by. The passed covariate(s) will marginalize over the other covariates in the model. If True, it averages over all covariates @@ -620,10 +629,10 @@ def comparisons( If `wrt` is a dict and length of ``contrast`` is greater than 2 and ``conditional`` is ``None``. If ``conditional`` is None and ``contrast`` is categorical with > 2 values. + If ``conditional`` is a list and the length is greater than 3. If ``comparison_type`` is not 'diff' or 'ratio'. If ``prob`` is not > 0 and < 1. """ - contrast_name = contrast if isinstance(contrast, dict): if len(contrast) > 1: @@ -637,6 +646,14 @@ def comparisons( f"{contrast_name} was passed {len(contrast_values)} values." ) + if isinstance(conditional, list): + if len(conditional) > 3: + raise ValueError( + f"Only 3 covariates can be passed to 'conditional'. {len(conditional)} " + "were passed. If you would like to pass more than 3 covariates, " + "use a dictionary." + ) + if conditional is None: if is_categorical_dtype(model.data[contrast_name]) or is_string_dtype( model.data[contrast_name] @@ -728,8 +745,9 @@ def slopes( the model. wrt : str, dict The slope of the regression with respect to (wrt) this predictor will be computed. - conditional : str, dict, list - The covariates we would like to condition on. + conditional : str, list, dict, optional + The covariates we would like to condition on. If dict, keys are the covariate names and + values are the values to condition on. average_by: str, list, bool, optional The covariates we would like to average by. The passed covariate(s) will marginalize over the other covariates in the model. If True, it averages over all covariates @@ -771,6 +789,7 @@ def slopes( If length of ``wrt`` is greater than 1. If ``conditional`` is ``None`` and ``wrt`` is passed more than 2 values. If ``conditional`` is ``None`` and default ``wrt`` has more than 2 unique values. + If ``conditional`` is a list and the length is greater than 3. If ``slope`` is not 'dydx', 'dyex', 'eyex', or 'eydx'. If ``prob`` is not > 0 and < 1. """ @@ -787,6 +806,14 @@ def slopes( f"{wrt_name} was passed {len(wrt_values)} values." ) + if isinstance(conditional, list): + if len(conditional) > 3: + raise ValueError( + f"Only 3 covariates can be passed to 'conditional'. {len(conditional)} " + " were passed. If you would like to pass more than 3 covariates, " + "use a dictionary." + ) + if not isinstance(wrt, dict) and conditional is None: if is_categorical_dtype(model.data[wrt_name]) or is_string_dtype(model.data[wrt_name]): num_levels = len(model.data[wrt_name].unique()) diff --git a/bambi/interpret/plotting.py b/bambi/interpret/plotting.py index 3f0429b21..330c77d70 100644 --- a/bambi/interpret/plotting.py +++ b/bambi/interpret/plotting.py @@ -13,9 +13,6 @@ from bambi.interpret.utils import get_covariates, ConditionalInfo from bambi.utils import get_aliased_name, listify -# TODO: aliases for type hints? -# TODO: functions for error handling - def _plot_differences( model: Model, @@ -31,7 +28,6 @@ def _plot_differences( """ Common function used for both 'plot_comparisons' and 'plot_slopes'. """ - if (subplot_kwargs and not average_by) or (subplot_kwargs and average_by): for key, value in subplot_kwargs.items(): conditional_info.covariates.update({key: value}) @@ -87,7 +83,7 @@ def _plot_differences( def plot_predictions( model: Model, idata: az.InferenceData, - conditional: Union[str, list, None] = None, + conditional: Union[str, list, dict, None] = None, average_by: Union[str, list, None] = None, target: str = "mean", sample_new_groups: bool = False, @@ -109,8 +105,9 @@ def plot_predictions( idata : arviz.InferenceData The InferenceData object that contains the samples from the posterior distribution of the model. - conditional : str, list, optional - A sequence of between one and three names of variables in the model. + conditional : str, list, dict, optional + The covariates we would like to condition on. If dict, keys are the covariate names and + values are the values to condition on. average_by: str, list, bool, optional The covariates we would like to average by. The passed covariate(s) will marginalize over the other covariates in the model. If True, it averages over all covariates @@ -154,30 +151,31 @@ def plot_predictions( Raises ------ ValueError - If number of values passed with ``conditional`` is >= 2 and - ``average_by`` are both ``None``. If ``conditional`` and ``average_by`` are both ``None``. If length of ``conditional`` is greater than 3 and ``average_by`` is ``None``. + If main covariate is not numeric or categoric. """ - if conditional is None and average_by is None: raise ValueError("Must specify at least one of 'conditional' or 'average_by'.") if isinstance(conditional, dict): - conditional = {key: sorted(listify(value)) for key, value in conditional.items()} - elif conditional is not None: + conditional = { + key: np.array(sorted(listify(value))).flatten() for key, value in conditional.items() + } + elif isinstance(conditional, str): conditional = listify(conditional) - if len(conditional) > 3 and average_by is None: - raise ValueError( - "Must specify a covariate to 'average_by' when number of covariates" - "passed to 'conditional' is greater than 3." - ) + + if conditional is not None and len(conditional) > 3 and average_by is None: + raise ValueError( + "Must specify a covariate to 'average_by' when number of covariates " + "passed to 'conditional' is greater than 3." + ) if average_by is True: raise ValueError( "Plotting when 'average_by = True' is not possible as 'True' marginalizes " - "over all covariates resulting in a single comparison estimate. " - "Please specify a covariate(s) to 'average_by'." + "over all covariates resulting in a single prediction estimate. " + "Please pass a covariate(s) to 'average_by'." ) cap_data = predictions( @@ -194,6 +192,7 @@ def plot_predictions( ) conditional_info = ConditionalInfo(model, conditional) + transforms = transforms if transforms is not None else {} if (subplot_kwargs and not average_by) or (subplot_kwargs and average_by): for key, value in subplot_kwargs.items(): @@ -267,7 +266,8 @@ def plot_comparisons( contrast : str, dict, list The predictor name whose contrast we would like to compare. conditional : str, dict, list - The covariates we would like to condition on. + The covariates we would like to condition on. If dict, keys are the covariate names and + values are the values to condition on. average_by: str, list, optional The covariates we would like to average by. The passed covariate(s) will marginalize over the other covariates in the model. Defaults to ``None``. @@ -306,11 +306,11 @@ def plot_comparisons( Raises ------ ValueError + If the number of contrast levels is greater than 2 and ``average_by`` is ``None``. If ``conditional`` and ``average_by`` are both ``None``. If length of ``conditional`` is greater than 3 and ``average_by`` is ``None``. - - Warning - If length of ``contrast`` is greater than 2. + If ``average_by`` is ``True``. + If main covariate is not numeric or categoric. """ contrast_name = contrast if isinstance(contrast, dict): @@ -350,8 +350,8 @@ def plot_comparisons( if average_by is True: raise ValueError( "Plotting when 'average_by = True' is not possible as 'True' marginalizes " - "over all covariates resulting in a single comparison estimate. " - "Please specify a covariate(s) to 'average_by'." + "over all covariates resulting in a single prediction estimate. " + "Please pass a covariate(s) to 'average_by'." ) conditional_info = ConditionalInfo(model, conditional) @@ -413,7 +413,8 @@ def plot_slopes( If 'wrt' is numeric, the derivative is computed, else if string or categorical, 'comparisons' is called to compute difference in group means. conditional : str, dict, list - The covariates we would like to condition on. + The covariates we would like to condition on. If dict, keys are the covariate names and + values are the values to condition on. average_by: str, list, bool, optional The covariates we would like to average by. The passed covariate(s) will marginalize over the other covariates in the model. If True, it averages over all covariates @@ -464,11 +465,12 @@ def plot_slopes( Raises ------ ValueError - If number of values passed with ``conditional`` is >= 2 and - ``average_by`` are both ``None``. + If the number of ``wrt`` values is greater than 2 and ``average_by`` is ``None``. If ``conditional`` and ``average_by`` are both ``None``. If length of ``conditional`` is greater than 3 and ``average_by`` is ``None``. + If ``average_by`` is ``True``. If ``slope`` is not one of ('dydx', 'dyex', 'eyex', 'eydx'). + If main covariate is not numeric or categoric. """ wrt_name = wrt if isinstance(wrt, dict): @@ -508,8 +510,8 @@ def plot_slopes( if average_by is True: raise ValueError( "Plotting when 'average_by = True' is not possible as 'True' marginalizes " - "over all covariates resulting in a single slope estimate. " - "Please specify a covariate(s) to 'average_by'." + "over all covariates resulting in a single prediction estimate. " + "Please pass a covariate(s) to 'average_by'." ) if slope not in ("dydx", "dyex", "eyex", "eydx"): diff --git a/bambi/interpret/utils.py b/bambi/interpret/utils.py index f2208e4e4..7409a79f5 100644 --- a/bambi/interpret/utils.py +++ b/bambi/interpret/utils.py @@ -180,12 +180,12 @@ class Covariates: panel: Union[str, None] -def average_over(data: pd.DataFrame, covariate: Union[str, list, None]) -> pd.DataFrame: +def average_over(data: pd.DataFrame, covariate: Union[str, list]) -> pd.DataFrame: """ Average estimates by specified covariate in the model. data.columns[-3:] are the columns: 'estimate', 'lower', and 'upper'. """ - if covariate is None: + if covariate == "all": return pd.DataFrame(data[data.columns[-3:]].mean()).T else: return data.groupby(covariate, as_index=False)[data.columns[-3:]].mean() @@ -359,6 +359,7 @@ def set_default_values(model: Model, data_dict: dict, kind: str) -> dict: if not isinstance(value, (list, np.ndarray)): data_dict[key] = [value] return data_dict + return data_dict