Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

advanced interpret usage #762

Merged
merged 23 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
2aadbb2
re-run notebooks and advanced usage docs
GStechschulte Nov 26, 2023
da07124
added select_draws and data_grid functions
GStechschulte Nov 26, 2023
732572f
move data generation functions to create_data.py for better cohesion
GStechschulte Nov 26, 2023
9f3ee27
move sorting of dict values to ConditionalInfo dataclass
GStechschulte Nov 26, 2023
e45c587
removal of keyword args. to functions in create_data.py
GStechschulte Nov 26, 2023
338d880
create_grid function for internal and user-level functions
GStechschulte Nov 26, 2023
90969b5
add select_draws and data_grid as modules
GStechschulte Nov 26, 2023
234be08
remove code-cells
GStechschulte Nov 26, 2023
7c2f48a
initial tests for interpret helper functions
GStechschulte Nov 28, 2023
c916c7c
add kwargs, docstrings, and error handling
GStechschulte Nov 28, 2023
f5d35cf
improved docstrings and inline comments
GStechschulte Nov 28, 2023
bd2405a
remove functions that have been deleted from utils.py
GStechschulte Nov 28, 2023
67d7a00
lowercase inline comments
GStechschulte Nov 28, 2023
5bf8933
re-run docs and add advanced interpret docs
GStechschulte Dec 1, 2023
3003300
finalize tests
GStechschulte Dec 1, 2023
57ee7fb
update interpret logger tests to reflect new message
GStechschulte Dec 1, 2023
fb2c04a
update logger to parse create_data func
GStechschulte Dec 1, 2023
4108fa9
remove double backticks
GStechschulte Dec 1, 2023
5e70027
eps logic and add logger decorator
GStechschulte Dec 1, 2023
177ef8c
re-run slopes and advanced usage notebooks
GStechschulte Dec 5, 2023
05b1a2c
remove elif block and update filterwarnings to specific message
GStechschulte Dec 5, 2023
ca8682d
remove elif block and update filterwarnings to specific message
GStechschulte Dec 5, 2023
5c361da
remove elif block and update filterwarnings to specific message
GStechschulte Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bambi/interpret/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import logging

from bambi.interpret.effects import comparisons, predictions, slopes
from bambi.interpret.helpers import data_grid, select_draws
from bambi.interpret.plotting import plot_comparisons, plot_predictions, plot_slopes

__all__ = [
"comparisons",
"data_grid",
"logger",
"select_draws",
"slopes",
"predictions",
"plot_comparisons",
Expand Down
242 changes: 127 additions & 115 deletions bambi/interpret/create_data.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,131 @@
import itertools

from typing import Union
from statistics import mode

import numpy as np
import pandas as pd

from pandas.api.types import (
is_categorical_dtype,
is_float_dtype,
is_integer_dtype,
is_numeric_dtype,
is_object_dtype,
is_string_dtype,
)
GStechschulte marked this conversation as resolved.
Show resolved Hide resolved

from bambi import Model
from bambi.interpret.utils import (
ConditionalInfo,
enforce_dtypes,
get_covariates,
get_model_covariates,
make_group_panel_values,
make_main_values,
set_default_values,
VariableInfo,
)

from bambi.interpret.logs import log_interpret_defaults

def _pairwise_grid(data_dict: dict) -> pd.DataFrame:
"""Creates a pairwise grid (cartesian product) of data by using the
key-values of the dictionary.

@log_interpret_defaults
def create_grid(
condition: ConditionalInfo, variable: Union[VariableInfo, None] = None, **kwargs
) -> pd.DataFrame:
"""Creates a grid of data by using the covariates passed into the 'conditional'
and 'variable' argument.

Values for the grid are either:
1.) computed using an equally spaced grid (`np.linspace`), mean, and or mode
depending on the covariate dtype.
2.) a user specified value or range of values if `condition.user_passed = True`

Parameters
----------
data_dict : dict
A dictionary containing the covariates as keys and their values as the
values.
condition : ConditionalInfo
Information about data passed to the conditional parameter of 'comparisons',
'predictions', or 'slopes' related functions.
variable : VariableInfo, optional
Information about data passed to the variable of interest parameter. This
is 'contrast' for 'comparisons', 'wrt' for 'slopes', and 'None' for 'predictions'.
**kwargs : dict
Optional keywords arguments such as 'effect_type' (the effect being computed),
and 'num' (the number of values to return when computing a `np.linspace` grid).

