Skip to content

Commit cbb93d2

Browse files
committed
modularize confidence interval code
1 parent a4b293f commit cbb93d2

File tree

4 files changed

+57
-69
lines changed

4 files changed

+57
-69
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ importFrom(stats,na.omit)
300300
importFrom(stats,na.pass)
301301
importFrom(stats,predict)
302302
importFrom(stats,qnorm)
303+
importFrom(stats,qt)
303304
importFrom(stats,quantile)
304305
importFrom(stats,setNames)
305306
importFrom(stats,terms)

R/aaa.R

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,59 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
3030
res
3131
}
3232

33+
# ------------------------------------------------------------------------------
34+
35+
#' @importFrom stats qt
36+
# used by logistic_reg() and gen_additive_mod()
37+
logistic_lp_to_conf_int <- function(results, object) {
38+
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2
39+
const <-
40+
stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
41+
trans <- object$fit$family$linkinv
42+
res_2 <-
43+
tibble(
44+
lo = trans(results$fit - const * results$se.fit),
45+
hi = trans(results$fit + const * results$se.fit)
46+
)
47+
res_1 <- res_2
48+
res_1$lo <- 1 - res_2$hi
49+
res_1$hi <- 1 - res_2$lo
50+
lo_nms <- paste0(".pred_lower_", object$lvl)
51+
hi_nms <- paste0(".pred_upper_", object$lvl)
52+
colnames(res_1) <- c(lo_nms[1], hi_nms[1])
53+
colnames(res_2) <- c(lo_nms[2], hi_nms[2])
54+
res <- bind_cols(res_1, res_2)
55+
56+
if (object$spec$method$pred$conf_int$extras$std_error)
57+
res$.std_error <- results$se.fit
58+
res
59+
}
60+
61+
# used by gen_additive_mod()
62+
linear_lp_to_conf_int <-
63+
function(results, object) {
64+
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2
65+
const <-
66+
stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
67+
trans <- object$fit$family$linkinv
68+
res <-
69+
tibble(
70+
.pred_lower = trans(results$fit - const * results$se.fit),
71+
.pred_upper = trans(results$fit + const * results$se.fit)
72+
)
73+
# In case of inverse or other links
74+
if (any(res$.pred_upper < res$.pred_lower)) {
75+
nms <- names(res)
76+
res <- res[, 2:1]
77+
names(res) <- nms
78+
}
79+
80+
if (object$spec$method$pred$conf_int$extras$std_error) {
81+
res$.std_error <- results$se.fit
82+
}
83+
res
84+
}
85+
3386
# ------------------------------------------------------------------------------
3487
# nocov
3588

R/gen_additive_mod_data.R

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -75,28 +75,7 @@ set_pred(
7575
type = "conf_int",
7676
value = list(
7777
pre = NULL,
78-
post = function(results, object) {
79-
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2
80-
const <-
81-
qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
82-
trans <- object$fit$family$linkinv
83-
res <-
84-
tibble(
85-
.pred_lower = trans(results$fit - const * results$se.fit),
86-
.pred_upper = trans(results$fit + const * results$se.fit)
87-
)
88-
# In case of inverse or other links
89-
if (any(res$.pred_upper < res$.pred_lower)) {
90-
nms <- names(res)
91-
res <- res[, 2:1]
92-
names(res) <- nms
93-
}
94-
95-
if (object$spec$method$pred$conf_int$extras$std_error) {
96-
res$.std_error <- results$se.fit
97-
}
98-
res
99-
},
78+
post = linear_lp_to_conf_int,
10079
func = c(fun = "predict"),
10180
args = list(
10281
object = rlang::expr(object$fit),
@@ -220,30 +199,7 @@ set_pred(
220199
type = "conf_int",
221200
value = list(
222201
pre = NULL,
223-
post = function(results, object) {
224-
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2
225-
const <-
226-
qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
227-
trans <- object$fit$family$linkinv
228-
res_2 <-
229-
tibble(
230-
lo = trans(results$fit - const * results$se.fit),
231-
hi = trans(results$fit + const * results$se.fit)
232-
)
233-
res_1 <- res_2
234-
res_1$lo <- 1 - res_2$hi
235-
res_1$hi <- 1 - res_2$lo
236-
lo_nms <- paste0(".pred_lower_", object$lvl)
237-
hi_nms <- paste0(".pred_upper_", object$lvl)
238-
colnames(res_1) <- c(lo_nms[1], hi_nms[1])
239-
colnames(res_2) <- c(lo_nms[2], hi_nms[2])
240-
res <- bind_cols(res_1, res_2)
241-
242-
if (object$spec$method$pred$conf_int$extras$std_error) {
243-
res$.std_error <- results$se.fit
244-
}
245-
res
246-
},
202+
post = logistic_lp_to_conf_int,
247203
func = c(fun = "predict"),
248204
args =
249205
list(

R/logistic_reg_data.R

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -95,29 +95,7 @@ set_pred(
9595
type = "conf_int",
9696
value = list(
9797
pre = NULL,
98-
post = function(results, object) {
99-
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2
100-
const <-
101-
qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
102-
trans <- object$fit$family$linkinv
103-
res_2 <-
104-
tibble(
105-
lo = trans(results$fit - const * results$se.fit),
106-
hi = trans(results$fit + const * results$se.fit)
107-
)
108-
res_1 <- res_2
109-
res_1$lo <- 1 - res_2$hi
110-
res_1$hi <- 1 - res_2$lo
111-
lo_nms <- paste0(".pred_lower_", object$lvl)
112-
hi_nms <- paste0(".pred_upper_", object$lvl)
113-
colnames(res_1) <- c(lo_nms[1], hi_nms[1])
114-
colnames(res_2) <- c(lo_nms[2], hi_nms[2])
115-
res <- bind_cols(res_1, res_2)
116-
117-
if (object$spec$method$pred$conf_int$extras$std_error)
118-
res$.std_error <- results$se.fit
119-
res
120-
},
98+
post = logistic_lp_to_conf_int,
12199
func = c(fun = "predict"),
122100
args =
123101
list(

0 commit comments

Comments
 (0)