Skip to content

Commit fa35df8

Browse files
committed
udpates for move form external package to parsnip
1 parent ff95b1e commit fa35df8

File tree

4 files changed

+55
-73
lines changed

4 files changed

+55
-73
lines changed

NAMESPACE

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@ importFrom(generics,tidy)
252252
importFrom(generics,varying_args)
253253
importFrom(glue,glue_collapse)
254254
importFrom(magrittr,"%>%")
255-
importFrom(parsnip,translate)
256255
importFrom(purrr,as_vector)
257256
importFrom(purrr,imap)
258257
importFrom(purrr,imap_lgl)

R/gen_additive_mod.R

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@
6565
#' [_Tidy Models with R_](https://tmwr.org)
6666
#' @examples
6767
#'
68-
#' show_engines("gen_additive_mod")
68+
#' #show_engines("gen_additive_mod")
6969
#'
70-
#' gen_additive_mod()
70+
#' #gen_additive_mod()
7171
#'
7272
#'
7373
#' @export
74-
gen_additive_mod <- function(mode = "regression",
74+
gen_additive_mod <- function(mode = "unknown",
7575
select_features = NULL,
7676
adjust_deg_free = NULL) {
7777

@@ -148,10 +148,9 @@ update.gen_additive_mod <- function(object,
148148

149149

150150
#' @export
151-
#' @importFrom parsnip translate
152151
translate.gen_additive_mod <- function(x, engine = x$engine, ...) {
153152
if (is.null(engine)) {
154-
message("Used `engine = 'gam'` for translation.")
153+
message("Used `engine = 'mgcv'` for translation.")
155154
engine <- "gam"
156155
}
157156
x <- translate.default(x, engine, ...)

R/gen_additive_mod_data.R

Lines changed: 48 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,36 @@
11

22
set_new_model("gen_additive_mod")
33

4+
# ------------------------------------------------------------------------------
45
#### REGRESION ----
5-
model = "gen_additive_mod"
6-
mode = "regression"
7-
engine = "gam"
8-
9-
set_model_engine(model = model, mode = mode, eng = engine)
10-
set_dependency(model = model, eng = engine, pkg = "mgcv")
11-
set_dependency(model = model, eng = engine, pkg = "parnsip")
6+
set_model_engine(model = "gen_additive_mod", mode = "regression", eng = "mgcv")
7+
set_dependency(model = "gen_additive_mod", eng = "mgcv", pkg = "mgcv")
128

139
#Args
1410

11+
# TODO make dials PR
1512
set_model_arg(
1613
model = "gen_additive_mod",
17-
eng = "gam",
14+
eng = "mgcv",
1815
parsnip = "select_features",
1916
original = "select",
20-
func = list(pkg = "parnsip", fun = "select_features"),
17+
func = list(pkg = "dials", fun = "select_features"),
2118
has_submodel = FALSE
2219
)
2320

2421
set_model_arg(
2522
model = "gen_additive_mod",
26-
eng = "gam",
23+
eng = "mgcv",
2724
parsnip = "adjust_deg_free",
2825
original = "gamma",
29-
func = list(pkg = "parnsip", fun = "adjust_deg_free"),
26+
func = list(pkg = "dials", fun = "adjust_deg_free"),
3027
has_submodel = FALSE
3128
)
3229

3330
set_encoding(
34-
model = model,
35-
eng = engine,
36-
mode = mode,
31+
model = "gen_additive_mod",
32+
eng = "mgcv",
33+
mode = "regression",
3734
options = list(
3835
predictor_indicators = "none",
3936
compute_intercept = FALSE,
@@ -43,24 +40,21 @@ set_encoding(
4340
)
4441

4542
set_fit(
46-
model = model,
47-
eng = engine,
48-
mode = mode,
43+
model = "gen_additive_mod",
44+
eng = "mgcv",
45+
mode = "regression",
4946
value = list(
5047
interface = "formula",
5148
protect = c("formula", "data"),
5249
func = c(pkg = "mgcv", fun = "gam"),
53-
defaults = list(
54-
select = FALSE,
55-
gamma = 1
56-
)
50+
defaults = list()
5751
)
5852
)
5953

6054
set_pred(
61-
model = model,
62-
eng = engine,
63-
mode = mode,
55+
model = "gen_additive_mod",
56+
eng = "mgcv",
57+
mode = "regression",
6458
type = "numeric",
6559
value = list(
6660
pre = NULL,
@@ -75,13 +69,14 @@ set_pred(
7569
)
7670

7771
set_pred(
78-
model = model,
79-
eng = engine,
80-
mode = mode,
72+
model = "gen_additive_mod",
73+
eng = "mgcv",
74+
mode = "regression",
8175
type = "conf_int",
8276
value = list(
8377
pre = NULL,
8478
post = function(results, object) {
79+
# TODO fix this; see the logistic regression code
8580
res <-tibble::tibble(.pre_lower = results$fit - 2*results$se.fit,
8681
.pre_upper = results$fit + 2*results$se.fit)
8782
},
@@ -96,9 +91,9 @@ set_pred(
9691
)
9792

9893
set_pred(
99-
model = model,
100-
eng = engine,
101-
mode = mode,
94+
model = "gen_additive_mod",
95+
eng = "mgcv",
96+
mode = "regression",
10297
type = "raw",
10398
value = list(
10499
pre = NULL,
@@ -111,20 +106,16 @@ set_pred(
111106
)
112107
)
113108

109+
# ------------------------------------------------------------------------------
114110
#### CLASSIFICATION
111+
set_model_engine(model = "gen_additive_mod", mode = "classification", eng = "mgcv")
112+
set_dependency(model = "gen_additive_mod", eng = "mgcv", pkg = "mgcv")
115113

116-
model = "gen_additive_mod"
117-
mode = "classification"
118-
engine = "gam"
119-
120-
set_model_engine(model = model, mode = mode, eng = engine)
121-
set_dependency(model = model, eng = engine, pkg = "mgcv")
122-
set_dependency(model = model, eng = engine, pkg = "parnsip")
123114

124115
set_encoding(
125-
model = model,
126-
eng = engine,
127-
mode = mode,
116+
model = "gen_additive_mod",
117+
eng = "mgcv",
118+
mode = "classification",
128119
options = list(
129120
predictor_indicators = "none",
130121
compute_intercept = FALSE,
@@ -134,46 +125,39 @@ set_encoding(
134125
)
135126

136127
set_fit(
137-
model = model,
138-
eng = engine,
139-
mode = mode,
128+
model = "gen_additive_mod",
129+
eng = "mgcv",
130+
mode = "classification",
140131
value = list(
141132
interface = "formula",
142133
protect = c("formula", "data"),
143134
func = c(pkg = "mgcv", fun = "gam"),
144135
defaults = list(
145-
select = FALSE,
146-
gamma = 1,
147136
family = stats::binomial(link = "logit")
148137
)
149138
)
150139
)
151140

152-
prob_to_class_2 <- function(x, object){
153-
154-
x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1])
155-
unname(x)
156-
}
157-
158141
set_pred(
159-
model = model,
160-
eng = engine,
161-
mode = mode,
142+
model = "gen_additive_mod",
143+
eng = "mgcv",
144+
mode = "classification",
162145
type = "class",
163146
value = list(
164147
pre = NULL,
165148
post = function(results, object) {
166149

167150
tbl <-tibble::as_tibble(results)
168151

169-
if (ncol(tbl)==1){
170-
res<-prob_to_class_2(tbl, object) %>%
152+
if (ncol(tbl) == 1) {
153+
res <- prob_to_class_2(tbl, object) %>%
171154
tibble::as_tibble() %>%
172155
stats::setNames("values") %>%
173156
dplyr::mutate(values = as.factor(values))
174157
} else{
175158
res <- tbl %>%
176-
apply(.,1,function(x) which(max(x)==x)[1])-1 %>% #modify in the future for something more elegant when gets the formula ok
159+
apply(., 1, function(x)
160+
which(max(x) == x)[1]) - 1 %>% #modify in the future for something more elegant when gets the formula ok
177161
tibble::as_tibble()
178162
}
179163

@@ -188,14 +172,14 @@ set_pred(
188172
)
189173

190174
set_pred(
191-
model = model,
192-
eng = engine,
193-
mode = mode,
175+
model = "gen_additive_mod",
176+
eng = "mgcv",
177+
mode = "classification",
194178
type = "prob",
195179
value = list(
196180
pre = NULL,
197181
post = function(results, object) {
198-
res <-tibble::as_tibble(results)
182+
res <- tibble::as_tibble(results)
199183
},
200184
func = c(fun = "predict"),
201185
args = list(
@@ -207,9 +191,9 @@ set_pred(
207191
)
208192

209193
set_pred(
210-
model = model,
211-
eng = engine,
212-
mode = mode,
194+
model = "gen_additive_mod",
195+
eng = "mgcv",
196+
mode = "classification",
213197
type = "raw",
214198
value = list(
215199
pre = NULL,

man/gen_additive_mod.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)