Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update predict.modelfit(type = "quantile") #1203

Open
topepo opened this issue Sep 12, 2024 · 6 comments
Open

update predict.modelfit(type = "quantile") #1203

topepo opened this issue Sep 12, 2024 · 6 comments

Comments

@topepo
Copy link
Member

topepo commented Sep 12, 2024

We are adding a mode for quantile regression but have one engine that already enables such prediction (using the censored regression mode).

We should allow that but make some adjustments to harmonize both approaches.

topepo added a commit to dajmcdon/parsnip that referenced this issue Sep 12, 2024
topepo added a commit that referenced this issue Sep 13, 2024
* small change to predict checks

* add vctrs for quantiles and test, refactor *_rq_preds

* revise tests

* Apply some of the suggestions from code review

Co-authored-by: Simon P. Couch <[email protected]>

* rename tests on suggestion from code review

* export missing funs from vctrs for formatting

* convert errors to snapshot tests

* pass call through input check

* update snapshots for caller_env

* rename to parsnip_quantiles, add format snapshot tests

* Apply suggestions from @topepo

Co-authored-by: Max Kuhn <[email protected]>

* rename parsnip_quantiles to quantile_pred

* rename parsnip_quantiles to quantile_pred and add vector probability check

* fix: two bugs introduced earlier

* add formatting tests for single quantile

* replace walk with a loop to avoid "Error in map()"

* remove row/col names

* adjust quantile_pred format

* as_tibble method

* updated NEWS file

* add PR number

* small new update

* helper methods

* update docs

* re-enable quantiles prediction for #1203

* update some tests

* no longer needed

* use tibble::new_tibble

* braces

* test as_tibble

* remove print methods

---------

Co-authored-by: Simon P. Couch <[email protected]>
Co-authored-by: Max Kuhn <[email protected]>
Co-authored-by: ‘topepo’ <‘[email protected]’>
@topepo
Copy link
Member Author

topepo commented Sep 16, 2024

Some notes...

The problem is that we have a predict() method that takes type = "quantile".

With the new quantile regression mode, we specify the quantile levels with set_mode(). The current predict() method has a quantile argument, which is problematic.

A few models have quantile prediction methods. Two survival engines for parametric models (flexsurv and survival) have methods. Also, the bayesian package has this prediction type of regression and classification models.

Proposed changes:

  • Throw a warning if a model uses type = "quantile" when the mode is regression. Other modes use the existing interface.
  • Currently, we do not produce an interval estimate for quantile predictions. We will remove the interval and level arguments to predict_quantile()
  • We will want to transition the quantile prediction type to be specifically reserved for regression models built to predict quantiles. Some ordinary regression models can compute quantiles of the prediction distribution, but these are not optimized for accuracy. We can grandfather the existing censored regression models as-is since they cannot have a censored regression mode and a quantile regression mode.
  • We will want to enforce a specific tidy format that mirrors our new format using a pred_quantile vctrs class. This will require a breaking change to the existing engines.

@topepo
Copy link
Member Author

topepo commented Sep 16, 2024

Regarding the bayesian package... it will be a breaking change. However, the package doesn't really follow any of our guidelines for naming arguments/prediction columns and using tidy data formats.

library(tidymodels)
library(bayesian)
#> Loading required package: brms
#> Loading required package: Rcpp
#> 
#> Attaching package: 'Rcpp'
#> The following object is masked from 'package:rsample':
#> 
#>     populate
#> Loading 'brms' package (version 2.21.0). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
#> 
#> Attaching package: 'brms'
#> The following object is masked from 'package:dials':
#> 
#>     mixture
#> The following object is masked from 'package:stats':
#> 
#>     ar
# regression example

bayesian_fit <-
  bayesian() %>%
  set_mode("regression") %>%
  set_engine("brms") %>%
  fit(
    rating ~ treat + period + carry + (1 | subject),
    data = inhaler
  )
