Skip to content

Commit c282f3a

Browse files
committed
Refactoring dials to separate file
1 parent 3571233 commit c282f3a

File tree

2 files changed

+38
-37
lines changed

2 files changed

+38
-37
lines changed

R/dials.R

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#' Dials Parameter for Keras Optimizers
2+
#' @param values A character vector of possible optimizers. Defaults to all
3+
#' known optimizers (keras defaults + custom registered).
4+
#' @keywords internal
5+
#' @export
6+
#' @return A `dials` parameter object for Keras optimizers.
7+
optimizer_function <- function(values = NULL) {
8+
if (is.null(values)) {
9+
values <- unique(c(
10+
keras_optimizers,
11+
names(.kerasnip_custom_objects$optimizers)
12+
))
13+
}
14+
dials::new_qual_param(
15+
type = "character",
16+
values = values,
17+
label = c(optimizer_function = "Optimizer Function"),
18+
finalize = NULL
19+
)
20+
}
21+
22+
#' Dials Parameter for Keras Loss Functions
23+
#' @param values A character vector of possible loss functions. Defaults to all
24+
#' known losses (keras defaults + custom registered).
25+
#' @keywords internal
26+
#' @export
27+
#' @return A `dials` parameter object for Keras loss.
28+
loss_function_keras <- function(values = NULL) {
29+
if (is.null(values)) {
30+
values <- unique(c(keras_losses, names(.kerasnip_custom_objects$losses)))
31+
}
32+
dials::new_qual_param(
33+
type = "character",
34+
values = values,
35+
label = c(loss_function_keras = "Loss Function"),
36+
finalize = NULL
37+
)
38+
}

R/utils.R

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -142,44 +142,7 @@ get_keras_object <- function(
142142
name
143143
}
144144

145-
#' Dials Parameter for Keras Optimizers
146-
#' @param values A character vector of possible optimizers. Defaults to all
147-
#' known optimizers (keras defaults + custom registered).
148-
#' @keywords internal
149-
#' @export
150-
#' @return A `dials` parameter object for Keras optimizers.
151-
optimizer_function <- function(values = NULL) {
152-
if (is.null(values)) {
153-
values <- unique(c(
154-
keras_optimizers,
155-
names(.kerasnip_custom_objects$optimizers)
156-
))
157-
}
158-
dials::new_qual_param(
159-
type = "character",
160-
values = values,
161-
label = c(optimizer_function = "Optimizer Function"),
162-
finalize = NULL
163-
)
164-
}
165145

166-
#' Dials Parameter for Keras Loss Functions
167-
#' @param values A character vector of possible loss functions. Defaults to all
168-
#' known losses (keras defaults + custom registered).
169-
#' @keywords internal
170-
#' @export
171-
#' @return A `dials` parameter object for Keras loss.
172-
loss_function_keras <- function(values = NULL) {
173-
if (is.null(values)) {
174-
values <- unique(c(keras_losses, names(.kerasnip_custom_objects$losses)))
175-
}
176-
dials::new_qual_param(
177-
type = "character",
178-
values = values,
179-
label = c(loss_function_keras = "Loss Function"),
180-
finalize = NULL
181-
)
182-
}
183146

184147
#' Process Predictor Input for Keras (Functional API)
185148
#'

0 commit comments

Comments
 (0)