Skip to content

Commit 6f0576c

Browse files
committed
add gen_additive_mod
1 parent 7d78009 commit 6f0576c

File tree

4 files changed

+469
-0
lines changed

4 files changed

+469
-0
lines changed

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ S3method(predict_time,model_fit)
4646
S3method(print,boost_tree)
4747
S3method(print,control_parsnip)
4848
S3method(print,decision_tree)
49+
S3method(print,gen_additive_mod)
4950
S3method(print,linear_reg)
5051
S3method(print,logistic_reg)
5152
S3method(print,mars)
@@ -76,6 +77,7 @@ S3method(tidy,nullmodel)
7677
S3method(translate,boost_tree)
7778
S3method(translate,decision_tree)
7879
S3method(translate,default)
80+
S3method(translate,gen_additive_mod)
7981
S3method(translate,linear_reg)
8082
S3method(translate,logistic_reg)
8183
S3method(translate,mars)
@@ -91,6 +93,7 @@ S3method(type_sum,model_fit)
9193
S3method(type_sum,model_spec)
9294
S3method(update,boost_tree)
9395
S3method(update,decision_tree)
96+
S3method(update,gen_additive_mod)
9497
S3method(update,linear_reg)
9598
S3method(update,logistic_reg)
9699
S3method(update,mars)
@@ -136,6 +139,7 @@ export(fit.model_spec)
136139
export(fit_control)
137140
export(fit_xy)
138141
export(fit_xy.model_spec)
142+
export(gen_additive_mod)
139143
export(get_dependency)
140144
export(get_encoding)
141145
export(get_fit)
@@ -248,6 +252,7 @@ importFrom(generics,tidy)
248252
importFrom(generics,varying_args)
249253
importFrom(glue,glue_collapse)
250254
importFrom(magrittr,"%>%")
255+
importFrom(parsnip,translate)
251256
importFrom(purrr,as_vector)
252257
importFrom(purrr,imap)
253258
importFrom(purrr,imap_lgl)

R/gen_additive_mod.R

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# gen_additive_mod() - General Interface to Linear GAM Models
2+
# - backend: gam
3+
# - prediction:
4+
# - mode = "regression" (default) uses
5+
# - mode = "classification"
6+
7+
#' Interface for Generalized Additive Models (GAM)
8+
#'
9+
#' @param mode A single character string for the type of model.
10+
#' @param select_features TRUE or FALSE. If this is TRUE then can add an
11+
#' extra penalty to each term so that it can be penalized to zero.
12+
#' This means that the smoothing parameter estimation that is part of
13+
#' fitting can completely remove terms from the model. If the corresponding
14+
#' smoothing parameter is estimated as zero then the extra penalty has no effect.
15+
#' Use `adjust_deg_free` to increase level of penalization.
16+
#' @param adjust_deg_free If `select_features = TRUE`, then acts as a multiplier for smoothness.
17+
#' Increase this beyond 1 to produce smoother models.
18+
#'
19+
#'
20+
#' @return
21+
#' A `parsnip` model specification
22+
#'
23+
#' @details
24+
#'
25+
#' __Available Engines:__
26+
#' - __gam__: Connects to `mgcv::gam()`
27+
#'
28+
#' __Parameter Mapping:__
29+
#'
30+
#' ```{r echo = FALSE}
31+
#' tibble::tribble(
32+
#' ~ "modelgam", ~ "mgcv::gam",
33+
#' "select_features", "select (FALSE)",
34+
#' "adjust_deg_free", "gamma (1)"
35+
#' ) %>% knitr::kable()
36+
#' ```
37+
#'
38+
#' @section Engine Details:
39+
#'
40+
#' __gam__
41+
#'
42+
#' This engine uses [mgcv::gam()] and has the following parameters,
43+
#' which can be modified through the [set_engine()] function.
44+
#'
45+
#' ``` {r echo=F}
46+
#' str(mgcv::gam)
47+
#' ```
48+
#'
49+
#' @section Fit Details:
50+
#'
51+
#' __MGCV Formula Interface__
52+
#'
53+
#' Fitting GAMs is accomplished using parameters including:
54+
#'
55+
#' - [mgcv::s()]: GAM spline smooths
56+
#' - [mgcv::te()]: GAM tensor product smooths
57+
#'
58+
#' These are applied in the `fit()` function:
59+
#'
60+
#' ``` r
61+
#' fit(value ~ s(date_mon, k = 12) + s(date_num), data = df)
62+
#' ```
63+
#'
64+
#'
65+
#' @examples
66+
#'
67+
#' show_engines("gen_additive_mod")
68+
#'
69+
#' gen_additive_mod()
70+
#'
71+
#'
72+
#' @export
73+
gen_additive_mod <- function(mode = "regression",
74+
select_features = NULL,
75+
adjust_deg_free = NULL) {
76+
77+
args <- list(
78+
select_features = rlang::enquo(select_features),
79+
adjust_deg_free = rlang::enquo(adjust_deg_free)
80+
)
81+
82+
new_model_spec(
83+
"gen_additive_mod",
84+
args = args,
85+
eng_args = NULL,
86+
mode = mode,
87+
method = NULL,
88+
engine = NULL
89+
)
90+
91+
}
92+
93+
#' @export
94+
print.gen_additive_mod <- function(x, ...) {
95+
cat("GAM Model Specification (", x$mode, ")\n\n", sep = "")
96+
model_printer(x, ...)
97+
98+
if(!is.null(x$method$fit$args)) {
99+
cat("Model fit template:\n")
100+
print(show_call(x))
101+
}
102+
103+
invisible(x)
104+
}
105+
106+
#' @export
107+
#' @importFrom stats update
108+
update.gen_additive_mod <- function(object,
109+
select_features = NULL,
110+
adjust_deg_free = NULL,
111+
parameters = NULL,
112+
fresh = FALSE, ...) {
113+
114+
update_dot_check(...)
115+
116+
if (!is.null(parameters)) {
117+
parameters <- check_final_param(parameters)
118+
}
119+
120+
args <- list(
121+
select_features = rlang::enquo(select_features),
122+
adjust_deg_free = rlang::enquo(adjust_deg_free)
123+
)
124+
125+
args <- update_main_parameters(args, parameters)
126+
127+
if (fresh) {
128+
object$args <- args
129+
} else {
130+
null_args <- purrr::map_lgl(args, null_value)
131+
if (any(null_args))
132+
args <- args[!null_args]
133+
if (length(args) > 0)
134+
object$args[names(args)] <- args
135+
}
136+
137+
new_model_spec(
138+
"gen_additive_mod",
139+
args = object$args,
140+
eng_args = object$eng_args,
141+
mode = object$mode,
142+
method = NULL,
143+
engine = object$engine
144+
)
145+
}
146+
147+
148+
#' @export
149+
#' @importFrom parsnip translate
150+
translate.gen_additive_mod <- function(x, engine = x$engine, ...) {
151+
if (is.null(engine)) {
152+
message("Used `engine = 'gam'` for translation.")
153+
engine <- "gam"
154+
}
155+
x <- translate.default(x, engine, ...)
156+
157+
x
158+
}

0 commit comments

Comments
 (0)