#> Compiling Stan program...
#> Trying to compile a simple C file
#> Running /Library/Frameworks/R.framework/Resources/bin/R CMD SHLIB foo.c
#> using C compiler: ‘Apple clang version 15.0.0 (clang-1500.3.9.4)’
#> using SDK: ‘’
#> clang -arch arm64 -I"/Library/Frameworks/R.framework/Resources/include" -DNDEBUG   -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/Rcpp/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/unsupported"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/BH/include" -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/src/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppParallel/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/rstan/include" -DEIGEN_NO_DEBUG  -DBOOST_DISABLE_ASSERTS  -DBOOST_PENDING_INTEGER_LOG2_HPP  -DSTAN_THREADS  -DUSE_STANC3 -DSTRICT_R_HEADERS  -DBOOST_PHOENIX_NO_VARIADIC_EXPRESSION  -D_HAS_AUTO_PTR_ETC=0  -include '/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/stan/math/prim/fun/Eigen.hpp'  -D_REENTRANT -DRCPP_PARALLEL_USE_TBB=1   -I/opt/R/arm64/include    -fPIC  -falign-functions=64 -Wall -g -O2  -c foo.c -o foo.o
#> In file included from <built-in>:1:
#> In file included from /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/stan/math/prim/fun/Eigen.hpp:22:
#> In file included from /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/Eigen/Dense:1:
#> In file included from /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/Eigen/Core:19:
#> /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/Eigen/src/Core/util/Macros.h:679:10: fatal error: 'cmath' file not found
#> #include <cmath>
#>          ^~~~~~~
#> 1 error generated.
#> make: *** [foo.o] Error 1
#> Start sampling
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 1).
#> Chain 1: 
#> Chain 1: Gradient evaluation took 7.7e-05 seconds
#> Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0.77 seconds.
#> Chain 1: Adjust your expectations accordingly!
#> Chain 1: 
#> Chain 1: 
#> Chain 1: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 1: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 1: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 1: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 1: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 1: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 1: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 1: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 1: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 1: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 1: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 1: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 1: 
#> Chain 1:  Elapsed Time: 0.748 seconds (Warm-up)
#> Chain 1:                0.356 seconds (Sampling)
#> Chain 1:                1.104 seconds (Total)
#> Chain 1: 
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 2).
#> Chain 2: 
#> Chain 2: Gradient evaluation took 2.9e-05 seconds
#> Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 0.29 seconds.
#> Chain 2: Adjust your expectations accordingly!
#> Chain 2: 
#> Chain 2: 
#> Chain 2: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 2: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 2: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 2: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 2: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 2: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 2: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 2: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 2: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 2: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 2: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 2: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 2: 
#> Chain 2:  Elapsed Time: 0.726 seconds (Warm-up)
#> Chain 2:                0.355 seconds (Sampling)
#> Chain 2:                1.081 seconds (Total)
#> Chain 2: 
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 3).
#> Chain 3: 
#> Chain 3: Gradient evaluation took 2.6e-05 seconds
#> Chain 3: 1000 transitions using 10 leapfrog steps per transition would take 0.26 seconds.
#> Chain 3: Adjust your expectations accordingly!
#> Chain 3: 
#> Chain 3: 
#> Chain 3: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 3: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 3: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 3: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 3: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 3: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 3: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 3: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 3: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 3: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 3: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 3: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 3: 
#> Chain 3:  Elapsed Time: 0.7 seconds (Warm-up)
#> Chain 3:                0.355 seconds (Sampling)
#> Chain 3:                1.055 seconds (Total)
#> Chain 3: 
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 4).
#> Chain 4: 
#> Chain 4: Gradient evaluation took 2.3e-05 seconds
#> Chain 4: 1000 transitions using 10 leapfrog steps per transition would take 0.23 seconds.
#> Chain 4: Adjust your expectations accordingly!
#> Chain 4: 
#> Chain 4: 
#> Chain 4: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 4: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 4: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 4: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 4: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 4: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 4: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 4: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 4: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 4: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 4: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 4: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 4: 
#> Chain 4:  Elapsed Time: 0.734 seconds (Warm-up)
#> Chain 4:                0.355 seconds (Sampling)
#> Chain 4:                1.089 seconds (Total)
#> Chain 4:

