Skip to content

Commit a8f3c83

Browse files
committed
model specification constructor function
1 parent c4a208f commit a8f3c83

File tree

12 files changed

+154
-112
lines changed

12 files changed

+154
-112
lines changed

R/arguments.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,14 @@ set_args <- function(object, ...) {
9898
object$eng_args[[i]] <- the_dots[[i]]
9999
}
100100
}
101-
object
101+
new_model_spec(
102+
cls = class(object)[1],
103+
args = object$args,
104+
eng_args = object$eng_args,
105+
mode = object$mode,
106+
method = NULL,
107+
engine = object$engine
108+
)
102109
}
103110

104111
#' @rdname set_args

R/boost_tree.R

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,14 @@ boost_tree <-
122122
sample_size = enquo(sample_size)
123123
)
124124

125-
if (!(mode %in% boost_tree_modes))
126-
stop("`mode` should be one of: ",
127-
paste0("'", boost_tree_modes, "'", collapse = ", "),
128-
call. = FALSE)
129-
130-
out <- list(args = args, eng_args = NULL,
131-
mode = mode, method = NULL, engine = NULL)
132-
class(out) <- make_classes("boost_tree")
133-
out
125+
new_model_spec(
126+
"boost_tree",
127+
args,
128+
eng_args = NULL,
129+
mode,
130+
method = NULL,
131+
engine = NULL
132+
)
134133
}
135134

136135
#' @export
@@ -191,7 +190,14 @@ update.boost_tree <-
191190
object$args[names(args)] <- args
192191
}
193192

194-
object
193+
new_model_spec(
194+
"boost_tree",
195+
args = object$args,
196+
eng_args = object$eng_args,
197+
mode = object$mode,
198+
method = NULL,
199+
engine = object$engine
200+
)
195201
}
196202

197203
# ------------------------------------------------------------------------------

R/engines.R

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,21 @@ check_installs <- function(x) {
7474
#' translate(mod, engine = "glmnet")
7575
#' @export
7676
set_engine <- function(object, engine, ...) {
77+
if (!inherits(object, "model_spec")) {
78+
stop("`object` should have class 'model_spec'.", call. = FALSE)
79+
}
7780
if (!is.character(engine) | length(engine) != 1)
7881
stop("`engine` should be a single character value.", call. = FALSE)
7982

8083
object$engine <- engine
8184
object <- check_engine(object)
8285

83-
84-
object$eng_args <- enquos(...)
85-
object
86+
new_model_spec(
87+
cls = class(object)[1],
88+
args = object$args,
89+
eng_args = enquos(...),
90+
mode = object$mode,
91+
method = NULL,
92+
engine = object$engine
93+
)
8694
}

R/linear_reg.R

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,23 +107,14 @@ linear_reg <-
107107
mixture = enquo(mixture)
108108
)
109109

110-
if (!(mode %in% linear_reg_modes))
111-
stop(
112-
"`mode` should be one of: ",
113-
paste0("'", linear_reg_modes, "'", collapse = ", "),
114-
call. = FALSE
115-
)
116-
117-
# write a constructor function
118-
out <- list(
110+
new_model_spec(
111+
"linear_reg",
119112
args = args,
120113
eng_args = NULL,
121114
mode = mode,
122115
method = NULL,
123116
engine = NULL
124117
)
125-
class(out) <- make_classes("linear_reg")
126-
out
127118
}
128119

129120
#' @export
@@ -171,7 +162,14 @@ update.linear_reg <-
171162
object$args[names(args)] <- args
172163
}
173164

174-
object
165+
new_model_spec(
166+
"linear_reg",
167+
args = object$args,
168+
eng_args = object$eng_args,
169+
mode = object$mode,
170+
method = NULL,
171+
engine = object$engine
172+
)
175173
}
176174

177175
# ------------------------------------------------------------------------------

R/logistic_reg.R

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,23 +106,14 @@ logistic_reg <-
106106
mixture = enquo(mixture)
107107
)
108108

109-
if (!(mode %in% logistic_reg_modes))
110-
stop(
111-
"`mode` should be one of: ",
112-
paste0("'", logistic_reg_modes, "'", collapse = ", "),
113-
call. = FALSE
114-
)
115-
116-
# write a constructor function
117-
out <- list(
109+
new_model_spec(
110+
"logistic_reg",
118111
args = args,
119112
eng_args = NULL,
120113
mode = mode,
121114
method = NULL,
122115
engine = NULL
123116
)
124-
class(out) <- make_classes("logistic_reg")
125-
out
126117
}
127118

128119
#' @export
@@ -170,7 +161,14 @@ update.logistic_reg <-
170161
object$args[names(args)] <- args
171162
}
172163

173-
object
164+
new_model_spec(
165+
"logistic_reg",
166+
args = object$args,
167+
eng_args = object$eng_args,
168+
mode = object$mode,
169+
method = NULL,
170+
engine = object$engine
171+
)
174172
}
175173

176174
# ------------------------------------------------------------------------------

R/mars.R

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,14 @@ mars <-
7070
prune_method = enquo(prune_method)
7171
)
7272

