Skip to content

Commit 9a2401d

Browse files
committed
RStudio add-in to write parsnip model spec code
1 parent a82ed40 commit 9a2401d

File tree

10 files changed

+301
-0
lines changed

10 files changed

+301
-0
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ export(update_engine_parameters)
188188
export(update_main_parameters)
189189
export(varying)
190190
export(varying_args)
191+
export(write_parsnip_specs)
191192
export(xgb_train)
192193
importFrom(dplyr,arrange)
193194
importFrom(dplyr,as_tibble)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* An RStudio add-in is availble that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE from the `Tools > Addins > Browse Addins` menu or by calling `write_parsnip_specs()`.
4+
35
# parsnip 0.1.4
46

57
* `show_engines()` will provide information on the current set for a model.

R/add_in.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#' Start an RStudio Addin that can write model specifications
2+
#'
3+
#' `write_parsnip_specs()` starts a process in the RStudio IDE Viewer window
4+
#' that allows users to write code for `parsnip` model specifications from
5+
#' various R packages. The new code are written to the current document at the
6+
#' location of the cursor.
7+
#'
8+
#' @export
9+
write_parsnip_specs <- function() {
10+
sys.source(
11+
system.file("add-in", "gadget.R", package = "parsnip", mustWork = TRUE),
12+
envir = rlang::new_environment(parent = rlang::global_env()),
13+
keep.source = FALSE
14+
)
15+
}
16+

R/data.R

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#' parsnip model specification database
2+
#'
3+
#' This is used in the RStudio add-in and captures information about mode
4+
#' specifications in various R packages.
5+
#'
6+
#' @name model_db
7+
#' @aliases model_db
8+
#' @docType data
9+
#' @return \item{model_db}{a data frame}
10+
#' @keywords datasets
11+
#' @examples
12+
#' data(model_db)
13+
NULL
14+

data/model_db.rda

1.98 KB
Binary file not shown.

inst/add-in/gadget.R

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
parsnip_spec_add_in <- function() {
2+
# ------------------------------------------------------------------------------
3+
# check installs
4+
5+
libs <- c("shiny", "miniUI", "rstudioapi")
6+
is_inst <- purrr::map_lgl(libs, parsnip:::is_installed)
7+
if (any(!is_inst)) {
8+
missing_pkg <- libs[!is_inst]
9+
missing_pkg <- paste0(missing_pkg, collapse = ", ")
10+
rlang::abort(
11+
glue::glue(
12+
"The add-in requires some CRAN package installs: ",
13+
glue::glue_collapse(glue::glue("'{missing_pkg}'"), sep = ", ")
14+
)
15+
)
16+
}
17+
18+
library(shiny)
19+
library(miniUI)
20+
library(rstudioapi)
21+
22+
data(model_db, package = "parsnip")
23+
24+
# ------------------------------------------------------------------------------
25+
26+
make_spec <- function(x, tune_args) {
27+
if (tune_args) {
28+
nms <- x$parameters[[1]]$parameter
29+
args <- purrr::map(nms, ~ rlang::call2("tune"))
30+
names(args) <- nms
31+
} else {
32+
args <- NULL
33+
}
34+
35+
if (x$package != "parsnip") {
36+
pkg <- x$package
37+
} else {
38+
pkg <- NULL
39+
}
40+
41+
if (length(args) > 0) {
42+
cl_1 <- rlang::call2(.ns = pkg, .fn = x$model, !!!args)
43+
} else {
44+
cl_1 <- rlang::call2(.ns = pkg, .fn = x$model)
45+
}
46+
47+
obj_nm <- paste0(x$model,"_", x$engine, "_spec")
48+
chr_1 <- rlang::expr_text(cl_1, width = 500)
49+
chr_1 <- paste0(chr_1, collapse = " ")
50+
# chr_1 <- gsub("(...)", "()", chr_1, fixed = TRUE)
51+
chr_1 <- paste(obj_nm, "<-\n ", chr_1)
52+
chr_2 <- paste0("set_engine('", x$engine, "')")
53+
54+
res <- paste0(chr_1, " %>%\n ", chr_2)
55+
56+
if (!x$single_mode) {
57+
chr_3 <- paste0("set_mode('", x$mode, "')")
58+
res <- paste0(res, " %>%\n ", chr_3)
59+
}
60+
61+
res
62+
}
63+
64+
ui <-
65+
miniPage(
66+
gadgetTitleBar("Write out model specifications"),
67+
miniContentPanel(
68+
radioButtons(
69+
"model_mode",
70+
label = h3("Type of Model"),
71+
choices = c("Classification", "Regression")
72+
),
73+
checkboxInput(
74+
"tune_args",
75+
label = "Tag parameters for tuning (if any)?",
76+
value = TRUE
77+
),
78+
textInput(
79+
"pattern",
80+
label = "Match on (regex)"
81+
),
82+
tags$br(),
83+
uiOutput("model_choices")
84+
),
85+
miniButtonBlock(
86+
actionButton("write", "Write specification code", class = "btn-success")
87+
)
88+
)
89+
90+
91+
server <-
92+
function(input, output) {
93+
get_models <- reactive({
94+
req(input$model_mode)
95+
96+
models <- model_db[model_db$mode == tolower(input$model_mode),]
97+
if (nchar(input$pattern) > 0) {
98+
incld <- grepl(input$pattern, models$model) | grepl(input$pattern, models$engine)
99+
models <- models[incld,]
100+
101+
}
102+
models
103+
}) # get_models
104+
105+
output$model_choices <- renderUI({
106+
107+
model_list <- get_models()
108+
109+
choices <- paste0(model_list$model, " (", model_list$engine, ")")
110+
111+
checkboxGroupInput(
112+
inputId = "model_name",
113+
label = "Model",
114+
choices = c(unique(choices))
115+
)
116+
}) # model_choices
117+
118+
119+
create_code <- reactive({
120+
121+
req(input$model_name)
122+
req(input$model_mode)
123+
124+
model_mode <- tolower(input$model_mode)
125+
selected <- model_db[model_db$label %in% input$model_name,]
126+
selected <- selected[selected$mode %in% model_mode,]
127+
128+
res <- purrr::map_chr(1:nrow(selected), ~ make_spec(selected[.x,], tune_args = input$tune_args))
129+
130+
paste0(res, sep = "\n\n")
131+
132+
}) # create_code
133+
134+
observeEvent(input$write, {
135+
res <- create_code()
136+
for (txt in res) {
137+
rstudioapi::insertText(txt)
138+
}
139+
})
140+
141+
observeEvent(input$done, {
142+
stopApp()
143+
})
144+
}
145+
146+
viewer <- paneViewer(300)
147+
runGadget(ui, server, viewer = viewer)
148+
}
149+
150+
parsnip_spec_add_in()
151+