# Results are not in any type of tidy format or follow the tidymodel rules for
# naming prediction columns. 
predict(bayesian_fit, inhaler, type = "quantile", quantile = c(.3, .5, .7))
#> Warning in c(0.3, 0.5, 0.7): For regression models, making quantile prediction requires a model with a
#> "quantile regression" mode as of parsnip version 1.3.0.
#> # A tibble: 572 × 5
#>    Estimate Est.Error   Q30   Q50   Q70
#>       <dbl>     <dbl> <dbl> <dbl> <dbl>
#>  1     1.21     0.581 0.911  1.22  1.51
#>  2     1.19     0.591 0.877  1.18  1.50
#>  3     1.21     0.592 0.907  1.20  1.52
#>  4     1.19     0.612 0.859  1.18  1.52
#>  5     1.20     0.613 0.886  1.19  1.50
#>  6     1.19     0.593 0.886  1.17  1.48
#>  7     1.22     0.595 0.910  1.21  1.51
#>  8     1.20     0.599 0.901  1.21  1.52
#>  9     1.19     0.595 0.875  1.18  1.51
#> 10     1.18     0.594 0.877  1.18  1.50
#> # ℹ 562 more rows
# Classification example

# data from: https://stats.oarc.ucla.edu/r/dae/mixed-effects-logistic-regression/
hdp <- 
  read.csv("https://stats.idre.ucla.edu/stat/data/hdp.csv") %>% 
  mutate(
    Married = factor(Married, levels = 0:1, labels = c("no", "yes")),
    DID = factor(DID),
    HID = factor(HID),
    CancerStage = factor(CancerStage),
    remission = factor(ifelse(remission == 1, "yes", "no"))
  )

bayesian_fit <-
  bayesian(family = bernoulli(link = "logit")) %>%
  set_mode("classification") %>%
  set_engine("brms") %>%
  fit(remission ~ IL6 + CRP + (1 | DID), data = hdp)
