Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leakage in X-Learner in-sample prediction #80

Open
kklein opened this issue Aug 12, 2024 · 1 comment · Fixed by #83 · May be fixed by #84
Open

Leakage in X-Learner in-sample prediction #80

kklein opened this issue Aug 12, 2024 · 1 comment · Fixed by #83 · May be fixed by #84
Labels
bug Something isn't working

Comments

@kklein
Copy link
Collaborator

kklein commented Aug 12, 2024

Issue at hand

@ArseniyZvyagintsevQC brought the following to our attention:

Let us assume a binary treatment variant scenario in which we want to work with in-sample predictions, i.e. is_oos=False.

The current implementation would go about fitting five models, three of which considered nuisance models and two of which considered treatment models:

model target cross-fitting dataset stage name
$\hat{\mu}_0$ $Y_i$ $\{(X_i, Y_i) | W_i=0\}$ nuisance "treatment_variant"
$\hat{\mu}_1$ $Y_i$ $\{(X_i, Y_i) | W_i=1\}$ nuisance "treatment_variant"
$\hat{e}$ $W_i$ $\{(X_i, Y_i)\}$ nuisance/propensity "propensity_model"
$\hat{\tau}_0$ $\hat{\mu}(X_i) - Y_0$ $\{(X_i, Y_i) | W_i=0\}$ treatment "control_effect_model"
$\hat{\tau}_1$ $Y_i - \hat{\mu}(X_i)$ $\{(X_i, Y_i) | W_i=1\}$ treatment "treatment_effect_model"

More background on this here.

Note that each of these models is cross-fitted. More precisely, each is cross-fitted wrt the data it has seen at training time.

Let's suppose now that we are at inference time and encounter an in-sample data point $i$. Wlog, let's assume that $W_i=1$.
In order to come up with a CATE estimate, the predict method will run

  • $\hat{\tau}_0(X_i)$ with is_oos=True since this datapoint has not been seen during training time of the model $\hat{\tau}_0$
  • $\hat{\tau}_1(X_i)$ with is_oos=False since this datapoint has indeed been seen during the training time of the model $\hat{\tau}_1$

The latter call makes sure we avoid leakage in $\hat{\tau}_1$. The former call, however, does not completely avoid leakage:
even though $i$ hasn't been seen in the training of $\hat{\tau}_0$, it has been seen in $\hat{\mu}_1$, which is, in turn, used by $\hat{\tau}_0$. Therefore, the observed outcome $Y_i$ can leak into the estimate $\hat{\tau}(X_i)$.

Next steps

We can devise an extreme, naïve approach to counteract this issue by training every type of model once per datapoint. Clearly, this ensures the absence of data leakage. The challenge with this issue revolves around coming up with a design that

  • allows for arbitrary numbers (>1, <=n) of cross-fitting folds, i.e. not fixing it to be equal to the number of training data points
  • integrates well into the structure of the library
@kklein kklein added the bug Something isn't working label Aug 12, 2024
@kklein
Copy link
Collaborator Author

kklein commented Aug 14, 2024

Preliminary idea

Currently we are training

  • n_folds many $\hat{\mu}_0$ models
  • n_folds many $\hat{\mu}_k$ models for every $k$
  • n_folds many $\hat{\tau}_{0,k}$ models for every $k$
  • n_folds many $\hat{\tau}_{k,0}$ models for every $k$

In order to answer an in-sample query of

Give me models $\hat{\tau}_{0,k}$ and $\hat{\tau}_{k,0}$ which have seen no information about sample $i$ at all

We could train

  • n_folds * n_folds many $\hat{\tau}_{0,k}$ models for every $k$
  • n_folds * n_folds many $\hat{\tau}_{k,0}$ models for every $k$

In the scenario described in the issue, we would then run the predict method as such:

  • $\hat{\tau}_0(X_i)$ is estimated by fetching the n_folds * (n_folds - 1) many models $\hat{\tau}_0(X_i)$ which are based on $\hat{\mu}_k$ models, which have not seen data point $i$; these model estimates can be aggregated
  • $\hat{\tau}_1(X_i)$ is estimated by fetching the n_folds * (n_folds - 1) many model $\hat{\tau}_1(X_i)$ which have not seen $i$ and have used any $\hat{\mu}_0$ models; these model estimates can be aggregated

Importantly, this would

  • Redefine the meaning of is_oos
  • Massively increase the computational burden

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
1 participant