inst/add-in/parsnip_model_db.R

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# ------------------------------------------------------------------------------
2+
# code to make the parsnip model database used by the RStudio addin
3+
4+
# ------------------------------------------------------------------------------
5+
6+
library(tidymodels)
7+
library(usethis)
8+
9+
# also requires installation of:
10+
packages <- c("parsnip", "discrim", "plsmod", "rules", "baguette", "poissonreg", "modeltime", "modeltime.gluonts")
11+
12+
# ------------------------------------------------------------------------------
13+
14+
# Detects model specifications via their print methods
15+
print_methods <- function(x) {
16+
require(x, character.only = TRUE)
17+
ns <- asNamespace(ns = x)
18+
mthds <- ls(envir = ns, pattern = "^print\\.")
19+
mthds <- gsub("^print\\.", "", mthds)
20+
purrr::map_dfr(mthds, get_engines) %>% dplyr::mutate(package = x)
21+
}
22+
get_engines <- function(x) {
23+
eng <- try(parsnip::show_engines(x), silent = TRUE)
24+
if (inherits(eng, "try-error")) {
25+
eng <- tibble::tibble(engine = NA_character_, mode = NA_character_, model = x)
26+
} else {
27+
eng$model <- x
28+
}
29+
eng
30+
}
31+
get_tunable_param <- function(mode, package, model, engine) {
32+
cl <- rlang::call2(.ns = package, .fn = model)
33+
obj <- rlang::eval_tidy(cl)
34+
obj <- parsnip::set_engine(obj, engine)
35+
obj <- parsnip::set_mode(obj, mode)
36+
res <-
37+
tune::tunable(obj) %>%
38+
dplyr::select(parameter = name)
39+
40+
# ------------------------------------------------------------------------------
41+
# Edit some model parameters
42+
43+
if (model == "rand_forest") {
44+
res <- res[res$parameter != "trees",]
45+
}
46+
if (model == "mars") {
47+
res <- res[res$parameter == "prod_degree",]
48+
}
49+
if (engine %in% c("rule_fit", "xgboost")) {
50+
res <- res[res$parameter != "mtry",]
51+
}
52+
if (model %in% c("bag_tree", "bag_mars")) {
53+
res <- res[0,]
54+
}
55+
res
56+
57+
}
58+
59+
# ------------------------------------------------------------------------------
60+
61+
model_db <-
62+
purrr::map_dfr(packages, print_methods) %>%
63+
dplyr::filter(!is.na(engine)) %>%
64+
dplyr::mutate(label = paste0(model, " (", engine, ")")) %>%
65+
dplyr::arrange(model, engine, mode)
66+
67+
num_modes <-
68+
model_db %>%
69+
dplyr::group_by(package, model, engine) %>%
70+
dplyr::count() %>%
71+
dplyr::ungroup() %>%
72+
dplyr::mutate(single_mode = n == 1) %>%
73+
dplyr::select(package, model, engine, single_mode)
74+
75+
model_db <-
76+
dplyr::left_join(model_db, num_modes, by = c("package", "model", "engine")) %>%
77+
dplyr::filter(engine != "spark") %>%
78+
dplyr::mutate(parameters = purrr::pmap(list(mode, package, model, engine), get_tunable_param))
79+
80+
usethis::use_data(model_db, overwrite = TRUE)
81+

inst/rstudio/addins.dcf

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
Name: Write parsnip model specifications
3+
Description: Create many parsnip model specifications automatically.
4+
Binding: write_parsnip_specs
5+
Interactive: true

man/model_db.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/write_parsnip_specs.Rd

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)