Skip to content

Commit a4b293f

Browse files
committed
test cases
1 parent 44213b2 commit a4b293f

File tree

3 files changed

+106
-0
lines changed

3 files changed

+106
-0
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
S3method(augment,model_fit)
44
S3method(fit,model_spec)
5+
S3method(fit_xy,gen_additive_mod)
56
S3method(fit_xy,model_spec)
67
S3method(glance,model_fit)
78
S3method(has_multi_predict,default)

R/gen_additive_mod.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,9 @@ translate.gen_additive_mod <- function(x, engine = x$engine, ...) {
158158

159159
x
160160
}
161+
162+
#' @export
163+
#' @keywords internal
164+
fit_xy.gen_additive_mod <- function(object, ...) {
165+
rlang::abort("`fit()` must be used with GAM models (due to its use of formulas).")
166+
}
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
library(testthat)
2+
library(parsnip)
3+
library(rlang)
4+
library(tibble)
5+
library(mgcv)
6+
7+
data(two_class_dat, package = "modeldata")
8+
9+
# ------------------------------------------------------------------------------
10+
11+
context("generalized additive models")
12+
13+
# ------------------------------------------------------------------------------
14+
15+
reg_mod <- gen_additive_mod(select_features = TRUE) %>% set_engine("mgcv") %>% set_mode("regression")
16+
17+
test_that('regression', {
18+
skip_if_not_installed("mgcv")
19+
20+
expect_error(
21+
f_res <- fit(
22+
reg_mod,
23+
mpg ~ s(disp) + wt + gear,
24+
data = mtcars
25+
),
26+
regexp = NA
27+
)
28+
expect_error(
29+
xy_res <- fit_xy(
30+
reg_mod,
31+
x = mtcars[, 1:5],
32+
y = mtcars$mpg,
33+
control = ctrl
34+
),
35+
regexp = "must be used with GAM models"
36+
)
37+
mgcv_mod <- mgcv::gam(mpg ~ s(disp) + wt + gear, data = mtcars, select = TRUE)
38+
expect_equal(coef(mgcv_mod), coef(f_res$fit))
39+
40+
f_pred <- predict(f_res, head(mtcars))
41+
mgcv_pred <- predict(mgcv_mod, head(mtcars), type = "response")
42+
expect_equal(names(f_pred), ".pred")
43+
expect_equivalent(f_pred[[".pred"]], unname(mgcv_pred))
44+
45+
f_ci <- predict(f_res, head(mtcars), type = "conf_int", std_error = TRUE)
46+
mgcv_ci <- predict(mgcv_mod, head(mtcars), type = "link", se.fit = TRUE)
47+
expect_equivalent(f_ci[[".std_error"]], unname(mgcv_ci$se.fit))
48+
lower <-
49+
mgcv_ci$fit - qt(0.025, df = mgcv_mod$df.residual, lower.tail = FALSE) * mgcv_ci$se.fit
50+
expect_equivalent(f_ci[[".pred_lower"]], unname(lower))
51+
52+
})
53+
54+
55+
56+
# ------------------------------------------------------------------------------
57+
58+
cls_mod <- gen_additive_mod(adjust_deg_free = 1.5) %>% set_engine("mgcv") %>% set_mode("classification")
59+
60+
test_that('classification', {
61+
skip_if_not_installed("mgcv")
62+
expect_error(
63+
f_res <- fit(
64+
cls_mod,
65+
Class ~ s(A, k = 10) + B,
66+
data = two_class_dat
67+
),
68+
regexp = NA
69+
)
70+
expect_error(
71+
xy_res <- fit_xy(
72+
cls_mod,
73+
x = two_class_dat[, 2:3],
74+
y = two_class_dat$Class,
75+
control = ctrl
76+
),
77+
regexp = "must be used with GAM models"
78+
)
79+
mgcv_mod <-
80+
mgcv::gam(Class ~ s(A, k = 10) + B,
81+
data = two_class_dat,
82+
gamma = 1.5,
83+
family = binomial)
84+
expect_equal(coef(mgcv_mod), coef(f_res$fit))
85+
86+
f_pred <- predict(f_res, head(two_class_dat), type = "prob")
87+
mgcv_pred <- predict(mgcv_mod, head(two_class_dat), type = "response")
88+
expect_equal(names(f_pred), c(".pred_Class1", ".pred_Class2"))
89+
expect_equivalent(f_pred[[".pred_Class2"]], unname(mgcv_pred))
90+
91+
f_ci <- predict(f_res, head(two_class_dat), type = "conf_int", std_error = TRUE)
92+
mgcv_ci <- predict(mgcv_mod, head(two_class_dat), type = "link", se.fit = TRUE)
93+
expect_equivalent(f_ci[[".std_error"]], unname(mgcv_ci$se.fit))
94+
lower <-
95+
mgcv_ci$fit - qt(0.025, df = mgcv_mod$df.residual, lower.tail = FALSE) * mgcv_ci$se.fit
96+
lower <- binomial()$linkinv(lower)
97+
expect_equivalent(f_ci[[".pred_lower_Class2"]], unname(lower))
98+
99+
})

0 commit comments

Comments
 (0)