11
22set_new_model(" gen_additive_mod" )
33
4+ # ------------------------------------------------------------------------------
45# ### REGRESION ----
5- model = " gen_additive_mod"
6- mode = " regression"
7- engine = " gam"
8-
9- set_model_engine(model = model , mode = mode , eng = engine )
10- set_dependency(model = model , eng = engine , pkg = " mgcv" )
11- set_dependency(model = model , eng = engine , pkg = " parnsip" )
6+ set_model_engine(model = " gen_additive_mod" , mode = " regression" , eng = " mgcv" )
7+ set_dependency(model = " gen_additive_mod" , eng = " mgcv" , pkg = " mgcv" )
128
139# Args
1410
11+ # TODO make dials PR
1512set_model_arg(
1613 model = " gen_additive_mod" ,
17- eng = " gam " ,
14+ eng = " mgcv " ,
1815 parsnip = " select_features" ,
1916 original = " select" ,
20- func = list (pkg = " parnsip " , fun = " select_features" ),
17+ func = list (pkg = " dials " , fun = " select_features" ),
2118 has_submodel = FALSE
2219)
2320
2421set_model_arg(
2522 model = " gen_additive_mod" ,
26- eng = " gam " ,
23+ eng = " mgcv " ,
2724 parsnip = " adjust_deg_free" ,
2825 original = " gamma" ,
29- func = list (pkg = " parnsip " , fun = " adjust_deg_free" ),
26+ func = list (pkg = " dials " , fun = " adjust_deg_free" ),
3027 has_submodel = FALSE
3128)
3229
3330set_encoding(
34- model = model ,
35- eng = engine ,
36- mode = mode ,
31+ model = " gen_additive_mod " ,
32+ eng = " mgcv " ,
33+ mode = " regression " ,
3734 options = list (
3835 predictor_indicators = " none" ,
3936 compute_intercept = FALSE ,
@@ -43,24 +40,21 @@ set_encoding(
4340)
4441
4542set_fit(
46- model = model ,
47- eng = engine ,
48- mode = mode ,
43+ model = " gen_additive_mod " ,
44+ eng = " mgcv " ,
45+ mode = " regression " ,
4946 value = list (
5047 interface = " formula" ,
5148 protect = c(" formula" , " data" ),
5249 func = c(pkg = " mgcv" , fun = " gam" ),
53- defaults = list (
54- select = FALSE ,
55- gamma = 1
56- )
50+ defaults = list ()
5751 )
5852)
5953
6054set_pred(
61- model = model ,
62- eng = engine ,
63- mode = mode ,
55+ model = " gen_additive_mod " ,
56+ eng = " mgcv " ,
57+ mode = " regression " ,
6458 type = " numeric" ,
6559 value = list (
6660 pre = NULL ,
@@ -75,13 +69,14 @@ set_pred(
7569)
7670
7771set_pred(
78- model = model ,
79- eng = engine ,
80- mode = mode ,
72+ model = " gen_additive_mod " ,
73+ eng = " mgcv " ,
74+ mode = " regression " ,
8175 type = " conf_int" ,
8276 value = list (
8377 pre = NULL ,
8478 post = function (results , object ) {
79+ # TODO fix this; see the logistic regression code
8580 res <- tibble :: tibble(.pre_lower = results $ fit - 2 * results $ se.fit ,
8681 .pre_upper = results $ fit + 2 * results $ se.fit )
8782 },
@@ -96,9 +91,9 @@ set_pred(
9691)
9792
9893set_pred(
99- model = model ,
100- eng = engine ,
101- mode = mode ,
94+ model = " gen_additive_mod " ,
95+ eng = " mgcv " ,
96+ mode = " regression " ,
10297 type = " raw" ,
10398 value = list (
10499 pre = NULL ,
@@ -111,20 +106,16 @@ set_pred(
111106 )
112107)
113108
109+ # ------------------------------------------------------------------------------
114110# ### CLASSIFICATION
111+ set_model_engine(model = " gen_additive_mod" , mode = " classification" , eng = " mgcv" )
112+ set_dependency(model = " gen_additive_mod" , eng = " mgcv" , pkg = " mgcv" )
115113
116- model = " gen_additive_mod"
117- mode = " classification"
118- engine = " gam"
119-
120- set_model_engine(model = model , mode = mode , eng = engine )
121- set_dependency(model = model , eng = engine , pkg = " mgcv" )
122- set_dependency(model = model , eng = engine , pkg = " parnsip" )
123114
124115set_encoding(
125- model = model ,
126- eng = engine ,
127- mode = mode ,
116+ model = " gen_additive_mod " ,
117+ eng = " mgcv " ,
118+ mode = " classification " ,
128119 options = list (
129120 predictor_indicators = " none" ,
130121 compute_intercept = FALSE ,
@@ -134,46 +125,39 @@ set_encoding(
134125)
135126
136127set_fit(
137- model = model ,
138- eng = engine ,
139- mode = mode ,
128+ model = " gen_additive_mod " ,
129+ eng = " mgcv " ,
130+ mode = " classification " ,
140131 value = list (
141132 interface = " formula" ,
142133 protect = c(" formula" , " data" ),
143134 func = c(pkg = " mgcv" , fun = " gam" ),
144135 defaults = list (
145- select = FALSE ,
146- gamma = 1 ,
147136 family = stats :: binomial(link = " logit" )
148137 )
149138 )
150139)
151140
152- prob_to_class_2 <- function (x , object ){
153-
154- x <- ifelse(x > = 0.5 , object $ lvl [2 ], object $ lvl [1 ])
155- unname(x )
156- }
157-
158141set_pred(
159- model = model ,
160- eng = engine ,
161- mode = mode ,
142+ model = " gen_additive_mod " ,
143+ eng = " mgcv " ,
144+ mode = " classification " ,
162145 type = " class" ,
163146 value = list (
164147 pre = NULL ,
165148 post = function (results , object ) {
166149
167150 tbl <- tibble :: as_tibble(results )
168151
169- if (ncol(tbl )== 1 ) {
170- res <- prob_to_class_2(tbl , object ) %> %
152+ if (ncol(tbl ) == 1 ) {
153+ res <- prob_to_class_2(tbl , object ) %> %
171154 tibble :: as_tibble() %> %
172155 stats :: setNames(" values" ) %> %
173156 dplyr :: mutate(values = as.factor(values ))
174157 } else {
175158 res <- tbl %> %
176- apply(. ,1 ,function (x ) which(max(x )== x )[1 ])- 1 %> % # modify in the future for something more elegant when gets the formula ok
159+ apply(. , 1 , function (x )
160+ which(max(x ) == x )[1 ]) - 1 %> % # modify in the future for something more elegant when gets the formula ok
177161 tibble :: as_tibble()
178162 }
179163
@@ -188,14 +172,14 @@ set_pred(
188172)
189173
190174set_pred(
191- model = model ,
192- eng = engine ,
193- mode = mode ,
175+ model = " gen_additive_mod " ,
176+ eng = " mgcv " ,
177+ mode = " classification " ,
194178 type = " prob" ,
195179 value = list (
196180 pre = NULL ,
197181 post = function (results , object ) {
198- res <- tibble :: as_tibble(results )
182+ res <- tibble :: as_tibble(results )
199183 },
200184 func = c(fun = " predict" ),
201185 args = list (
@@ -207,9 +191,9 @@ set_pred(
207191)
208192
209193set_pred(
210- model = model ,
211- eng = engine ,
212- mode = mode ,
194+ model = " gen_additive_mod " ,
195+ eng = " mgcv " ,
196+ mode = " classification " ,
213197 type = " raw" ,
214198 value = list (
215199 pre = NULL ,
0 commit comments