Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

X-Learner: Use the same sample splits in all base models. #84

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
12 changes: 6 additions & 6 deletions metalearners/cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,12 @@ def _predict_in_sample(
) -> np.ndarray:
if not self._test_indices:
raise ValueError()
if len(X) != sum(len(fold) for fold in self._test_indices):
raise ValueError(
"Trying to predict in-sample on data that is unlike data encountered in training. "
f"Training data included {sum(len(fold) for fold in self._test_indices)} "
f"observations while prediction data includes {len(X)} observations."
)
# if len(X) != sum(len(fold) for fold in self._test_indices):
# raise ValueError(
# "Trying to predict in-sample on data that is unlike data encountered in training. "
# f"Training data included {sum(len(fold) for fold in self._test_indices)} "
# f"observations while prediction data includes {len(X)} observations."
# )
n_outputs = self._n_outputs(method)
predictions = self._initialize_prediction_tensor(
n_observations=len(X),
Expand Down
122 changes: 82 additions & 40 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,31 +99,36 @@ def fit_all_nuisance(

qualified_fit_params = self._qualified_fit_params(fit_params)

self._cvs: list = []
if not synchronize_cross_fitting:
raise ValueError()

self._cv_split_indices = self._split(X)
self._treatment_cv_split_indices = {}

for treatment_variant in range(self.n_variants):
self._treatment_variants_indices.append(w == treatment_variant)
if synchronize_cross_fitting:
cv_split_indices = self._split(
index_matrix(X, self._treatment_variants_indices[treatment_variant])
treatment_indices = np.where(
Copy link
Collaborator Author

@kklein kklein Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an opaque way of turning an array [True, True, False, False, True] into an array [0, 1, 4]. Not sure if there's a neater way of doing that.

Copy link
Contributor

@MatthiasLoefflerQC MatthiasLoefflerQC Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[index for index, value in enumerate(vector) if value] would work too, I guess, and is more verbose, but I like the np.where :)

self._treatment_variants_indices[treatment_variant]
)[0]
self._treatment_cv_split_indices[treatment_variant] = [
(
np.intersect1d(train_indices, treatment_indices),
np.intersect1d(test_indices, treatment_indices),
)
else:
cv_split_indices = None
self._cvs.append(cv_split_indices)
for train_indices, test_indices in self._cv_split_indices
MatthiasLoefflerQC marked this conversation as resolved.
Show resolved Hide resolved
]

nuisance_jobs: list[_ParallelJoblibSpecification | None] = []
for treatment_variant in range(self.n_variants):
nuisance_jobs.append(
self._nuisance_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
y=y[self._treatment_variants_indices[treatment_variant]],
X=X,
y=y,
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=treatment_variant,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[NUISANCE][VARIANT_OUTCOME_MODEL],
cv=self._cvs[treatment_variant],
cv=self._treatment_cv_split_indices[treatment_variant],
)
)

Expand Down Expand Up @@ -160,14 +165,14 @@ def fit_all_treatment(
) -> Self:
if self._treatment_variants_indices is None:
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"The nuisance models need to be fitted before fitting the treatment models. "
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"typically set during nuisance fitting, is None."
)
if not hasattr(self, "_cvs"):
if not hasattr(self, "_treatment_cv_split_indices"):
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"In particular, the MetaLearner's attribute _cvs, "
"The nuisance models need to be fitted before fitting the treatment models. "
"In particular, the MetaLearner's attribute _treatment_cv_split_indices, "
"typically set during nuisance fitting, does not exist."
)
qualified_fit_params = self._qualified_fit_params(fit_params)
Expand All @@ -180,34 +185,31 @@ def fit_all_treatment(
is_oos=False,
)
)

for treatment_variant in range(1, self.n_variants):
imputed_te_control, imputed_te_treatment = self._pseudo_outcome(
y, w, treatment_variant, conditional_average_outcome_estimates
)
treatment_jobs.append(
self._treatment_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
X=X,
y=imputed_te_treatment,
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][TREATMENT_EFFECT_MODEL],
cv=self._cvs[treatment_variant],
cv=self._treatment_cv_split_indices[treatment_variant],
)
)

