@@ -94,41 +94,22 @@ mlp <-
9494 hidden_units = NULL , penalty = NULL , dropout = NULL , epochs = NULL ,
9595 activation = NULL ,
9696 ... ) {
97+
9798 others <- enquos(... )
98- hidden_units <- enquo(hidden_units )
99- penalty <- enquo(penalty )
100- dropout <- enquo(dropout )
101- epochs <- enquo(epochs )
102- activation <- enquo(activation )
103-
104-
105- act_funs <- c(" linear" , " softmax" , " relu" , " elu" )
106- if (is.numeric(hidden_units ))
107- if (hidden_units < 2 )
108- stop(" There must be at least two hidden units" , call. = FALSE )
109- if (is.numeric(penalty ))
110- if (penalty < 0 )
111- stop(" The amount of weight decay must be >= 0." , call. = FALSE )
112- if (is.numeric(dropout ))
113- if (dropout < 0 | dropout > = 1 )
114- stop(" The dropout proportion must be on [0, 1)." , call. = FALSE )
115- if (is.numeric(penalty ) & is.numeric(dropout ))
116- if (dropout > 0 & penalty > 0 )
117- stop(" Both weight decay and dropout should not be specified." , call. = FALSE )
118- if (is.character(activation ))
119- if (! any(activation %in% c(act_funs )))
120- stop(" `activation should be one of: " ,
121- paste0(" '" , act_funs , " '" , collapse = " , " ),
122- call. = FALSE )
99+
100+ args <- list (
101+ hidden_units = enquo(hidden_units ),
102+ penalty = enquo(penalty ),
103+ dropout = enquo(dropout ),
104+ epochs = enquo(epochs ),
105+ activation = enquo(activation )
106+ )
123107
124108 if (! (mode %in% mlp_modes ))
125109 stop(" `mode` should be one of: " ,
126110 paste0(" '" , mlp_modes , " '" , collapse = " , " ),
127111 call. = FALSE )
128112
129- args <- list (hidden_units = hidden_units , penalty = penalty , dropout = dropout ,
130- epochs = epochs , activation = activation )
131-
132113 no_value <- ! vapply(others , is.null , logical (1 ))
133114 others <- others [no_value ]
134115
@@ -177,14 +158,14 @@ update.mlp <-
177158 fresh = FALSE ,
178159 ... ) {
179160 others <- enquos(... )
180- hidden_units <- enquo(hidden_units )
181- penalty <- enquo(penalty )
182- dropout <- enquo(dropout )
183- epochs <- enquo(epochs )
184- activation <- enquo(activation )
185161
186- args <- list (hidden_units = hidden_units , penalty = penalty , dropout = dropout ,
187- epochs = epochs , activation = activation )
162+ args <- list (
163+ hidden_units = enquo(hidden_units ),
164+ penalty = enquo(penalty ),
165+ dropout = enquo(dropout ),
166+ epochs = enquo(epochs ),
167+ activation = enquo(activation )
168+ )
188169
189170 # TODO make these blocks into a function and document well
190171 if (fresh ) {
@@ -231,3 +212,36 @@ translate.mlp <- function(x, engine, ...) {
231212 }
232213 x
233214}
215+
216+ # ------------------------------------------------------------------------------
217+
218+ check_args.mlp <- function (object ) {
219+
220+ args <- lapply(object $ args , rlang :: eval_tidy )
221+
222+ if (is.numeric(args $ hidden_units ))
223+ if (args $ hidden_units < 2 )
224+ stop(" There must be at least two hidden units" , call. = FALSE )
225+
226+ if (is.numeric(args $ penalty ))
227+ if (args $ penalty < 0 )
228+ stop(" The amount of weight decay must be >= 0." , call. = FALSE )
229+
230+ if (is.numeric(args $ dropout ))
231+ if (args $ dropout < 0 | args $ dropout > = 1 )
232+ stop(" The dropout proportion must be on [0, 1)." , call. = FALSE )
233+
234+ if (is.numeric(args $ penalty ) & is.numeric(args $ dropout ))
235+ if (args $ dropout > 0 & args $ penalty > 0 )
236+ stop(" Both weight decay and dropout should not be specified." , call. = FALSE )
237+
238+ act_funs <- c(" linear" , " softmax" , " relu" , " elu" )
239+
240+ if (is.character(args $ activation ))
241+ if (! any(args $ activation %in% c(act_funs )))
242+ stop(" `activation should be one of: " ,
243+ paste0(" '" , act_funs , " '" , collapse = " , " ),
244+ call. = FALSE )
245+
246+ invisible (object )
247+ }
0 commit comments