Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the c-index with IPCW #71

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open

Add the c-index with IPCW #71

wants to merge 13 commits into from

Conversation

Vincent-Maladiere
Copy link
Collaborator

@Vincent-Maladiere Vincent-Maladiere commented Jul 9, 2024

What does this PR propose?

This PR proposes to add the c-index as defined in [1]. I think this is ready to be reviewed for merging, with some questions/suggestions in the TODO section below.

show maths Screenshot 2024-07-09 at 17 07 48

where:

Screenshot 2024-07-09 at 17 07 33

and

Screenshot 2024-07-09 at 17 07 43

and

Screenshot 2024-07-09 at 17 09 41

where $M$ is the probability of incidence of the event of interest.

  • The concordance_index_incidence function is inspired by the concordance_index function in lifelines, with some significant differences:
    • It computes the above formula, designed for competing risks and IPCW
    • It accepts truncated times taus
    • It can use virtually any estimator to compute the IPCW (only KM for now)
    • It uses increasing probabilities of incidence instead of decreasing survival probabilities
  • The main advantage of the lifelines design is the use of a balanced tree, leading to a time complexity in $O ( n \times log (n))$ instead of $O(n^2)$ like in scikit-survival concordance_index_ipcw.
  • To support IPCW, I extended the _BTree class from lifelines.utils.btree.py by adding a weighting count mechanism. I referenced lifelines in hazardous.metrics._btree.py, but I can reference it also in the hazardous.metrics._concordance_index.py file if necessary.
  • I added an extensive test suite on a few data points, allowing me to check the results manually.
  • A good effort on documentation has been started by @judithabk6

TODO

  • Keep improving the documentation
  • Should we add tests for bigger datasets and compare our results with scikit-survival for the survival case? And with results from a specific r-package dealing with competing risk?
  • In the formula of $\hat{W}_{ij, 1}$ (see maths above), I didn't differentiate between $\tilde{T}_i-$ and $\tilde{T}_i$.
  • Use the Cox IPCW estimator. This depends on first adding Cox to the IPCW estimator (to be done in another PR).
  • Should we use the tied_tol parameter for ties in predictions?

cc @ogrisel @GaelVaroquaux @juAlberge @glemaitre

[1] Wolbers, M., Blanche, P., Koller, M. T., Witteman, J. C., & Gerds, T. A. (2014). Concordance for prognostic models with competing risks.

@Vincent-Maladiere
Copy link
Collaborator Author

Vincent-Maladiere commented Jul 9, 2024

The CI for the doc fails because the previous boosting tree model is missing. This should be fixed when #53 is merged.

@Vincent-Maladiere
Copy link
Collaborator Author

Vincent-Maladiere commented Jul 24, 2024

Update on performance

Our implementation is 100x slower than scikit-survival concordance_index_ipcw. This is due to the weight computing (the IPCWs) inside the BalancedTree, which lifelines doesn't perform.

code benchmark
import numpy as np
import pandas as pd
from time import time
from lifelines import CoxPHFitter
from lifelines.datasets import load_kidney_transplant
from sklearn.model_selection import train_test_split

from hazardous.metrics._concordance_index import _concordance_index_incidence_report

df = load_kidney_transplant()

# make the dataset 100x times longer for benchmarking purposes
df = pd.concat([df] * 100, axis=0)

df_train, df_test = train_test_split(df, stratify=df["death"])
cox = CoxPHFitter().fit(df_train, duration_col="time", event_col="death")

t_min, t_max = df["time"].min(), df["time"].max()
time_grid = np.linspace(t_min, t_max, 20)
y_pred = 1 - cox.predict_survival_function(df_test, times=time_grid).T.to_numpy()

y_train = df_train[["death", "time"]].rename(columns=dict(
    death="event", time="duration"
))
y_test = df_test[["death", "time"]].rename(columns=dict(
    death="event", time="duration"
))

tic = time()
result = _concordance_index_incidence_report(
    y_test=y_test,
    y_pred=y_pred,
    time_grid=time_grid,
    taus=None,
    y_train=y_train,
) 
print(f"our implementation: {time() - tic:.2f}s")

# scikit-survival
from sksurv.metrics import concordance_index_ipcw

def make_recarray(y):
    event, duration = y["event"].values, y["duration"].values
    return np.array(
        [(event[i], duration[i]) for i in range(len(event))],
        dtype=[("e", bool), ("t", float)],
    )

tic = time()
concordance_index_ipcw(
    make_recarray(y_train),
    make_recarray(y_test),
    y_pred[:, -1],
    tau=None,
)
print(f"scikit-survival: {time() - tic:.2f}s")

# lifelines
from lifelines.utils import concordance_index

concordance_index(
    event_times=y_test["duration"],
    predicted_scores=1 - y_pred[:, -1],
    event_observed=y_test["event"],
)
print(f"lifelines: {time() - tic:.2f}s")

On a dataset with 20k rows:

our implementation: 18.10s
scikit-survival: 0.24s
lifelines: 0.27s

The flamegraph is quite clear about the culprit, being the list comprehension that computes the IPCW weight for each pair. When I remove the IPCWs, the performance becomes similar to lifelines.

Speedscope views of our implementation Screenshot 2024-07-24 at 18 46 38 Screenshot 2024-07-24 at 18 46 53

I tried to fix this performance issue using numba @jitclass on the BTree, but it is still very slow. I put the numba BTree on a separate draft branch for reference.

Conclusion

I only see two ways forward:

  1. Either my computation of the IPCW in the BTree is flawed, and we can fix the performance issue
  2. or the BTree is not adapted for our metric and we have to look at a non-optimized pairwise implementation like scikit-survival with a $O(n^2)$ instead of $n \log (n)$ time complexity. This would simplify the code base though.

@jjerphan
Copy link
Member

Pinged by @Vincent-Maladiere, but have no time for it.

Random pile of pieces of advice:

  • find if a better algorithm exist first
  • profile to see what's the bottleneck
  • see if tree-based structures can be used from another library (e.g. pydatastructures
  • use another language (like Cython or C++) to implement the costly algorithmic part

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Jul 26, 2024 via email

@Vincent-Maladiere
Copy link
Collaborator Author

After giving it some more thought, there is room for improvement with the current balanced tree design :

  1. When we don't use an IPCW estimator (like lifelines):
    $$W_{ij,1} = W_{ij,2} = 1$$
  2. When we use a non-conditional IPCW estimator (Kaplan-Meier, like scikit-survival):
    $$W_{ij,1} = W_{i,1} = \hat{G}(T_i) ^ 2 \space \mathrm{and} \space W_{ij,2} = \hat{G}(T_i) \hat{G}(T_j) $$

However, when we use a conditional IPCW estimator (like Cox or SurvivalBoost), we have:
$$W_{ij,1} = \hat{G}(T_i | X_i) \hat{G}(T_i | X_j) \space \mathrm{and} \space W_{ij,2} = \hat{G}(T_i | X_i) \hat{G}(T_j | X_j)$$

In this case, the balanced tree is not adapted anymore, and we should use the naive implementation.

So, to make things simpler, I suggest we only implement the naive version for now, and eventually return to the balanced tree later, for the non-conditional and unweighted cases.

WDYT?

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Jul 26, 2024 via email

@Vincent-Maladiere
Copy link
Collaborator Author

Here is the revised version. When used on a survival dataset, it gives identical results to scikit-survival, with a slightly better time complexity.

cindex_duration

@Vincent-Maladiere
Copy link
Collaborator Author

This PR is now ready to be reviewed :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants