Skip to content

Commit 33a23f1

Browse files
committed
Update models to enquo args and create a list immediately. Move model spec testing to the check_args() generic.
1 parent f3bee97 commit 33a23f1

File tree

11 files changed

+214
-154
lines changed

11 files changed

+214
-154
lines changed

R/boost_tree.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ check_args.boost_tree <- function(object) {
253253
if (is.numeric(args$min_n) && args$min_n < 0)
254254
stop("`min_n` should be >= 1", call. = FALSE)
255255

256+
invisible(object)
256257
}
257258

258259
# xgboost helpers --------------------------------------------------------------

R/fit_helpers.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ form_form <-
2222
}
2323

2424
# evaluate quoted args once here to check them
25-
check_args(object)
25+
object <- check_args(object)
2626

2727
# sub in arguments to actual syntax for corresponding engine
2828
object <- translate(object, engine = object$engine)
@@ -74,7 +74,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
7474
}
7575

7676
# evaluate quoted args once here to check them
77-
check_args(object)
77+
object <- check_args(object)
7878

7979
# sub in arguments to actual syntax for corresponding engine
8080
object <- translate(object, engine = object$engine)

R/linear_reg.R

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,13 @@ linear_reg <-
104104
penalty = NULL,
105105
mixture = NULL,
106106
...) {
107+
107108
others <- enquos(...)
108-
penalty <- enquo(penalty)
109-
mixture <- enquo(mixture)
109+
110+
args <- list(
111+
penalty = enquo(penalty),
112+
mixture = enquo(mixture)
113+
)
110114

111115
if (!(mode %in% linear_reg_modes))
112116
stop(
@@ -115,15 +119,6 @@ linear_reg <-
115119
call. = FALSE
116120
)
117121

118-
if (all(is.numeric(penalty)) && any(penalty < 0))
119-
stop("The amount of regularization should be >= 0", call. = FALSE)
120-
if (is.numeric(mixture) && (mixture < 0 | mixture > 1))
121-
stop("The mixture proportion should be within [0,1]", call. = FALSE)
122-
if (is.numeric(mixture) && length(mixture) > 1)
123-
stop("Only one value of `mixture` is allowed.", call. = FALSE)
124-
125-
args <- list(penalty = penalty, mixture = mixture)
126-
127122
no_value <- !vapply(others, is.null, logical(1))
128123
others <- others[no_value]
129124

@@ -169,16 +164,13 @@ update.linear_reg <-
169164
penalty = NULL, mixture = NULL,
170165
fresh = FALSE,
171166
...) {
172-
others <- enquos(...)
173-
penalty <- enquo(penalty)
174-
mixture <- enquo(mixture)
175167

176-
if (is.numeric(penalty) && penalty < 0)
177-
stop("The amount of regularization should be >= 0", call. = FALSE)
178-
if (is.numeric(mixture) && (mixture < 0 | mixture > 1))
179-
stop("The mixture proportion should be within [0,1]", call. = FALSE)
168+
others <- enquos(...)
180169

181-
args <- list(penalty = penalty, mixture = mixture)
170+
args <- list(
171+
penalty = enquo(penalty),
172+
mixture = enquo(mixture)
173+
)
182174

183175
if (fresh) {
184176
object$args <- args
@@ -200,6 +192,21 @@ update.linear_reg <-
200192
object
201193
}
202194

195+
# ------------------------------------------------------------------------------
196+
197+
check_args.linear_reg <- function(object) {
198+
199+
args <- lapply(object$args, rlang::eval_tidy)
200+
201+
if (all(is.numeric(args$penalty)) && any(args$penalty < 0))
202+
stop("The amount of regularization should be >= 0", call. = FALSE)
203+
if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1))
204+
stop("The mixture proportion should be within [0,1]", call. = FALSE)
205+
if (is.numeric(args$mixture) && length(args$mixture) > 1)
206+
stop("Only one value of `mixture` is allowed.", call. = FALSE)
207+
208+
invisible(object)
209+
}
203210

204211
# ------------------------------------------------------------------------------
205212

R/logistic_reg.R

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,13 @@ logistic_reg <-
102102
penalty = NULL,
103103
mixture = NULL,
104104
...) {
105+
105106
others <- enquos(...)
106-
penalty <- enquo(penalty)
107-
mixture <- enquo(mixture)
107+
108+
args <- list(
109+
penalty = enquo(penalty),
110+
mixture = enquo(mixture)
111+
)
108112

109113
if (!(mode %in% logistic_reg_modes))
110114
stop(
@@ -113,13 +117,6 @@ logistic_reg <-
113117
call. = FALSE
114118
)
115119

116-
if (is.numeric(penalty) && penalty < 0)
117-
stop("The amount of regularization should be >= 0", call. = FALSE)
118-
if (is.numeric(mixture) && (mixture < 0 | mixture > 1))
119-
stop("The mixture proportion should be within [0,1]", call. = FALSE)
120-
121-
args <- list(penalty = penalty, mixture = mixture)
122-
123120
no_value <- !vapply(others, is.null, logical(1))
124121
others <- others[no_value]
125122

@@ -165,16 +162,13 @@ update.logistic_reg <-
165162
penalty = NULL, mixture = NULL,
166163
fresh = FALSE,
167164
...) {
168-
others <- enquos(...)
169-
penalty <- enquo(penalty)
170-
mixture <- enquo(mixture)
171165

172-
if (is.numeric(penalty) && penalty < 0)
173-
stop("The amount of regularization should be >= 0", call. = FALSE)
174-
if (is.numeric(mixture) && (mixture < 0 | mixture > 1))
175-
stop("The mixture proportion should be within [0,1]", call. = FALSE)
166+
others <- enquos(...)
176167

177-
args <- list(penalty = penalty, mixture = mixture)
168+
args <- list(
169+
penalty = enquo(penalty),
170+
mixture = enquo(mixture)
171+
)
178172

179173
if (fresh) {
180174
object$args <- args
@@ -196,6 +190,21 @@ update.logistic_reg <-
196190
object
197191
}
198192

193+
# ------------------------------------------------------------------------------
194+
195+
check_args.logistic_reg <- function(object) {
196+
197+
args <- lapply(object$args, rlang::eval_tidy)
198+
199+
if (all(is.numeric(args$penalty)) && any(args$penalty < 0))
200+
stop("The amount of regularization should be >= 0", call. = FALSE)
201+
if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1))
202+
stop("The mixture proportion should be within [0,1]", call. = FALSE)
203+
if (is.numeric(args$mixture) && length(args$mixture) > 1)
204+
stop("Only one value of `mixture` is allowed.", call. = FALSE)
205+
206+
invisible(object)
207+
}
199208

200209
# ------------------------------------------------------------------------------
201210

R/mars.R

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,29 +69,20 @@ mars <-
6969
function(mode = "unknown",
7070
num_terms = NULL, prod_degree = NULL, prune_method = NULL,
7171
...) {
72+
7273
others <- enquos(...)
73-
num_terms <- enquo(num_terms)
74-
prod_degree <- enquo(prod_degree)
75-
prune_method <- enquo(prune_method)
74+
75+
args <- list(
76+
num_terms = enquo(num_terms),
77+
prod_degree = enquo(prod_degree),
78+
prune_method = enquo(prune_method)
79+
)
7680

7781
if (!(mode %in% mars_modes))
7882
stop("`mode` should be one of: ",
7983
paste0("'", mars_modes, "'", collapse = ", "),
8084
call. = FALSE)
8185

82-
if (is.numeric(prod_degree) && prod_degree < 0)
83-
stop("`prod_degree` should be >= 1", call. = FALSE)
84-
if (is.numeric(num_terms) && num_terms < 0)
85-
stop("`num_terms` should be >= 1", call. = FALSE)
86-
if (!is_varying(prune_method) &&
87-
!is.null(prune_method) &&
88-
is.character(prune_method))
89-
stop("`prune_method` should be a single string value", call. = FALSE)
90-
91-
args <- list(num_terms = num_terms,
92-
prod_degree = prod_degree,
93-
prune_method = prune_method)
94-
9586
no_value <- !vapply(others, is.null, logical(1))
9687
others <- others[no_value]
9788

@@ -131,14 +122,14 @@ update.mars <-
131122
num_terms = NULL, prod_degree = NULL, prune_method = NULL,
132123
fresh = FALSE,
133124
...) {
125+
134126
others <- enquos(...)
135-
num_terms <- enquo(num_terms)
136-
prod_degree <- enquo(prod_degree)
137-
prune_method <- enquo(prune_method)
138127

139-
args <- list(num_terms = num_terms,
140-
prod_degree = prod_degree,
141-
prune_method = prune_method)
128+
args <- list(
129+
num_terms = enquo(num_terms),
130+
prod_degree = enquo(prod_degree),
131+
prune_method = enquo(prune_method)
132+
)
142133

143134
if (fresh) {
144135
object$args <- args
@@ -179,6 +170,26 @@ translate.mars <- function(x, engine, ...) {
179170

180171
# ------------------------------------------------------------------------------
181172

173+
check_args.mars <- function(object) {
174+
175+
args <- lapply(object$args, rlang::eval_tidy)
176+
177+
if (is.numeric(args$prod_degree) && args$prod_degree < 0)
178+
stop("`prod_degree` should be >= 1", call. = FALSE)
179+
180+
if (is.numeric(args$num_terms) && args$num_terms < 0)
181+
stop("`num_terms` should be >= 1", call. = FALSE)
182+
183+
if (!is_varying(args$prune_method) &&
184+
!is.null(args$prune_method) &&
185+
is.character(args$prune_method))
186+
stop("`prune_method` should be a single string value", call. = FALSE)
187+
188+
invisible(object)
189+
}
190+
191+
# ------------------------------------------------------------------------------
192+
182193
#' @importFrom purrr map_dfr
183194
earth_submodel_pred <- function(object, new_data, terms = 2:3, ...) {
184195
map_dfr(terms, earth_reg_updater, object = object, newdata = new_data, ...)

R/misc.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,5 +176,5 @@ check_args <- function(object) {
176176
}
177177

178178
check_args.default <- function(object) {
179-
# nothing to do
179+
invisible(object)
180180
}

R/mlp.R

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -94,41 +94,22 @@ mlp <-
9494
hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL,
9595
activation = NULL,
9696
...) {
97+
9798
others <- enquos(...)
98-
hidden_units <- enquo(hidden_units)
99-
penalty <- enquo(penalty)
100-
dropout <- enquo(dropout)
101-
epochs <- enquo(epochs)
102-
activation <- enquo(activation)
103-
104-
105-
act_funs <- c("linear", "softmax", "relu", "elu")
106-
if (is.numeric(hidden_units))
107-
if (hidden_units < 2)
108-
stop("There must be at least two hidden units", call. = FALSE)
109-
if (is.numeric(penalty))
110-
if (penalty < 0)
111-
stop("The amount of weight decay must be >= 0.", call. = FALSE)
112-
if (is.numeric(dropout))
113-
if (dropout < 0 | dropout >= 1)
114-
stop("The dropout proportion must be on [0, 1).", call. = FALSE)
115-
if (is.numeric(penalty) & is.numeric(dropout))
116-
if (dropout > 0 & penalty > 0)
117-
stop("Both weight decay and dropout should not be specified.", call. = FALSE)
118-
if (is.character(activation))
119-
if (!any(activation %in% c(act_funs)))
120-
stop("`activation should be one of: ",
121-
paste0("'", act_funs, "'", collapse = ", "),
122-
call. = FALSE)
99+
100+
args <- list(
101+
hidden_units = enquo(hidden_units),
102+
penalty = enquo(penalty),
103+
dropout = enquo(dropout),
104+
epochs = enquo(epochs),
105+
activation = enquo(activation)
106+
)
123107

124108
if (!(mode %in% mlp_modes))
125109
stop("`mode` should be one of: ",
126110
paste0("'", mlp_modes, "'", collapse = ", "),
127111
call. = FALSE)
128112

129-
args <- list(hidden_units = hidden_units, penalty = penalty, dropout = dropout,
130-
epochs = epochs, activation = activation)
131-
132113
no_value <- !vapply(others, is.null, logical(1))
133114
others <- others[no_value]
134115

@@ -177,14 +158,14 @@ update.mlp <-
177158
fresh = FALSE,
178159
...) {
179160
others <- enquos(...)
180-
hidden_units <- enquo(hidden_units)
181-
penalty <- enquo(penalty)
182-
dropout <- enquo(dropout)
183-
epochs <- enquo(epochs)
184-
activation <- enquo(activation)
185161

186-
args <- list(hidden_units = hidden_units, penalty = penalty, dropout = dropout,
187-
epochs = epochs, activation = activation)
162+
args <- list(
163+
hidden_units = enquo(hidden_units),
164+
penalty = enquo(penalty),
165+
dropout = enquo(dropout),
166+
epochs = enquo(epochs),
167+
activation = enquo(activation)
168+
)
188169

189170
# TODO make these blocks into a function and document well
190171
if (fresh) {
@@ -231,3 +212,36 @@ translate.mlp <- function(x, engine, ...) {
231212
}
232213
x
233214
}
215+
216+
# ------------------------------------------------------------------------------
217+
218+
check_args.mlp <- function(object) {
219+
220+
args <- lapply(object$args, rlang::eval_tidy)
221+
222+
if (is.numeric(args$hidden_units))
223+
if (args$hidden_units < 2)
224+
stop("There must be at least two hidden units", call. = FALSE)
225+
226+
if (is.numeric(args$penalty))
227+
if (args$penalty < 0)
228+
stop("The amount of weight decay must be >= 0.", call. = FALSE)
229+
230+
if (is.numeric(args$dropout))
231+
if (args$dropout < 0 | args$dropout >= 1)
232+
stop("The dropout proportion must be on [0, 1).", call. = FALSE)
233+
234+
if (is.numeric(args$penalty) & is.numeric(args$dropout))
235+
if (args$dropout > 0 & args$penalty > 0)
236+
stop("Both weight decay and dropout should not be specified.", call. = FALSE)
237+
238+
act_funs <- c("linear", "softmax", "relu", "elu")
239+
240+
if (is.character(args$activation))
241+
if (!any(args$activation %in% c(act_funs)))
242+
stop("`activation should be one of: ",
243+
paste0("'", act_funs, "'", collapse = ", "),
244+
call. = FALSE)
245+
246+
invisible(object)
247+
}

0 commit comments

Comments
 (0)