Skip to content

Commit 4ac31fb

Browse files
committed
Updating for mars() engine-specific args
1 parent 2bf7b04 commit 4ac31fb

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

R/mars.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ update.mars <-
9393
parameters = NULL,
9494
num_terms = NULL, prod_degree = NULL, prune_method = NULL,
9595
fresh = FALSE, ...) {
96-
update_dot_check(...)
96+
97+
eng_args <- update_engine_parameters(object$eng_args, ...)
9798

9899
if (!is.null(parameters)) {
99100
parameters <- check_final_param(parameters)
@@ -109,12 +110,15 @@ update.mars <-
109110

110111
if (fresh) {
111112
object$args <- args
113+
object$eng_args <- eng_args
112114
} else {
113115
null_args <- map_lgl(args, null_value)
114116
if (any(null_args))
115117
args <- args[!null_args]
116118
if (length(args) > 0)
117119
object$args[names(args)] <- args
120+
if (length(eng_args) > 0)
121+
object$eng_args[names(eng_args)] <- eng_args
118122
}
119123

120124
new_model_spec(

tests/testthat/test_mars.R

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,12 @@ test_that('updating', {
7878
expr1 <- mars() %>% set_engine("earth", model = FALSE)
7979
expr1_exp <- mars(num_terms = 1) %>% set_engine("earth", model = FALSE)
8080

81-
expr2 <- mars(num_terms = varying()) %>% set_engine("earth")
81+
expr2 <- mars(num_terms = varying()) %>% set_engine("earth", nk = varying())
8282
expr2_exp <- mars(num_terms = varying()) %>% set_engine("earth", nk = 10)
8383

84-
expr3 <- mars(num_terms = 1, prod_degree = varying()) %>% set_engine("earth")
85-
expr3_exp <- mars(num_terms = 1) %>% set_engine("earth")
84+
expr3 <- mars(num_terms = 1, prod_degree = varying()) %>% set_engine("earth", nk = varying())
85+
expr3_fre <- mars(num_terms = 1) %>% set_engine("earth", nk = varying())
86+
expr3_exp <- mars(num_terms = 1) %>% set_engine("earth", nk = 10)
8687

8788
expr4 <- mars(num_terms = 0) %>% set_engine("earth", nk = 10)
8889
expr4_exp <- mars(num_terms = 0) %>% set_engine("earth", nk = 10, trace = 2)
@@ -91,7 +92,9 @@ test_that('updating', {
9192
expr5_exp <- mars(num_terms = 1) %>% set_engine("earth", nk = 10, trace = 2)
9293

9394
expect_equal(update(expr1, num_terms = 1), expr1_exp)
94-
expect_equal(update(expr3, num_terms = 1, fresh = TRUE), expr3_exp)
95+
expect_equal(update(expr2, nk = 10), expr2_exp)
96+
expect_equal(update(expr3, num_terms = 1, fresh = TRUE), expr3_fre)
97+
expect_equal(update(expr3, num_terms = 1, fresh = TRUE, nk = 10), expr3_exp)
9598

9699
param_tibb <- tibble::tibble(num_terms = 3, prod_degree = 1)
97100
param_list <- as.list(param_tibb)

0 commit comments

Comments
 (0)