Skip to content

Commit

Permalink
quantile -> quantile_levels for #1203
Browse files Browse the repository at this point in the history
  • Loading branch information
‘topepo’ committed Sep 25, 2024
1 parent bef131b commit 539e5a7
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 10 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).

## Breaking Change

* For quantile prediction, the `predict()` argument has been changed from `quantile` to `quantile_levels` for consistency. This does not affect models with mode `"quantile regression"`.

# parsnip 1.2.1

Expand Down
2 changes: 1 addition & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())

# ----------------------------------------------------------------------------

other_args <- c("interval", "level", "std_error", "quantile",
other_args <- c("interval", "level", "std_error", "quantile_levels",
"time", "eval_time", "increasing")
is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
Expand Down
21 changes: 16 additions & 5 deletions R/predict_quantile.R
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#' @keywords internal
#' @rdname other_predict
#' @param quantile A vector of numbers between 0 and 1 for the quantile being
#' predicted.
#' @param quantile_levels A vector of values between zero and one.
#' @inheritParams predict.model_fit
#' @method predict_quantile model_fit
#' @export predict_quantile.model_fit
#' @export
predict_quantile.model_fit <- function(object,
new_data,
quantile = (1:9)/10,
quantile_levels = NULL,
interval = "none",
level = 0.95,
...) {
Expand All @@ -20,15 +19,27 @@ predict_quantile.model_fit <- function(object,
return(NULL)
}

if (object$spec$mode != "quantile regression") {
if (is.null(quantile_levels)) {
quantile_levels <- (1:9)/10
}
hardhat::check_quantile_levels(quantile_levels)
# Pass some extra arguments to be used in post-processor
object$quantile_levels <- quantile_levels
} else {
if (!is.null(quantile_levels)) {
cli::cli_abort("{.arg quantile_levels} are specified by {.fn set_mode}
when the mode is {.val quantile regression}.")
}
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$quantile$pre)) {
new_data <- object$spec$method$pred$quantile$pre(new_data, object)
}

# Pass some extra arguments to be used in post-processor
object$spec$method$pred$quantile$args$p <- quantile
pred_call <- make_pred_call(object$spec$method$pred$quantile)

res <- eval_tidy(pred_call)
Expand Down
5 changes: 2 additions & 3 deletions man/other_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/set_args.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions tests/testthat/test-linear_reg_quantreg.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ test_that('linear quantile regression via quantreg - multiple quantiles', {
expect_named(ten_quant_df, c(".pred_quantile", ".quantile_levels", ".row"))
expect_true(nrow(ten_quant_df) == nrow(sac_test) * 10)

expect_snapshot(
ten_quant_pred <- predict(ten_quant, new_data = sac_test),
error = TRUE
)

###

ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1,])
Expand Down

0 comments on commit 539e5a7

Please sign in to comment.