73-
if (!(mode %in% mars_modes))
74-
stop("`mode` should be one of: ",
75-
paste0("'", mars_modes, "'", collapse = ", "),
76-
call. = FALSE)
77-
78-
out <- list(args = args, eng_args = NULL,
79-
mode = mode, method = NULL, engine = NULL)
80-
class(out) <- make_classes("mars")
81-
out
73+
new_model_spec(
74+
"mars",
75+
args = args,
76+
eng_args = NULL,
77+
mode = mode,
78+
method = NULL,
79+
engine = NULL
80+
)
8281
}
8382

8483
#' @export
@@ -127,7 +126,14 @@ update.mars <-
127126
object$args[names(args)] <- args
128127
}
129128

130-
object
129+
new_model_spec(
130+
"mars",
131+
args = object$args,
132+
eng_args = object$eng_args,
133+
mode = object$mode,
134+
method = NULL,
135+
engine = object$engine
136+
)
131137
}
132138

133139
# ------------------------------------------------------------------------------

R/misc.R

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,17 @@ update_dot_check <- function(...) {
202202
invisible(NULL)
203203
}
204204

205+
# ------------------------------------------------------------------------------
206+
207+
new_model_spec <- function(cls, args, eng_args, mode, method, engine) {
208+
spec_modes <- get(paste0(cls, "_modes"))
209+
if (!(mode %in% spec_modes))
210+
stop("`mode` should be one of: ",
211+
paste0("'", spec_modes, "'", collapse = ", "),
212+
call. = FALSE)
213+
214+
out <- list(args = args, eng_args = eng_args,
215+
mode = mode, method = method, engine = engine)
216+
class(out) <- make_classes(cls)
217+
out
218+
}

R/mlp.R

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,14 @@ mlp <-
9797
activation = enquo(activation)
9898
)
9999

100-
if (!(mode %in% mlp_modes))
101-
stop("`mode` should be one of: ",
102-
paste0("'", mlp_modes, "'", collapse = ", "),
103-
call. = FALSE)
104-
105-
# write a constructor function
106-
out <- list(args = args, eng_args = NULL,
107-
mode = mode, method = NULL, engine = NULL)
108-
109-
class(out) <- make_classes("mlp")
110-
out
100+
new_model_spec(
101+
"mlp",
102+
args = args,
103+
eng_args = NULL,
104+
mode = mode,
105+
method = NULL,
106+
engine = NULL
107+
)
111108
}
112109

113110
#' @export
@@ -165,7 +162,14 @@ update.mlp <-
165162
object$args[names(args)] <- args
166163
}
167164

168-
object
165+
new_model_spec(
166+
"mlp",
167+
args = object$args,
168+
eng_args = object$eng_args,
169+
mode = object$mode,
170+
method = NULL,
171+
engine = object$engine
172+
)
169173
}
170174

171175
# ------------------------------------------------------------------------------

R/multinom_reg.R

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,14 @@ multinom_reg <-
8989
mixture = enquo(mixture)
9090
)
9191

92-
if (!(mode %in% multinom_reg_modes))
93-
stop(
94-
"`mode` should be one of: ",
95-
paste0("'", multinom_reg_modes, "'", collapse = ", "),
96-
call. = FALSE
97-
)
98-
99-
# write a constructor function
100-
out <- list(
92+
new_model_spec(
93+
"multinom_reg",
10194
args = args,
10295
eng_args = NULL,
10396
mode = mode,
10497
method = NULL,
10598
engine = NULL
10699
)
107-
class(out) <- make_classes("multinom_reg")
108-
out
109100
}
110101

111102
#' @export
@@ -153,7 +144,14 @@ update.multinom_reg <-
153144
object$args[names(args)] <- args
154145
}
155146

156-
object
147+
new_model_spec(
148+
"multinom_reg",
149+
args = object$args,
150+
eng_args = object$eng_args,
151+
mode = object$mode,
152+
method = NULL,
153+
engine = object$engine
154+
)
157155
}
158156

159157
# ------------------------------------------------------------------------------

R/nearest_neighbor.R

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,14 @@ nearest_neighbor <- function(mode = "unknown",
7979
dist_power = enquo(dist_power)
8080
)
8181

82-
## TODO: make a utility function here
83-
if (!(mode %in% nearest_neighbor_modes)) {
84-
stop("`mode` should be one of: ",
85-
paste0("'", nearest_neighbor_modes, "'", collapse = ", "),
86-
call. = FALSE)
87-
}
88-
89-
# write a constructor function
90-
out <- list(args = args, eng_args = NULL,
91-
mode = mode, method = NULL, engine = NULL)
92-
93-
class(out) <- make_classes("nearest_neighbor")
94-
out
82+
new_model_spec(
83+
"nearest_neighbor",
84+
args = args,
85+
eng_args = NULL,
86+
mode = mode,
87+
method = NULL,
88+
engine = NULL
89+
)
9590
}
9691

9792
#' @export
@@ -132,7 +127,14 @@ update.nearest_neighbor <- function(object,
132127
object$args[names(args)] <- args
133128
}
134129

135-
object
130+
new_model_spec(
131+
"nearest_neighbor",
132+
args = object$args,
133+
eng_args = object$eng_args,
134+
mode = object$mode,
135+
method = NULL,
136+
engine = object$engine
137+
)
136138
}
137139

138140

0 commit comments

Comments
 (0)