Skip to content

Commit 00f3cdb

Browse files
authored
create helper for newdata error (#912)
1 parent 3725810 commit 00f3cdb

19 files changed

+73
-35
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
* Fixed bug with prediction from a boosted tree model fitted with `"xgboost"` using a custom objective function (#875).
1616

1717
* Several internal functions (to help work with `Surv` objects) were added as a standalone file that can be used in other packages via `usethis::use_standalone("tidymodels/parsnip")`.
18+
19+
* Rather than being implemented in each method, the check for the `new_data` argument being mistakenly passed as `newdata` to `multi_predict()` now happens in the generic. Packages re-exporting the `multi_predict()` generic and implementing now-duplicate checks may see new failures and can remove their own analogous checks. This check already existed in all `predict()` methods (via `predict.model_fit()`) and all parsnip `multi_predict()` methods (#525).
20+
1821
* `logistic_reg()` will now warn at `fit()` when the outcome has more than two levels (#545).
1922

23+
2024
# parsnip 1.0.4
2125

2226
* For censored regression models, a "reverse Kaplan-Meier" curve is computed for the censoring distribution. This can be used when evaluating this type of model (#855).

R/aaa_multi_predict.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ multi_predict <- function(object, ...) {
2222
rlang::warn("Model fit failed; cannot make predictions.")
2323
return(NULL)
2424
}
25+
check_for_newdata(...)
2526
UseMethod("multi_predict")
2627
}
2728

R/boost_tree.R

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -480,10 +480,6 @@ get_event_level <- function(model_spec){
480480
#' @param trees An integer vector for the number of trees in the ensemble.
481481
multi_predict._xgb.Booster <-
482482
function(object, new_data, type = NULL, trees = NULL, ...) {
483-
if (any(names(enquos(...)) == "newdata")) {
484-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
485-
}
486-
487483
if (is.null(trees)) {
488484
trees <- object$fit$nIter
489485
}
@@ -608,9 +604,6 @@ C5.0_train <-
608604
#' @rdname multi_predict
609605
multi_predict._C5.0 <-
610606
function(object, new_data, type = NULL, trees = NULL, ...) {
611-
if (any(names(enquos(...)) == "newdata"))
612-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
613-
614607
if (is.null(trees))
615608
trees <- min(object$fit$trials)
616609
trees <- sort(trees)

R/glmnet-engines.R

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,6 @@ multi_predict_glmnet <- function(object,
173173
type = NULL,
174174
penalty = NULL,
175175
...) {
176-
177-
if (any(names(enquos(...)) == "newdata")) {
178-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
179-
}
180-
181176
if (object$spec$mode == "classification") {
182177
if (is_quosure(penalty)) {
183178
penalty <- eval_tidy(penalty)

R/mars.R

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,6 @@ earth_reg_updater <- function(num, object, new_data, ...) {
150150
#' @export
151151
multi_predict._earth <-
152152
function(object, new_data, type = NULL, num_terms = NULL, ...) {
153-
if (any(names(enquos(...)) == "newdata"))
154-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
155-
156153
load_libs(object, quiet = TRUE, attach = TRUE)
157154

158155
if (is.null(num_terms))

R/misc.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,3 +483,13 @@ check_case_weights <- function(x, spec) {
483483
}
484484
invisible(NULL)
485485
}
486+
487+
# -----------------------------------------------------------------------------
488+
check_for_newdata <- function(..., call = rlang::caller_env()) {
489+
if (any(names(list(...)) == "newdata")) {
490+
rlang::abort(
491+
"Please use `new_data` instead of `newdata`.",
492+
call = call
493+
)
494+
}
495+
}

R/mlp.R

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,9 +366,6 @@ mlp_num_weights <- function(p, hidden_units, classes) {
366366
#' @export
367367
multi_predict._torch_mlp <-
368368
function(object, new_data, type = NULL, epochs = NULL, ...) {
369-
if (any(names(enquos(...)) == "newdata"))
370-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
371-
372369
load_libs(object, quiet = TRUE, attach = TRUE)
373370

374371
if (is.null(epochs))

R/nearest_neighbor.R

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,6 @@ translate.nearest_neighbor <- function(x, engine = x$engine, ...) {
149149
#' @export
150150
multi_predict._train.kknn <-
151151
function(object, new_data, type = NULL, neighbors = NULL, ...) {
152-
if (any(names(enquos(...)) == "newdata"))
153-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
154-
155152
if (is.null(neighbors))
156153
neighbors <- rlang::eval_tidy(object$fit$call$ks)
157154
neighbors <- sort(neighbors)

R/predict.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,15 +324,13 @@ make_pred_call <- function(x) {
324324
cl
325325
}
326326

327-
check_pred_type_dots <- function(object, type, ...) {
327+
check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env()) {
328328
the_dots <- list(...)
329329
nms <- names(the_dots)
330330

331331
# ----------------------------------------------------------------------------
332332

333-
if (any(names(the_dots) == "newdata")) {
334-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
335-
}
333+
check_for_newdata(..., call = call)
336334

337335
# ----------------------------------------------------------------------------
338336

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# submodel prediction
2+
3+
Code
4+
multi_predict(class_fit, newdata = wa_churn[1:4, vars], trees = 4, type = "prob")
5+
Condition
6+
Error in `multi_predict()`:
7+
! Please use `new_data` instead of `newdata`.
8+

0 commit comments

Comments
 (0)