@@ -76,9 +76,26 @@ set_pred(
7676 value = list (
7777 pre = NULL ,
7878 post = function (results , object ) {
79- # TODO fix this; see the logistic regression code
80- res <- tibble :: tibble(.pred_lower = results $ fit - 2 * results $ se.fit ,
81- .pred_upper = results $ fit + 2 * results $ se.fit )
79+ hf_lvl <- (1 - object $ spec $ method $ pred $ conf_int $ extras $ level )/ 2
80+ const <-
81+ qt(hf_lvl , df = object $ fit $ df.residual , lower.tail = FALSE )
82+ trans <- object $ fit $ family $ linkinv
83+ res <-
84+ tibble(
85+ .pred_lower = trans(results $ fit - const * results $ se.fit ),
86+ .pred_upper = trans(results $ fit + const * results $ se.fit )
87+ )
88+ # In case of inverse or other links
89+ if (any(res $ .pred_upper < res $ .pred_lower )) {
90+ nms <- names(res )
91+ res <- res [, 2 : 1 ]
92+ names(res ) <- nms
93+ }
94+
95+ if (object $ spec $ method $ pred $ conf_int $ extras $ std_error ) {
96+ res $ .std_error <- results $ se.fit
97+ }
98+ res
8299 },
83100 func = c(fun = " predict" ),
84101 args = list (
@@ -145,22 +162,9 @@ set_pred(
145162 type = " class" ,
146163 value = list (
147164 pre = NULL ,
148- post = function (results , object ) {
149-
150- tbl <- tibble :: as_tibble(results )
151-
152- if (ncol(tbl ) == 1 ) {
153- res <- prob_to_class_2(tbl , object ) %> %
154- tibble :: as_tibble() %> %
155- stats :: setNames(" values" ) %> %
156- dplyr :: mutate(values = as.factor(values ))
157- } else {
158- res <- tbl %> %
159- apply(. , 1 , function (x )
160- which(max(x ) == x )[1 ]) - 1 %> % # modify in the future for something more elegant when gets the formula ok
161- tibble :: as_tibble()
162- }
163-
165+ post = function (x , object ) {
166+ x <- ifelse(x > = 0.5 , object $ lvl [2 ], object $ lvl [1 ])
167+ unname(x )
164168 },
165169 func = c(fun = " predict" ),
166170 args = list (
@@ -177,9 +181,11 @@ set_pred(
177181 mode = " classification" ,
178182 type = " prob" ,
179183 value = list (
180- pre = NULL ,
181- post = function (results , object ) {
182- res <- tibble :: as_tibble(results )
184+ pre = NULL ,
185+ post = function (x , object ) {
186+ x <- tibble(v1 = 1 - x , v2 = x )
187+ colnames(x ) <- object $ lvl
188+ x
183189 },
184190 func = c(fun = " predict" ),
185191 args = list (
@@ -207,3 +213,45 @@ set_pred(
207213)
208214
209215
216+ set_pred(
217+ model = " gen_additive_mod" ,
218+ eng = " mgcv" ,
219+ mode = " classification" ,
220+ type = " conf_int" ,
221+ value = list (
222+ pre = NULL ,
223+ post = function (results , object ) {
224+ hf_lvl <- (1 - object $ spec $ method $ pred $ conf_int $ extras $ level )/ 2
225+ const <-
226+ qt(hf_lvl , df = object $ fit $ df.residual , lower.tail = FALSE )
227+ trans <- object $ fit $ family $ linkinv
228+ res_2 <-
229+ tibble(
230+ lo = trans(results $ fit - const * results $ se.fit ),
231+ hi = trans(results $ fit + const * results $ se.fit )
232+ )
233+ res_1 <- res_2
234+ res_1 $ lo <- 1 - res_2 $ hi
235+ res_1 $ hi <- 1 - res_2 $ lo
236+ lo_nms <- paste0(" .pred_lower_" , object $ lvl )
237+ hi_nms <- paste0(" .pred_upper_" , object $ lvl )
238+ colnames(res_1 ) <- c(lo_nms [1 ], hi_nms [1 ])
239+ colnames(res_2 ) <- c(lo_nms [2 ], hi_nms [2 ])
240+ res <- bind_cols(res_1 , res_2 )
241+
242+ if (object $ spec $ method $ pred $ conf_int $ extras $ std_error ) {
243+ res $ .std_error <- results $ se.fit
244+ }
245+ res
246+ },
247+ func = c(fun = " predict" ),
248+ args =
249+ list (
250+ object = rlang :: expr(object $ fit ),
251+ newdata = rlang :: expr(new_data ),
252+ type = " link" ,
253+ se.fit = TRUE
254+ )
255+ )
256+ )
257+
0 commit comments