From 83c744bccb666bd290bc5e97abaeff529545f0e8 Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 12 Sep 2024 18:52:09 -0400 Subject: [PATCH] re-enable quantiles prediction for #1203 --- R/predict_quantile.R | 15 ++++++++++----- man/other_predict.Rd | 9 ++++++++- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/R/predict_quantile.R b/R/predict_quantile.R index f9154d6a9..efe0458f8 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -6,7 +6,12 @@ #' @method predict_quantile model_fit #' @export predict_quantile.model_fit #' @export -predict_quantile.model_fit <- function(object, new_data, ...) { +predict_quantile.model_fit <- function(object, + new_data, + quantile = (1:9)/10, + interval = "none", + level = 0.95, + ...) { check_spec_pred_type(object, "quantile") @@ -18,12 +23,11 @@ predict_quantile.model_fit <- function(object, new_data, ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$quantile$pre)) { + if (!is.null(object$spec$method$pred$quantile$pre)) new_data <- object$spec$method$pred$quantile$pre(new_data, object) - } # Pass some extra arguments to be used in post-processor - object$spec$method$pred$quantile$args$quantile_level <- object$quantile_level + object$spec$method$pred$quantile$args$p <- quantile pred_call <- make_pred_call(object$spec$method$pred$quantile) res <- eval_tidy(pred_call) @@ -40,5 +44,6 @@ predict_quantile.model_fit <- function(object, new_data, ...) { # @keywords internal # @rdname other_predict # @inheritParams predict.model_fit -predict_quantile <- function (object, ...) +predict_quantile <- function (object, ...) { UseMethod("predict_quantile") +} diff --git a/man/other_predict.Rd b/man/other_predict.Rd index bc1d104bf..6c997e28d 100644 --- a/man/other_predict.Rd +++ b/man/other_predict.Rd @@ -46,7 +46,14 @@ predict_linear_pred(object, ...) predict_numeric(object, ...) -\method{predict_quantile}{model_fit}(object, new_data, ...) +\method{predict_quantile}{model_fit}( + object, + new_data, + quantile = (1:9)/10, + interval = "none", + level = 0.95, + ... +) \method{predict_survival}{model_fit}( object,