Skip to content

Commit e784b9f

Browse files
committed
Update inp_spec mapping
1 parent af201ae commit e784b9f

File tree

9 files changed

+86
-45
lines changed

9 files changed

+86
-45
lines changed

R/create_keras_spec_helpers.R

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,16 @@ collect_spec_args <- function(
125125
#' to the provided string. This is the common case for blocks that take a
126126
#' single tensor input.
127127
#' 2. **Multiple Input Mapping**: If `input_map` is a named character vector,
128-
#' it provides an explicit mapping from new argument names (the names of the
129-
#' vector) to the original argument names in the `block` function (the values
130-
#' of the vector). This is used for blocks with multiple inputs, like a
131-
#' concatenation layer.
128+
#' the **names must match the argument names of `block`** and each value
129+
#' must be the name of an upstream layer block whose output should be fed
130+
#' into that argument. This orientation matches the
131+
#' syntax (e.g., `c(numeric = "processed_numerical")`). This is used for
132+
#' blocks with multiple inputs, like a concatenation layer.
133+
#'
134+
#' _Note_: Prior releases accepted the opposite orientation
135+
#' (`c(processed_numerical = "numeric")`). Existing code written in that style
136+
#' must flip the names/values when upgrading to this version.
137+
132138
#'
133139
#' @param block A function that defines a Keras layer or a set of layers. The
134140
#' first arguments should be the input tensor(s).
@@ -164,7 +170,7 @@ collect_spec_args <- function(
164170
#' path_b = inp_spec(dense_block, "main_input"),
165171
#' concatenated = inp_spec(
166172
#' concat_block,
167-
#' c(path_a = "input_a", path_b = "input_b")
173+
#' c(input_a = "path_a", input_b = "path_b")
168174
#' ),
169175
#' output = inp_spec(output_block, "concatenated")
170176
#' )
@@ -188,19 +194,19 @@ inp_spec <- function(block, input_map) {
188194
# Case 1: Single string, rename first argument
189195
names(new_formals)[1] <- input_map
190196
} else if (is.character(input_map) && !is.null(names(input_map))) {
191-
# Case 2: Named vector for mapping
192-
if (!all(input_map %in% original_names)) {
193-
missing_args <- input_map[!input_map %in% original_names]
197+
# Case 2: Named vector for mapping (argument-first)
198+
if (!all(names(input_map) %in% original_names)) {
199+
missing_args <- names(input_map)[!names(input_map) %in% original_names]
194200
stop(paste(
195201
"Argument(s)",
196202
paste(shQuote(missing_args), collapse = ", "),
197203
"not found in the block function."
198204
))
199205
}
200-
# Use match() for a more concise, vectorized replacement of names
206+
201207
new_names <- original_names
202-
match_indices <- match(input_map, original_names)
203-
new_names[match_indices] <- names(input_map)
208+
match_indices <- match(names(input_map), original_names)
209+
new_names[match_indices] <- unname(input_map)
204210
names(new_formals) <- new_names
205211
} else {
206212
stop("`input_map` must be a single string or a named character vector.")

README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
21
# kerasnip
32

43
<!-- badges: start -->
4+
55
[![Lifecycle: experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html#experimental)
66
[![R-CMD-check](https://github.com/davidrsch/kerasnip/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/davidrsch/kerasnip/actions/workflows/R-CMD-check.yaml)
77
[![Codecov test
88
coverage](https://codecov.io/gh/davidrsch/kerasnip/graph/badge.svg)](https://app.codecov.io/gh/davidrsch/kerasnip)
99
[![CRAN_Status_Badge](https://www.r-pkg.org/badges/version/kerasnip)](https://cran.r-project.org/package=kerasnip)
1010
[![Downloads](https://cranlogs.r-pkg.org/badges/kerasnip)](https://cran.r-project.org/package=kerasnip)
11+
1112
<!-- badges: end -->
1213

1314
The goal of `kerasnip` is to provide a seamless bridge between the `keras` and `tidymodels` frameworks. It allows for the dynamic creation of `parsnip` model specifications for Keras models, making them fully compatible with `tidymodels` workflows.
@@ -112,8 +113,7 @@ create_keras_functional_spec(
112113
main_input = input_block,
113114
path_a = inp_spec(path_block, "main_input"),
114115
path_b = inp_spec(path_block, "main_input"),
115-
concatenated = inp_spec(concat_block, c(path_a = "input_a", path_b = "input_b")),
116-
# The output block must be named 'output'.
116+
concatenated = inp_spec(concat_block, c(input_a = "path_a", input_b = "path_b")),
117117
output = inp_spec(output_block, "concatenated")
118118
),
119119
mode = "regression"
@@ -136,6 +136,7 @@ fit(spec, mpg ~ ., data = mtcars) |>
136136
#> 4 18.6
137137
#> 5 17.9
138138
```
139+
139140
### Example 3: Tuning a Sequential MLP Architecture
140141

141142
This example demonstrates how to tune the number of dense layers and the rate of a final dropout layer, showcasing how to tune both architecture and block hyperparameters simultaneously.
@@ -210,11 +211,11 @@ tune_res <- tune_grid(
210211
# 6. Show the best architecture
211212
show_best(tune_res, metric = "rmse")
212213
#> # A tibble: 5 × 7
213-
#> num_dense dense_units dropout_rate .metric .estimator .mean .config
214-
#> <int> <int> <dbl> <chr> <chr> <dbl> <chr>
214+
#> num_dense dense_units dropout_rate .metric .estimator .mean .config
215+
#> <int> <int> <dbl> <chr> <chr> <dbl> <chr>
215216
#> 1 1 64 0.1 rmse standard 2.92 Preprocessor1_Model02
216217
#> 2 1 64 0.5 rmse standard 3.02 Preprocessor1_Model08
217218
#> 3 3 64 0.1 rmse standard 3.15 Preprocessor1_Model04
218219
#> 4 1 8 0.1 rmse standard 3.20 Preprocessor1_Model01
219220
#> 5 3 8 0.1 rmse standard 3.22 Preprocessor1_Model03
220-
```
221+
```

man/inp_spec.Rd

Lines changed: 10 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_e2e_func_classification.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ test_that("E2E: Functional spec (classification) works", {
2828
path_b = inp_spec(path_block, "main_input"),
2929
concatenated = inp_spec(
3030
concat_block,
31-
c(path_a = "input_a", path_b = "input_b")
31+
c(input_a = "path_a", input_b = "path_b")
3232
),
3333
output = inp_spec(output_block_class, "concatenated")
3434
),
@@ -165,7 +165,7 @@ test_that("E2E: Multi-input, single-output functional classification works", {
165165
path_b = inp_spec(dense_path, "flatten_b"),
166166
concatenated = inp_spec(
167167
concat_block,
168-
c(path_a = "in_1", path_b = "in_2")
168+
c(in_1 = "path_a", in_2 = "path_b")
169169
),
170170
output = inp_spec(output_block_class, "concatenated")
171171
),
@@ -230,7 +230,7 @@ test_that("E2E: Functional spec with pre-constructed optimizer works", {
230230
path_b = inp_spec(path_block, "main_input"),
231231
concatenated = inp_spec(
232232
concat_block,
233-
c(path_a = "input_a", path_b = "input_b")
233+
c(input_a = "path_a", input_b = "path_b")
234234
),
235235
output = inp_spec(output_block_class, "concatenated")
236236
),
@@ -283,7 +283,7 @@ test_that("E2E: Functional spec with string loss works", {
283283
path_b = inp_spec(path_block, "main_input"),
284284
concatenated = inp_spec(
285285
concat_block,
286-
c(path_a = "input_a", path_b = "input_b")
286+
c(input_a = "path_a", input_b = "path_b")
287287
),
288288
output = inp_spec(output_block_class, "concatenated")
289289
),

tests/testthat/test_e2e_func_regression.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ test_that("E2E: Functional spec (regression) works", {
2323
path_b = inp_spec(path_block, "main_input"),
2424
concatenated = inp_spec(
2525
concat_block,
26-
c(path_a = "input_a", path_b = "input_b")
26+
c(input_a = "path_a", input_b = "path_b")
2727
),
2828
output = inp_spec(output_block_reg, "concatenated")
2929
),
@@ -74,7 +74,7 @@ test_that("E2E: Functional regression works with named predictors in formula", {
7474
path_b = inp_spec(path_block, "main_input"),
7575
concatenated = inp_spec(
7676
concat_block,
77-
c(path_a = "input_a", path_b = "input_b")
77+
c(input_a = "path_a", input_b = "path_b")
7878
),
7979
output = inp_spec(output_block_reg, "concatenated")
8080
),
@@ -185,7 +185,7 @@ test_that("E2E: Multi-input, multi-output functional regression works", {
185185
path_b = inp_spec(dense_path, "input_b"),
186186
concatenated = inp_spec(
187187
concat_block,
188-
c(path_a = "in_1", path_b = "in_2")
188+
c(in_1 = "path_a", in_2 = "path_b")
189189
),
190190
output_1 = inp_spec(output_block_1, "concatenated"),
191191
output_2 = inp_spec(output_block_2, "concatenated")

tests/testthat/test_inp_spec.R

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,35 @@ test_that("inp_spec throws error for mismatched input_map names", {
1616
})
1717

1818

19+
test_that("inp_spec supports argument-first mapping", {
20+
block_with_args <- function(numeric, categorical) {
21+
list(numeric = numeric, categorical = categorical)
22+
}
23+
mapper <- c(
24+
numeric = "processed_numeric",
25+
categorical = "processed_categorical"
26+
)
27+
wrapped <- kerasnip:::inp_spec(block_with_args, mapper)
28+
29+
expect_identical(
30+
names(formals(wrapped))[1:2],
31+
c("processed_numeric", "processed_categorical")
32+
)
33+
res <- wrapped(processed_numeric = 10, processed_categorical = 20)
34+
expect_identical(res$numeric, 10)
35+
expect_identical(res$categorical, 20)
36+
})
37+
38+
test_that("inp_spec rejects the legacy input_map orientation", {
39+
block_with_args <- function(input_a, input_b) {}
40+
legacy_mapper <- c(processed_a = "input_a", processed_b = "input_b")
41+
expect_error(
42+
kerasnip:::inp_spec(block_with_args, legacy_mapper),
43+
"not found in the block function"
44+
)
45+
})
46+
47+
1948
test_that("inp_spec throws error for invalid input_map type", {
2049
block_with_args <- function(a) {}
2150
expect_error(

vignettes/functional_api.Rmd

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ create_keras_functional_spec(
115115
processed_2 = inp_spec(dense_path_2, "input_2"),
116116
concatenated = inp_spec(
117117
concat_block,
118-
c(processed_1 = "input_a", processed_2 = "input_b")
118+
c(input_a = "processed_1", input_b = "processed_2")
119119
),
120120
output_1 = inp_spec(output_block_1, "concatenated"), # New output block 1
121-
output_2 = inp_spec(output_block_2, "concatenated") # New output block 2
121+
output_2 = inp_spec(output_block_2, "concatenated") # New output block 2
122122
),
123123
mode = "regression" # Still regression, but will have two columns in y
124124
)
@@ -159,7 +159,7 @@ train_df <- tibble::tibble(
159159
function(i) x_data_2[i, , drop = FALSE]
160160
),
161161
output_1 = y_data_1, # Named output 1
162-
output_2 = y_data_2 # Named output 2
162+
output_2 = y_data_2 # Named output 2
163163
)
164164
165165
rec <- recipe(output_1 + output_2 ~ input_1 + input_2, data = train_df)

vignettes/workflows_functional.Rmd

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ First, we load the necessary packages.
3232
library(kerasnip)
3333
library(tidymodels)
3434
library(keras3)
35-
library(dplyr) # For data manipulation
36-
library(ggplot2) # For plotting
37-
library(future) # For parallel processing
38-
library(finetune) # For racing
35+
library(dplyr) # For data manipulation
36+
library(ggplot2) # For plotting
37+
library(future) # For parallel processing
38+
library(finetune) # For racing
3939
```
4040

4141
## Data Preparation
@@ -149,10 +149,10 @@ create_keras_functional_spec(
149149
combined_features = inp_spec(
150150
concatenate_features,
151151
c(
152-
processed_numerical = "numeric",
153-
processed_neighborhood = "neighborhood",
154-
processed_bldg = "bldg",
155-
processed_condition = "condition"
152+
numeric = "processed_numerical",
153+
neighborhood = "processed_neighborhood",
154+
bldg = "processed_bldg",
155+
condition = "processed_condition"
156156
)
157157
),
158158
output = inp_spec(output_regression, "combined_features")
@@ -288,7 +288,7 @@ final_ames_fit |>
288288
plot(show_shapes = TRUE)
289289
```
290290

291-
![Model](images/model_plot_shapes_fs.png){fig-alt="A picture showing the model shape"}
291+
![Model](images/model_plot_shapes_wf.png){fig-alt="A picture showing the model shape"}
292292

293293
```{r inspect-final-keras-model-history}
294294
# Plot the training history

vignettes/workflows_sequential.Rmd

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ First, we load the necessary packages.
3232
library(kerasnip)
3333
library(tidymodels)
3434
library(keras3)
35-
library(dplyr) # For data manipulation
36-
library(ggplot2) # For plotting
37-
library(future) # For parallel processing
38-
library(finetune) # For racing
35+
library(dplyr) # For data manipulation
36+
library(ggplot2) # For plotting
37+
library(future) # For parallel processing
38+
library(finetune) # For racing
3939
```
4040

4141
## Data Preparation

0 commit comments

Comments
 (0)