From ab536739c5a3652648448abbc9c87ef20d384e2c Mon Sep 17 00:00:00 2001 From: Martin Stancsics Date: Mon, 28 Aug 2023 07:54:14 +0200 Subject: [PATCH] Correctly create missing category from model_spec (#297) --- src/tabmat/formula.py | 6 +++++- tests/test_formula.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/tabmat/formula.py b/src/tabmat/formula.py index 61335af3..84732f10 100644 --- a/src/tabmat/formula.py +++ b/src/tabmat/formula.py @@ -425,6 +425,7 @@ def from_categorical( reduced_rank: bool, missing_method: str = "fail", missing_name: str = "(MISSING)", + force_convert: bool = False, ) -> "_InteractableCategoricalVector": """Create an interactable categorical vector from a pandas categorical.""" categories = list(cat.categories) @@ -441,7 +442,7 @@ def from_categorical( "if [cat_]missing_method='fail'." ) - if missing_method == "convert" and -1 in codes: + if missing_method == "convert" and (-1 in codes or force_convert): codes[codes == -1] = len(categories) categories.append(missing_name) @@ -718,14 +719,17 @@ def encode_contrasts( order to avoid spanning the intercept. """ levels = levels if levels is not None else _state.get("categories") + force_convert = _state.get("force_convert", False) cat = pandas.Categorical(data._values, categories=levels) _state["categories"] = cat.categories + _state["force_convert"] = missing_method == "convert" and cat.isna().any() return _InteractableCategoricalVector.from_categorical( cat, reduced_rank=reduced_rank, missing_method=missing_method, missing_name=missing_name, + force_convert=force_convert, ) diff --git a/tests/test_formula.py b/tests/test_formula.py index 65930106..f2d8c224 100644 --- a/tests/test_formula.py +++ b/tests/test_formula.py @@ -690,6 +690,10 @@ def test_cat_missing_C(): assert result.column_names == expected_names assert result.model_spec.get_model_matrix(df).column_names == expected_names + np.testing.assert_equal(result.model_spec.get_model_matrix(df).A, result.A) + np.testing.assert_equal( + result.model_spec.get_model_matrix(df[:2]).A, result.A[:2, :] + ) @pytest.mark.parametrize(