Skip to content

Commit

Permalink
Correctly create missing category from model_spec (#297)
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart authored Aug 28, 2023
1 parent 41ad4af commit ab53673
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/tabmat/formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def from_categorical(
reduced_rank: bool,
missing_method: str = "fail",
missing_name: str = "(MISSING)",
force_convert: bool = False,
) -> "_InteractableCategoricalVector":
"""Create an interactable categorical vector from a pandas categorical."""
categories = list(cat.categories)
Expand All @@ -441,7 +442,7 @@ def from_categorical(
"if [cat_]missing_method='fail'."
)

if missing_method == "convert" and -1 in codes:
if missing_method == "convert" and (-1 in codes or force_convert):
codes[codes == -1] = len(categories)
categories.append(missing_name)

Expand Down Expand Up @@ -718,14 +719,17 @@ def encode_contrasts(
order to avoid spanning the intercept.
"""
levels = levels if levels is not None else _state.get("categories")
force_convert = _state.get("force_convert", False)
cat = pandas.Categorical(data._values, categories=levels)
_state["categories"] = cat.categories
_state["force_convert"] = missing_method == "convert" and cat.isna().any()

return _InteractableCategoricalVector.from_categorical(
cat,
reduced_rank=reduced_rank,
missing_method=missing_method,
missing_name=missing_name,
force_convert=force_convert,
)


Expand Down
4 changes: 4 additions & 0 deletions tests/test_formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,10 @@ def test_cat_missing_C():

assert result.column_names == expected_names
assert result.model_spec.get_model_matrix(df).column_names == expected_names
np.testing.assert_equal(result.model_spec.get_model_matrix(df).A, result.A)
np.testing.assert_equal(
result.model_spec.get_model_matrix(df[:2]).A, result.A[:2, :]
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit ab53673

Please sign in to comment.