Skip to content

Commit

Permalink
Change to quantile argument to quantile levels (#1208)
Browse files Browse the repository at this point in the history
* quantile -> quantile_levels for #1203

* defer test until censored updates in new PR

* update docs for quantile_levels

* update test

* disable quantile predictions for surv_reg

---------

Co-authored-by: ‘topepo’ <‘[email protected]’>
  • Loading branch information
topepo and ‘topepo’ authored Sep 26, 2024
1 parent bef131b commit 861a64a
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 61 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).

## Breaking Change

* For quantile prediction, the `predict()` argument has been changed from `quantile` to `quantile_levels` for consistency. This does not affect models with mode `"quantile regression"`.
* The quantile regression prediction type was disabled for the deprecated `surv_reg()` model.

# parsnip 1.2.1

Expand Down
2 changes: 1 addition & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())

# ----------------------------------------------------------------------------

other_args <- c("interval", "level", "std_error", "quantile",
other_args <- c("interval", "level", "std_error", "quantile_levels",
"time", "eval_time", "increasing")
is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
Expand Down
23 changes: 18 additions & 5 deletions R/predict_quantile.R
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#' @keywords internal
#' @rdname other_predict
#' @param quantile A vector of numbers between 0 and 1 for the quantile being
#' predicted.
#' @param quantile_levels A vector of values between zero and one for the
#' quantile to be predicted. If the model has a `"censored regression"` mode,
#' this value should be `NULL`. For other modes, the default is `(1:9)/10`.
#' @inheritParams predict.model_fit
#' @method predict_quantile model_fit
#' @export predict_quantile.model_fit
#' @export
predict_quantile.model_fit <- function(object,
new_data,
quantile = (1:9)/10,
quantile_levels = NULL,
interval = "none",
level = 0.95,
...) {
Expand All @@ -20,15 +21,27 @@ predict_quantile.model_fit <- function(object,
return(NULL)
}

if (object$spec$mode == "quantile regression") {
if (!is.null(quantile_levels)) {
cli::cli_abort("When the mode is {.val quantile regression},
{.arg quantile_levels} are specified by {.fn set_mode}.")
}
} else {
if (is.null(quantile_levels)) {
quantile_levels <- (1:9)/10
}
hardhat::check_quantile_levels(quantile_levels)
# Pass some extra arguments to be used in post-processor
object$quantile_levels <- quantile_levels
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$quantile$pre)) {
new_data <- object$spec$method$pred$quantile$pre(new_data, object)
}

# Pass some extra arguments to be used in post-processor
object$spec$method$pred$quantile$args$p <- quantile
pred_call <- make_pred_call(object$spec$method$pred$quantile)

res <- eval_tidy(pred_call)
Expand Down
38 changes: 0 additions & 38 deletions R/surv_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,6 @@ set_pred(
)
)

set_pred(
model = "surv_reg",
eng = "flexsurv",
mode = "regression",
type = "quantile",
value = list(
pre = NULL,
post = flexsurv_quant,
func = c(fun = "summary"),
args =
list(
object = expr(object$fit),
newdata = expr(new_data),
type = "quantile",
quantiles = expr(quantile)
)
)
)

# ------------------------------------------------------------------------------

set_model_engine("surv_reg", mode = "regression", eng = "survival")
Expand Down Expand Up @@ -133,22 +114,3 @@ set_pred(
)
)
)

set_pred(
model = "surv_reg",
eng = "survival",
mode = "regression",
type = "quantile",
value = list(
pre = NULL,
post = survreg_quant,
func = c(fun = "predict"),
args =
list(
object = expr(object$fit),
newdata = expr(new_data),
type = "quantile",
p = expr(quantile)
)
)
)
7 changes: 4 additions & 3 deletions man/other_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/set_args.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions tests/testthat/_snaps/linear_reg_quantreg.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# linear quantile regression via quantreg - multiple quantiles

Code
ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0:
9) / 9)
Condition
Error in `predict_quantile()`:
! When the mode is "quantile regression", `quantile_levels` are specified by `set_mode()`.

5 changes: 5 additions & 0 deletions tests/testthat/test-linear_reg_quantreg.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ test_that('linear quantile regression via quantreg - multiple quantiles', {
expect_named(ten_quant_df, c(".pred_quantile", ".quantile_levels", ".row"))
expect_true(nrow(ten_quant_df) == nrow(sac_test) * 10)

expect_snapshot(
ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0:9)/9),
error = TRUE
)

###

ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1,])
Expand Down
14 changes: 1 addition & 13 deletions tests/testthat/test-surv_reg_survreg.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ complete_form <- survival::Surv(time) ~ group
# ------------------------------------------------------------------------------

test_that('survival execution', {
skip_on_travis()

rlang::local_options(lifecycle_verbosity = "quiet")
surv_basic <- surv_reg() %>% set_engine("survival")
surv_lnorm <- surv_reg(dist = "lognormal") %>% set_engine("survival")
Expand Down Expand Up @@ -46,7 +44,7 @@ test_that('survival execution', {
})

test_that('survival prediction', {
skip_on_travis()
skip_if_not_installed("survival")

rlang::local_options(lifecycle_verbosity = "quiet")
surv_basic <- surv_reg() %>% set_engine("survival")
Expand All @@ -61,16 +59,6 @@ test_that('survival prediction', {
exp_pred <- predict(extract_fit_engine(res), head(lung))
exp_pred <- tibble(.pred = unname(exp_pred))
expect_equal(exp_pred, predict(res, head(lung)))

exp_quant <- predict(extract_fit_engine(res), head(lung), p = (2:4)/5, type = "quantile")
exp_quant <-
apply(exp_quant, 1, function(x)
tibble(.pred = x, .quantile = (2:4) / 5))
exp_quant <- tibble(.pred = exp_quant)
obs_quant <- predict(res, head(lung), type = "quantile", quantile = (2:4)/5)

expect_equal(as.data.frame(exp_quant), as.data.frame(obs_quant))

})


0 comments on commit 861a64a

Please sign in to comment.