#> Compiling Stan program...
#> Trying to compile a simple C file
#> Running /Library/Frameworks/R.framework/Resources/bin/R CMD SHLIB foo.c
#> using C compiler: ‘Apple clang version 15.0.0 (clang-1500.3.9.4)’
#> using SDK: ‘’
#> clang -arch arm64 -I"/Library/Frameworks/R.framework/Resources/include" -DNDEBUG   -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/Rcpp/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/unsupported"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/BH/include" -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/src/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppParallel/include/"  -I"/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/rstan/include" -DEIGEN_NO_DEBUG  -DBOOST_DISABLE_ASSERTS  -DBOOST_PENDING_INTEGER_LOG2_HPP  -DSTAN_THREADS  -DUSE_STANC3 -DSTRICT_R_HEADERS  -DBOOST_PHOENIX_NO_VARIADIC_EXPRESSION  -D_HAS_AUTO_PTR_ETC=0  -include '/Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/stan/math/prim/fun/Eigen.hpp'  -D_REENTRANT -DRCPP_PARALLEL_USE_TBB=1   -I/opt/R/arm64/include    -fPIC  -falign-functions=64 -Wall -g -O2  -c foo.c -o foo.o
#> In file included from <built-in>:1:
#> In file included from /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/StanHeaders/include/stan/math/prim/fun/Eigen.hpp:22:
#> In file included from /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/Eigen/Dense:1:
#> In file included from /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/Eigen/Core:19:
#> /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library/RcppEigen/include/Eigen/src/Core/util/Macros.h:679:10: fatal error: 'cmath' file not found
#> #include <cmath>
#>          ^~~~~~~
#> 1 error generated.
#> make: *** [foo.o] Error 1
#> Start sampling
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 1).
#> Chain 1: 
#> Chain 1: Gradient evaluation took 0.000487 seconds
#> Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 4.87 seconds.
#> Chain 1: Adjust your expectations accordingly!
#> Chain 1: 
#> Chain 1: 
#> Chain 1: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 1: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 1: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 1: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 1: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 1: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 1: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 1: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 1: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 1: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 1: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 1: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 1: 
#> Chain 1:  Elapsed Time: 14.297 seconds (Warm-up)
#> Chain 1:                4.489 seconds (Sampling)
#> Chain 1:                18.786 seconds (Total)
#> Chain 1: 
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 2).
#> Chain 2: 
#> Chain 2: Gradient evaluation took 0.000297 seconds
#> Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 2.97 seconds.
#> Chain 2: Adjust your expectations accordingly!
#> Chain 2: 
#> Chain 2: 
#> Chain 2: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 2: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 2: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 2: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 2: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 2: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 2: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 2: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 2: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 2: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 2: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 2: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 2: 
#> Chain 2:  Elapsed Time: 12.056 seconds (Warm-up)
#> Chain 2:                4.56 seconds (Sampling)
#> Chain 2:                16.616 seconds (Total)
#> Chain 2: 
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 3).
#> Chain 3: 
#> Chain 3: Gradient evaluation took 0.000302 seconds
#> Chain 3: 1000 transitions using 10 leapfrog steps per transition would take 3.02 seconds.
#> Chain 3: Adjust your expectations accordingly!
#> Chain 3: 
#> Chain 3: 
#> Chain 3: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 3: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 3: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 3: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 3: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 3: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 3: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 3: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 3: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 3: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 3: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 3: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 3: 
#> Chain 3:  Elapsed Time: 12.867 seconds (Warm-up)
#> Chain 3:                4.546 seconds (Sampling)
#> Chain 3:                17.413 seconds (Total)
#> Chain 3: 
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 4).
#> Chain 4: 
#> Chain 4: Gradient evaluation took 0.000295 seconds
#> Chain 4: 1000 transitions using 10 leapfrog steps per transition would take 2.95 seconds.
#> Chain 4: Adjust your expectations accordingly!
#> Chain 4: 
#> Chain 4: 
#> Chain 4: Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 4: Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 4: Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 4: Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 4: Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 4: Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 4: Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 4: Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 4: Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 4: Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 4: Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 4: Iteration: 2000 / 2000 [100%]  (Sampling)
#> Chain 4: 
#> Chain 4:  Elapsed Time: 12.375 seconds (Warm-up)
#> Chain 4:                4.507 seconds (Sampling)
#> Chain 4:                16.882 seconds (Total)
#> Chain 4:
#> Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
#> Running the chains for more iterations may help. See
#> https://mc-stan.org/misc/warnings.html#bulk-ess

# This doesn't seem to work:
predict(bayesian_fit, hdp %>% select(-remission), type = "quantile", 
        quantile = c(.3, .5, .7))
#> # A tibble: 8,525 × 5
#>    Estimate Est.Error   Q30   Q50   Q70
#>       <dbl>     <dbl> <dbl> <dbl> <dbl>
#>  1   0.0192     0.137     0     0     0
#>  2   0.0335     0.180     0     0     0
#>  3   0.0158     0.125     0     0     0
#>  4   0.03       0.171     0     0     0
#>  5   0.026      0.159     0     0     0
#>  6   0.031      0.173     0     0     0
#>  7   0.0265     0.161     0     0     0
#>  8   0.0245     0.155     0     0     0
#>  9   0.0215     0.145     0     0     0
#> 10   0.0215     0.145     0     0     0
#> # ℹ 8,515 more rows

Created on 2024-09-16 with reprex v2.1.1

@hfrick
Copy link
Member

hfrick commented Sep 17, 2024

Currrently we do not produce any interval estimate for quantile predictions, we will remove the interval and level arguments to predict_quantile().

That's incorrect, we do produce them for the flexsurv and flexsurvspline engines in censored for survial_reg() models. I'd like us to keep that functionality.