Returns
-------
pd.DataFrame
A dataframe containing values used as input to the fitted Bambi model to
generate predictions.
A dataframe containing pairwise combinations of values.
"""
keys, values = zip(*data_dict.items())
data_grid = pd.DataFrame([dict(zip(keys, v)) for v in itertools.product(*values)])
return data_grid
model, observed_data = condition.model, condition.model.data

if condition.user_passed:
# shallow copy of user-passed data dictionary
data_dict = {**condition.conditional}
else:
data_dict = {}
# values here are the names of the covariates
for covariate in condition.covariates.values():
x = observed_data[covariate]
num = kwargs.get("num", 50)
if is_numeric_dtype(x) or is_float_dtype(x):
values = np.linspace(np.min(x), np.max(x), num)
elif is_integer_dtype(x):
values = np.quantile(x, np.linspace(0, 1, 5))
elif is_categorical_dtype(x) or is_string_dtype(x) or is_object_dtype(x):
values = np.unique(x)
else:
raise TypeError(
f"Unsupported data type of {x.dtype} for covariate '{covariate.name}'"
)

data_dict[covariate] = values

if variable:
data_dict[variable.name] = variable.values

# set typical values as defaults for unspecified covariates
data_dict = set_default_values(model, data_dict)
data_grid = _pairwise_grid(data_dict)

# can't enforce dtype on 'with respect to' variable for 'slopes' as it
# may remove floating point in the epsilon
effect = kwargs.get("effect_type", None)
if effect == "slopes":
except_col = variable.name
else:
except_col = None

data_grid = enforce_dtypes(observed_data, data_grid, except_col)

# after computing default values, fractional values may have been computed.
# Enforcing the dtype of "int" may create duplicate rows as it will round
# the fractional values.
data_grid = data_grid.drop_duplicates()

return data_grid.reset_index(drop=True)

def _grid_level(
condition_info: ConditionalInfo,
variable_info: Union[VariableInfo, None],
user_passed: bool,
kind: str,
) -> pd.DataFrame:
"""Creates a "grid" of data by using the covariates passed into the
`conditional` argument. Values for the grid are either: (1) computed
using a equally spaced grid, mean, and or mode (depending on the
covariate dtype), and (2) a user specified value or range of values.

def _pairwise_grid(data_dict: dict) -> pd.DataFrame:
"""Creates a pairwise grid (cartesian product) of data by using the
key-values of the dictionary.

Parameters
----------
condition_info : ConditionalInfo
Information about the conditional argument passed into the plot
function.
variable_info : VariableInfo, optional
Information about the variable of interest. This is `contrast` for
'comparisons', `wrt` for 'slopes', and `None` for 'predictions'.
user_passed : bool
Whether the user passed a value(s) for the `conditional` argument.
kind : str
The kind of effect being computed. Either "comparisons", "predictions",
or "slopes".
data_dict : dict
A dictionary containing the covariates as keys and their values as the
values.

Returns
-------
pd.DataFrame
A dataframe containing values used as input to the fitted Bambi model to
generate predictions.
"""
covariates = get_covariates(condition_info.covariates)

if kind == "predictions":
# Compute pairwise grid of values if the user passed a dict.
if user_passed:
data_dict = {**condition_info.conditional}
data_dict = set_default_values(condition_info.model, data_dict, kind=kind)
for key, value in data_dict.items():
if not isinstance(value, (list, np.ndarray)):
data_dict[key] = [value]
data_grid = _pairwise_grid(data_dict)
else:
# Compute a grid of values
main_values = make_main_values(
condition_info.model.data[covariates.main], covariates.main
)
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
condition_info.model.data,
data_dict,
covariates.main,
covariates.group,
covariates.panel,
kind=kind,
)
data_dict = set_default_values(condition_info.model, data_dict, kind=kind)
data_grid = pd.DataFrame(data_dict)
else:
# Compute pairwise grid of values if the user passed a dict.
if user_passed:
data_dict = {**condition_info.conditional}
else:
# Compute a grid of values
main_values = make_main_values(
condition_info.model.data[covariates.main], covariates.main
)
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
condition_info.model.data,
data_dict,
covariates.main,
covariates.group,
covariates.panel,
kind=kind,
)

data_dict[variable_info.name] = variable_info.values
data_dict = set_default_values(condition_info.model, data_dict, kind=kind)
data_grid = _pairwise_grid(data_dict)

# Can't enforce dtype on numeric 'wrt' for 'slopes 'as it may remove floating point epsilons
except_col = None if kind in ("comparisons", "predictions") else {variable_info.name}
data_grid = enforce_dtypes(condition_info.model.data, data_grid, except_col)

# After computing default values, fractional values may have been computed.
# Enforcing the dtype of "int" may create duplicate rows as it will round
# the fractional values.
data_grid = data_grid.drop_duplicates()

return data_grid.reset_index(drop=True)
keys, values = zip(*data_dict.items())
cross_joined_data = pd.DataFrame([dict(zip(keys, v)) for v in itertools.product(*values)])
return cross_joined_data


