Skip to content

Commit

Permalink
interpret get name of LazyVariable (#773)
Browse files Browse the repository at this point in the history
  • Loading branch information
GStechschulte authored Jan 23, 2024
1 parent 1a3cf8a commit 9a1387a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
9 changes: 6 additions & 3 deletions bambi/interpret/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from typing import Union

import numpy as np
from formulae.terms.call import Call
import pandas as pd
import xarray as xr

from formulae.terms.call import Call
from formulae.terms.call_resolver import LazyVariable

from bambi import Model
from bambi.utils import listify
from bambi.interpret.logs import log_interpret_defaults
Expand Down Expand Up @@ -229,15 +231,16 @@ def get_model_covariates(model: Model) -> np.ndarray:
"""
Return covariates specified in the model.
"""

terms = get_model_terms(model)
covariates = []
for term in terms.values():
if hasattr(term, "components"):
for component in term.components:
# if the component is a function call, use the argument names
if isinstance(component, Call):
covariates.append([arg.name for arg in component.call.args])
covariates.append(
[arg.name for arg in component.call.args if isinstance(arg, LazyVariable)]
)
else:
covariates.append([component.name])
elif hasattr(term, "factor"):
Expand Down
43 changes: 43 additions & 0 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,37 @@ def food_choice():
return model, idata


@pytest.fixture(scope="module")
def formulae_transform():
"""
A model with a 'formulae' stateful transformation (polynomial) on a term.
"""
np.random.seed(0)
x1 = np.random.normal(size=100)
x2 = np.random.normal(size=100)
y = 2 + 3*x1 + 1.5*x1**2 + 2*x2 + np.random.normal(scale=1, size=100)
data = pd.DataFrame({'x1': x1, "x2": x2, 'y': y})
model = bmb.Model('y ~ poly(x1, 2) + x2', data)
idata = model.fit(tune=500, draws=500, random_seed=1234)
return model, idata


@pytest.fixture(scope="module")
def nonformulae_transform():
"""
A model with a non-formulae transformation on a term.
"""
np.random.seed(0)
x1 = np.random.uniform(1, 50, 50)
noise = np.random.normal(0, 1, 50)
y = 3 * np.log(x1) + noise
data = pd.DataFrame({'x1': x1, 'y': y})

model = bmb.Model('y ~ np.log(x1)', data)
idata = model.fit(tune=500, draws=500, random_seed=1234)
return model, idata


# Improvement:
# * Test the actual plots are what we are indeed the desired result.
# * Test using the dictionary and the list gives the same plot
Expand Down Expand Up @@ -402,6 +433,18 @@ def test_group_effects(self, sleep_study):
def test_categorical_response(self, food_choice, covariates):
model, idata = food_choice
plot_predictions(model, idata, covariates)


def test_term_transformations(self, formulae_transform, nonformulae_transform):
model, idata = formulae_transform

# Test that the plot works with a formulae transformation
plot_predictions(model, idata, ["x2", "x1"])

model, idata = nonformulae_transform

# Test that the plot works with a non-formulae transformation
plot_predictions(model, idata, "x1")


class TestComparison:
Expand Down

0 comments on commit 9a1387a

Please sign in to comment.