library(censored)
#> Loading required package: parsnip
#> Loading required package: survival

# flexsurv engine
set.seed(1)
fit_s <- survival_reg() %>%
  set_engine("flexsurv") %>%
  set_mode("censored regression") %>%
  fit(Surv(stop, event) ~ rx + size + enum, data = bladder)

pred <- predict(fit_s,
  new_data = bladder[1:3, ], type = "quantile",
  interval = "confidence", level = 0.7
)
pred
#> # A tibble: 3 × 1
#>   .pred           
#>   <list>          
#> 1 <tibble [9 × 4]>
#> 2 <tibble [9 × 4]>
#> 3 <tibble [9 × 4]>
pred$.pred[[1]]
#> # A tibble: 9 × 4
#>   .quantile .pred_quantile .pred_lower .pred_upper
#>       <dbl>          <dbl>       <dbl>       <dbl>
#> 1       0.1           3.57        2.75        4.46
#> 2       0.2           7.33        5.83        8.86
#> 3       0.3          11.5         9.33       13.7 
#> 4       0.4          16.2        13.3        19.4 
#> 5       0.5          21.7        18.0        25.9 
#> 6       0.6          28.3        23.6        33.8 
#> 7       0.7          36.8        30.8        44.1 
#> 8       0.8          48.5        40.5        58.5 
#> 9       0.9          68.4        56.4        83.6

# flexsurvspline engine
set.seed(1)
fit_s <- survival_reg() %>%
  set_engine("flexsurvspline", k = 1) %>%
  set_mode("censored regression") %>%
  fit(Surv(stop, event) ~ rx + size + enum, data = bladder)

pred <- predict(fit_s,
  new_data = bladder[1:3, ], type = "quantile",
  interval = "confidence", level = 0.7
)
pred
#> # A tibble: 3 × 1
#>   .pred           
#>   <list>          
#> 1 <tibble [9 × 4]>
#> 2 <tibble [9 × 4]>
#> 3 <tibble [9 × 4]>
pred$.pred[[1]]
#> # A tibble: 9 × 4
#>   .quantile .pred_quantile .pred_lower .pred_upper
#>       <dbl>          <dbl>       <dbl>       <dbl>
#> 1       0.1           3.86        3.08        4.70
#> 2       0.2           7.17        5.90        8.67
#> 3       0.3          10.8         8.94       13.1 
#> 4       0.4          15.2        12.6        18.3 
#> 5       0.5          20.6        17.2        24.8 
#> 6       0.6          27.6        23.0        33.5 
#> 7       0.7          37.1        31.1        45.2 
#> 8       0.8          51.2        42.4        64.3 
#> 9       0.9          76.2        61.4       100.

Created on 2024-09-17 with reprex v2.1.0

@hfrick
Copy link
Member

hfrick commented Sep 17, 2024

We will want to transition the quantile prediction type to be specifically reserved for regression models built to predict quantiles. Some ordinary regression models can compute quantiles of the prediction distribution, but these are not optimized for accuracy. We can grandfather the existing censored regression models as-is since they cannot have a censored regression mode and a quantile regression mode.

Why do we want to reserve type = "quantile" for models with mode = "quantile regression"? Wouldn't the mode be enough distinction? We can document that only the quantiles predicted by quantile regression models are optimized for accuracy but still allow other types of quantiles.

topepo pushed a commit to tidymodels/censored that referenced this issue Sep 23, 2024
@topepo
Copy link
Member Author

topepo commented Sep 23, 2024

For survival::survreg() objects, the quantile levels do not appear to be stored anywhere in the output of predict(). We may need a wrapper to add an attribute or to pre-format it into a tidy format. I've exported parsnip::matrix_to_quantile_pred() since the output is similar to that produced by quantreg.

topepo pushed a commit that referenced this issue Sep 25, 2024
topepo added a commit that referenced this issue Sep 26, 2024
* quantile -> quantile_levels for #1203