def _differences_unit_level(variable_info: VariableInfo, kind: str) -> pd.DataFrame:
def _differences_unit_level(variable_info: VariableInfo, effect_type: str) -> pd.DataFrame:
"""Creates the data for unit-level contrasts by using the observed (empirical)
data. All covariates in the model are included in the data, except for the
contrast predictor. The contrast predictor is replaced with either: (1) the
Expand All @@ -141,8 +136,8 @@ def _differences_unit_level(variable_info: VariableInfo, kind: str) -> pd.DataFr
variable_info : VariableInfo
Information about the variable of interest. This is `contrast` for
'comparisons' and `wrt` for 'slopes'.
kind : str
The kind of effect being computed. Either "comparisons" or "slopes".
effect_type : str
The type of effect being computed. Either "comparisons" or "slopes".

Returns
-------
Expand All @@ -153,10 +148,9 @@ def _differences_unit_level(variable_info: VariableInfo, kind: str) -> pd.DataFr
"""
covariates = get_model_covariates(variable_info.model)
df = variable_info.model.data[covariates].drop(labels=variable_info.name, axis=1)

variable_vals = variable_info.values

if kind == "comparisons":
if effect_type == "comparisons":
variable_vals = np.array(variable_info.values)[..., None]
variable_vals = np.repeat(variable_vals, variable_info.model.data.shape[0], axis=1)

Expand All @@ -165,11 +159,13 @@ def _differences_unit_level(variable_info: VariableInfo, kind: str) -> pd.DataFr
unit_level_df_dict[f"contrast_{idx}"] = df.copy()
unit_level_df_dict[f"contrast_{idx}"][variable_info.name] = value

return pd.concat(unit_level_df_dict.values())
unit_level_df = pd.concat(unit_level_df_dict.values())

return unit_level_df.reset_index(drop=True)


def create_differences_data(
condition_info: ConditionalInfo, variable_info: VariableInfo, user_passed: bool, kind: str
condition_info: ConditionalInfo, variable_info: VariableInfo, effect_type: str
) -> pd.DataFrame:
"""Creates either unit level or grid level data for 'comparisons' and 'slopes'
depending if the user passed covariate values.
Expand All @@ -182,10 +178,8 @@ def create_differences_data(
variable_info : VariableInfo
Information about the variable of interest. This is `contrast` for
'comparisons' and `wrt` for 'slopes'.
user_passed : bool
Whether the user passed a value(s) for the `conditional` argument.
kind : str
The kind of effect being computed. Either "comparisons" or "slopes".
effect_type : str
The type of effect being computed. Either "comparisons" or "slopes".

Returns
-------
Expand All @@ -195,14 +189,13 @@ def create_differences_data(
is returned. Otherwise, a grid of values is created using the covariates
passed into the `conditional` argument.
"""

if not condition_info.covariates:
return _differences_unit_level(variable_info, kind)
return _differences_unit_level(variable_info, effect_type)

return _grid_level(condition_info, variable_info, user_passed, kind)
return create_grid(condition_info, variable_info, effect_type=effect_type)


def create_predictions_data(condition_info: ConditionalInfo, user_passed: bool) -> pd.DataFrame:
def create_predictions_data(condition_info: ConditionalInfo) -> pd.DataFrame:
"""Creates either unit level or grid level data for 'predictions' depending
if the user passed covariates.

Expand All @@ -211,8 +204,6 @@ def create_predictions_data(condition_info: ConditionalInfo, user_passed: bool)
condition_info : ConditionalInfo
Information about the conditional argument passed into the plot
function.
user_passed : bool
Whether the user passed a value(s) for the `conditional` argument.

Returns
-------
Expand All @@ -222,9 +213,30 @@ def create_predictions_data(condition_info: ConditionalInfo, user_passed: bool)
is returned. Otherwise, a grid of values is created using the covariates
passed into the `conditional` argument.
"""
# Unit level data used the observed (empirical) data
# unit level data uses the observed (empirical) data
if not condition_info.covariates:
covariates = get_model_covariates(condition_info.model)
return condition_info.model.data[covariates]

return _grid_level(condition_info, None, user_passed, "predictions")
return create_grid(condition_info, None)


@log_interpret_defaults
def set_default_values(model: Model, data_dict: dict) -> dict:
"""
Set default values for each variable in the model if the user did not
pass them in the data_dict.
"""
# set unspecified covariates to "typical" values
unique_covariates = get_model_covariates(model)
for name in unique_covariates:
if name not in data_dict:
x = model.data[name]
if is_numeric_dtype(x) or is_integer_dtype(x) or is_float_dtype(x):
data_dict[name] = np.array([np.mean(x)])
elif is_categorical_dtype(x) or is_string_dtype(x) or is_object_dtype(x):
data_dict[name] = np.array([mode(x)])
else:
raise TypeError(f"Unsupported data type of {x.dtype} for covariate '{name}'")

return data_dict
Loading
Loading