Skip to content

Commit 44213b2

Browse files
committed
confidence intervals and other model info changes
1 parent b535a07 commit 44213b2

File tree

1 file changed

+70
-22
lines changed

1 file changed

+70
-22
lines changed

R/gen_additive_mod_data.R

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,26 @@ set_pred(
7676
value = list(
7777
pre = NULL,
7878
post = function(results, object) {
79-
# TODO fix this; see the logistic regression code
80-
res <-tibble::tibble(.pred_lower = results$fit - 2*results$se.fit,
81-
.pred_upper = results$fit + 2*results$se.fit)
79+
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2
80+
const <-
81+
qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
82+
trans <- object$fit$family$linkinv
83+
res <-
84+
tibble(
85+
.pred_lower = trans(results$fit - const * results$se.fit),
86+
.pred_upper = trans(results$fit + const * results$se.fit)
87+
)
88+
# In case of inverse or other links
89+
if (any(res$.pred_upper < res$.pred_lower)) {
90+
nms <- names(res)
91+
res <- res[, 2:1]
92+
names(res) <- nms
93+
}
94+
95+
if (object$spec$method$pred$conf_int$extras$std_error) {
96+
res$.std_error <- results$se.fit
97+
}
98+
res
8299
},
83100
func = c(fun = "predict"),
84101
args = list(
@@ -145,22 +162,9 @@ set_pred(
145162
type = "class",
146163
value = list(
147164
pre = NULL,
148-
post = function(results, object) {
149-
150-
tbl <-tibble::as_tibble(results)
151-
152-
if (ncol(tbl) == 1) {
153-
res <- prob_to_class_2(tbl, object) %>%
154-
tibble::as_tibble() %>%
155-
stats::setNames("values") %>%
156-
dplyr::mutate(values = as.factor(values))
157-
} else{
158-
res <- tbl %>%
159-
apply(., 1, function(x)
160-
which(max(x) == x)[1]) - 1 %>% #modify in the future for something more elegant when gets the formula ok
161-
tibble::as_tibble()
162-
}
163-
165+
post = function(x, object) {
166+
x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1])
167+
unname(x)
164168
},
165169
func = c(fun = "predict"),
166170
args = list(
@@ -177,9 +181,11 @@ set_pred(
177181
mode = "classification",
178182
type = "prob",
179183
value = list(
180-
pre = NULL,
181-
post = function(results, object) {
182-
res <- tibble::as_tibble(results)
184+
pre = NULL,
185+
post = function(x, object) {
186+
x <- tibble(v1 = 1 - x, v2 = x)
187+
colnames(x) <- object$lvl
188+
x
183189
},
184190
func = c(fun = "predict"),
185191
args = list(
@@ -207,3 +213,45 @@ set_pred(
207213
)
208214

209215

216+
set_pred(
217+
model = "gen_additive_mod",
218+
eng = "mgcv",
219+
mode = "classification",
220+
type = "conf_int",
221+
value = list(
222+
pre = NULL,
223+
post = function(results, object) {
224+
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2
225+
const <-
226+
qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
227+
trans <- object$fit$family$linkinv
228+
res_2 <-
229+
tibble(
230+
lo = trans(results$fit - const * results$se.fit),
231+
hi = trans(results$fit + const * results$se.fit)
232+
)
233+
res_1 <- res_2
234+
res_1$lo <- 1 - res_2$hi
235+
res_1$hi <- 1 - res_2$lo
236+
lo_nms <- paste0(".pred_lower_", object$lvl)
237+
hi_nms <- paste0(".pred_upper_", object$lvl)
238+
colnames(res_1) <- c(lo_nms[1], hi_nms[1])
239+
colnames(res_2) <- c(lo_nms[2], hi_nms[2])
240+
res <- bind_cols(res_1, res_2)
241+
242+
if (object$spec$method$pred$conf_int$extras$std_error) {
243+
res$.std_error <- results$se.fit
244+
}
245+
res
246+
},
247+
func = c(fun = "predict"),
248+
args =
249+
list(
250+
object = rlang::expr(object$fit),
251+
newdata = rlang::expr(new_data),
252+
type = "link",
253+
se.fit = TRUE
254+
)
255+
)
256+
)
257+

0 commit comments

Comments
 (0)