* defer test until censored updates in new PR

* update docs for quantile_levels

* update test

* disable quantile predictions for surv_reg

---------

Co-authored-by: ‘topepo’ <‘[email protected]’>
topepo pushed a commit to tidymodels/censored that referenced this issue Sep 30, 2024
topepo added a commit that referenced this issue Oct 11, 2024
* add a quantile regression mode to test with

* update type checkers

* avoid confusion with global all_models object

* add quantile_level argument to set_mode()

* initial data for quantreg

* some initial tests

* fix some issues

* enable quantile prediction

* tests for quantreg

* Quantile predictions output constructor (#1191)

* small change to predict checks

* add vctrs for quantiles and test, refactor *_rq_preds

* revise tests

* Apply some of the suggestions from code review

Co-authored-by: Simon P. Couch <[email protected]>

* rename tests on suggestion from code review

* export missing funs from vctrs for formatting

* convert errors to snapshot tests

* pass call through input check

* update snapshots for caller_env

* rename to parsnip_quantiles, add format snapshot tests

* Apply suggestions from @topepo

Co-authored-by: Max Kuhn <[email protected]>

* rename parsnip_quantiles to quantile_pred

* rename parsnip_quantiles to quantile_pred and add vector probability check

* fix: two bugs introduced earlier

* add formatting tests for single quantile

* replace walk with a loop to avoid "Error in map()"

* remove row/col names

* adjust quantile_pred format

* as_tibble method

* updated NEWS file

* add PR number

* small new update

* helper methods

* update docs

* re-enable quantiles prediction for #1203

* update some tests

* no longer needed

* use tibble::new_tibble

* braces

* test as_tibble

* remove print methods

---------

Co-authored-by: Simon P. Couch <[email protected]>
Co-authored-by: Max Kuhn <[email protected]>
Co-authored-by: ‘topepo’ <‘[email protected]’>

* quantile regression updates for new hardhat model (#1207)

* bump hardhat version

* remove parts now in hardhat

* update for new hardhat version

* quantile_levels (plural now)

* news update

* typo

* rename helper function

* run CI on PRs from branches

* forgotten remote

* actions for edited PRs

* plural

* expand branch list

* export function for censored to use

* updated snapshot

* remake snapshot

* Revert "remake snapshot"

This reverts commit 954e326.

* updated snapshot

* Update R/arguments.R

Co-authored-by: Hannah Frick <[email protected]>

* typo

* changes from reviewer feedback

---------

Co-authored-by: ‘topepo’ <‘[email protected]’>
Co-authored-by: Hannah Frick <[email protected]>

* Change to `quantile` argument to `quantile levels` (#1208)

* quantile -> quantile_levels for #1203

* defer test until censored updates in new PR

* update docs for quantile_levels

* update test

* disable quantile predictions for surv_reg

---------

Co-authored-by: ‘topepo’ <‘[email protected]’>

* post conflict merge updates

* update news

* version bump and fix typo

* revert GHA branches

* small bug fix

* Apply suggestions from code review

Co-authored-by: Hannah Frick <[email protected]>
Co-authored-by: Emil Hvitfeldt <[email protected]>

* don't export median

* add call arg

* added documentation on model

* add mode

* convert error to warning

* remove rankdeficient

* added skip

* add deprecated `quantile` arg back in

* remove numeric prediction

---------

Co-authored-by: ‘topepo’ <‘[email protected]’>
Co-authored-by: Daniel McDonald <[email protected]>
Co-authored-by: Simon P. Couch <[email protected]>
Co-authored-by: Hannah Frick <[email protected]>
Co-authored-by: Emil Hvitfeldt <[email protected]>
@hfrick
Copy link
Member

hfrick commented Oct 15, 2024

Adding a todo as part of this: the docs for predict.model_fit() currently describe the return value for predict(type = "quantile") as a list column. This needs updating to the new vcts class -- and, more user-facing, the new column name for this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants