From 861a64a4a9f570ca6bc9540c7935b4134214a64f Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 26 Sep 2024 15:38:52 -0400 Subject: [PATCH] Change to `quantile` argument to `quantile levels` (#1208) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * quantile -> quantile_levels for #1203 * defer test until censored updates in new PR * update docs for quantile_levels * update test * disable quantile predictions for surv_reg --------- Co-authored-by: ‘topepo’ <‘mxkuhn@gmail.com’> --- NEWS.md | 4 +++ R/predict.R | 2 +- R/predict_quantile.R | 23 +++++++++--- R/surv_reg_data.R | 38 -------------------- man/other_predict.Rd | 7 ++-- man/set_args.Rd | 2 +- tests/testthat/_snaps/linear_reg_quantreg.md | 9 +++++ tests/testthat/test-linear_reg_quantreg.R | 5 +++ tests/testthat/test-surv_reg_survreg.R | 14 +------- 9 files changed, 43 insertions(+), 61 deletions(-) create mode 100644 tests/testthat/_snaps/linear_reg_quantreg.md diff --git a/NEWS.md b/NEWS.md index e2a63b619..dfa96528f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,10 @@ * New `extract_fit_time()` method has been added that returns the time it took to train the model (#853). +## Breaking Change + +* For quantile prediction, the `predict()` argument has been changed from `quantile` to `quantile_levels` for consistency. This does not affect models with mode `"quantile regression"`. +* The quantile regression prediction type was disabled for the deprecated `surv_reg()` model. # parsnip 1.2.1 diff --git a/R/predict.R b/R/predict.R index 3a2681048..397b92112 100644 --- a/R/predict.R +++ b/R/predict.R @@ -344,7 +344,7 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env()) # ---------------------------------------------------------------------------- - other_args <- c("interval", "level", "std_error", "quantile", + other_args <- c("interval", "level", "std_error", "quantile_levels", "time", "eval_time", "increasing") is_pred_arg <- names(the_dots) %in% other_args if (any(!is_pred_arg)) { diff --git a/R/predict_quantile.R b/R/predict_quantile.R index fc2d91b15..56ec31bde 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -1,14 +1,15 @@ #' @keywords internal #' @rdname other_predict -#' @param quantile A vector of numbers between 0 and 1 for the quantile being -#' predicted. +#' @param quantile_levels A vector of values between zero and one for the +#' quantile to be predicted. If the model has a `"censored regression"` mode, +#' this value should be `NULL`. For other modes, the default is `(1:9)/10`. #' @inheritParams predict.model_fit #' @method predict_quantile model_fit #' @export predict_quantile.model_fit #' @export predict_quantile.model_fit <- function(object, new_data, - quantile = (1:9)/10, + quantile_levels = NULL, interval = "none", level = 0.95, ...) { @@ -20,6 +21,20 @@ predict_quantile.model_fit <- function(object, return(NULL) } + if (object$spec$mode == "quantile regression") { + if (!is.null(quantile_levels)) { + cli::cli_abort("When the mode is {.val quantile regression}, + {.arg quantile_levels} are specified by {.fn set_mode}.") + } + } else { + if (is.null(quantile_levels)) { + quantile_levels <- (1:9)/10 + } + hardhat::check_quantile_levels(quantile_levels) + # Pass some extra arguments to be used in post-processor + object$quantile_levels <- quantile_levels + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -27,8 +42,6 @@ predict_quantile.model_fit <- function(object, new_data <- object$spec$method$pred$quantile$pre(new_data, object) } - # Pass some extra arguments to be used in post-processor - object$spec$method$pred$quantile$args$p <- quantile pred_call <- make_pred_call(object$spec$method$pred$quantile) res <- eval_tidy(pred_call) diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 9313ede22..a37dc50bd 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -59,25 +59,6 @@ set_pred( ) ) -set_pred( - model = "surv_reg", - eng = "flexsurv", - mode = "regression", - type = "quantile", - value = list( - pre = NULL, - post = flexsurv_quant, - func = c(fun = "summary"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "quantile", - quantiles = expr(quantile) - ) - ) -) - # ------------------------------------------------------------------------------ set_model_engine("surv_reg", mode = "regression", eng = "survival") @@ -133,22 +114,3 @@ set_pred( ) ) ) - -set_pred( - model = "surv_reg", - eng = "survival", - mode = "regression", - type = "quantile", - value = list( - pre = NULL, - post = survreg_quant, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "quantile", - p = expr(quantile) - ) - ) -) diff --git a/man/other_predict.Rd b/man/other_predict.Rd index 6c997e28d..313ff4d72 100644 --- a/man/other_predict.Rd +++ b/man/other_predict.Rd @@ -49,7 +49,7 @@ predict_numeric(object, ...) \method{predict_quantile}{model_fit}( object, new_data, - quantile = (1:9)/10, + quantile_levels = NULL, interval = "none", level = 0.95, ... @@ -103,8 +103,9 @@ interval estimates.} \item{std_error}{A single logical for whether the standard error should be returned (assuming that the model can compute it).} -\item{quantile}{A vector of numbers between 0 and 1 for the quantile being -predicted.} +\item{quantile_levels}{A vector of values between zero and one for the +quantile to be predicted. If the model has a \code{"censored regression"} mode, +this value should be \code{NULL}. For other modes, the default is \code{(1:9)/10}.} } \description{ These are internal functions not meant to be directly called by the user. diff --git a/man/set_args.Rd b/man/set_args.Rd index 6d3b60f3d..b31e4ad4c 100644 --- a/man/set_args.Rd +++ b/man/set_args.Rd @@ -21,7 +21,7 @@ set_mode(object, mode, ...) "regression")} \item{quantile_levels}{A vector of values between zero and one (only for the -\verb{quantile regression } mode); otherwise, it is \code{NULL}. The model uses these +\code{"quantile regression"} mode); otherwise, it is \code{NULL}. The model uses these values to appropriately train quantile regression models to make predictions for these values (e.g., \code{quantile_levels = 0.5} is the median).} } diff --git a/tests/testthat/_snaps/linear_reg_quantreg.md b/tests/testthat/_snaps/linear_reg_quantreg.md new file mode 100644 index 000000000..cba265991 --- /dev/null +++ b/tests/testthat/_snaps/linear_reg_quantreg.md @@ -0,0 +1,9 @@ +# linear quantile regression via quantreg - multiple quantiles + + Code + ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0: + 9) / 9) + Condition + Error in `predict_quantile()`: + ! When the mode is "quantile regression", `quantile_levels` are specified by `set_mode()`. + diff --git a/tests/testthat/test-linear_reg_quantreg.R b/tests/testthat/test-linear_reg_quantreg.R index 7edc7c3a5..0785fe7b5 100644 --- a/tests/testthat/test-linear_reg_quantreg.R +++ b/tests/testthat/test-linear_reg_quantreg.R @@ -83,6 +83,11 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_named(ten_quant_df, c(".pred_quantile", ".quantile_levels", ".row")) expect_true(nrow(ten_quant_df) == nrow(sac_test) * 10) + expect_snapshot( + ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0:9)/9), + error = TRUE + ) + ### ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1,]) diff --git a/tests/testthat/test-surv_reg_survreg.R b/tests/testthat/test-surv_reg_survreg.R index dbb279998..cda216c51 100644 --- a/tests/testthat/test-surv_reg_survreg.R +++ b/tests/testthat/test-surv_reg_survreg.R @@ -10,8 +10,6 @@ complete_form <- survival::Surv(time) ~ group # ------------------------------------------------------------------------------ test_that('survival execution', { - skip_on_travis() - rlang::local_options(lifecycle_verbosity = "quiet") surv_basic <- surv_reg() %>% set_engine("survival") surv_lnorm <- surv_reg(dist = "lognormal") %>% set_engine("survival") @@ -46,7 +44,7 @@ test_that('survival execution', { }) test_that('survival prediction', { - skip_on_travis() + skip_if_not_installed("survival") rlang::local_options(lifecycle_verbosity = "quiet") surv_basic <- surv_reg() %>% set_engine("survival") @@ -61,16 +59,6 @@ test_that('survival prediction', { exp_pred <- predict(extract_fit_engine(res), head(lung)) exp_pred <- tibble(.pred = unname(exp_pred)) expect_equal(exp_pred, predict(res, head(lung))) - - exp_quant <- predict(extract_fit_engine(res), head(lung), p = (2:4)/5, type = "quantile") - exp_quant <- - apply(exp_quant, 1, function(x) - tibble(.pred = x, .quantile = (2:4) / 5)) - exp_quant <- tibble(.pred = exp_quant) - obs_quant <- predict(res, head(lung), type = "quantile", quantile = (2:4)/5) - - expect_equal(as.data.frame(exp_quant), as.data.frame(obs_quant)) - })