treatment_jobs.append(
self._treatment_joblib_specifications(
X=index_matrix(X, self._treatment_variants_indices[0]),
X=X,
y=imputed_te_control,
model_kind=CONTROL_EFFECT_MODEL,
model_ord=treatment_variant - 1,
n_jobs_cross_fitting=n_jobs_cross_fitting,
fit_params=qualified_fit_params[TREATMENT][CONTROL_EFFECT_MODEL],
cv=self._cvs[0],
cv=self._treatment_cv_split_indices[0],
)
)

Expand Down Expand Up @@ -278,19 +280,18 @@ def predict(
oos_method=oos_method,
)
)

tau_hat_treatment[treatment_variant_indices] = self.predict_treatment(
X=index_matrix(X, treatment_variant_indices),
X=X,
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=False,
)
)[treatment_variant_indices]
tau_hat_control[control_indices] = self.predict_treatment(
X=index_matrix(X, control_indices),
X=X,
model_kind=CONTROL_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=False,
)
)[control_indices]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need is_oos=False below (and likewise for tau_hat_treatment)? Might be worth a try.

tau_hat_control[non_control_indices] = self.predict_treatment(
X=index_matrix(X, non_control_indices),
model_kind=CONTROL_EFFECT_MODEL,
Expand Down Expand Up @@ -424,16 +425,8 @@ def _pseudo_outcome(
This function can be used with both in-sample or out-of-sample data.
"""
validate_valid_treatment_variant_not_control(treatment_variant, self.n_variants)

treatment_indices = w == treatment_variant
control_indices = w == 0

treatment_outcome = index_matrix(
conditional_average_outcome_estimates, control_indices
)[:, treatment_variant]
control_outcome = index_matrix(
conditional_average_outcome_estimates, treatment_indices
)[:, 0]
treatment_outcome = conditional_average_outcome_estimates[:, treatment_variant]
control_outcome = conditional_average_outcome_estimates[:, 0]

if self.is_classification:
# Get the probability of positive class, multiclass is currently not supported.
Expand All @@ -443,8 +436,8 @@ def _pseudo_outcome(
control_outcome = control_outcome[:, 0]
treatment_outcome = treatment_outcome[:, 0]

imputed_te_treatment = y[treatment_indices] - control_outcome
imputed_te_control = treatment_outcome - y[control_indices]
imputed_te_treatment = y - control_outcome
imputed_te_control = treatment_outcome - y

return imputed_te_control, imputed_te_treatment

Expand Down Expand Up @@ -534,3 +527,52 @@ def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
final_model = build(input_dict, {output_name: cate})
check_model(final_model, full_check=True)
return final_model

def predict_conditional_average_outcomes(
self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL
) -> np.ndarray:
if self._treatment_variants_indices is None:
raise ValueError(
"The metalearner needs to be fitted before predicting."
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"typically set during fitting, is None."
)
# TODO: Consider multiprocessing
n_obs = len(X)
cao_tensor = self._nuisance_tensors(n_obs)[VARIANT_OUTCOME_MODEL][0]
predict_method_name = self.nuisance_model_specifications()[
VARIANT_OUTCOME_MODEL
]["predict_method"](self)
conditional_average_outcomes_list = []

for tv in range(self.n_variants):
if is_oos:
conditional_average_outcomes_list.append(
self.predict_nuisance(
X=X,
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=tv,
is_oos=True,
kklein marked this conversation as resolved.
Show resolved Hide resolved
oos_method=oos_method,
)
)
else:
# TODO: Consider moving this logic to CrossFitEstimator.predict.
cfe = self._nuisance_models[VARIANT_OUTCOME_MODEL][tv]
conditional_average_outcome_estimates = cao_tensor.copy()

for fold_index, test_indices in zip(
range(cfe.n_folds), cfe._test_indices # type: ignore[arg-type]
):
fold_model = cfe._estimators[fold_index]
predict_method = getattr(fold_model, predict_method_name)
fold_estimates = predict_method(X[test_indices])
conditional_average_outcome_estimates[test_indices] = fold_estimates

conditional_average_outcomes_list.append(
conditional_average_outcome_estimates
)

return np.stack(conditional_average_outcomes_list, axis=1).reshape(
n_obs, self.n_variants, -1
)
Loading