44
55# ' @export
66tunable.model_spec <- function (x , ... ) {
7- mod_env <- rlang :: ns_env(" parsnip" )$ parsnip
7+
8+ mod_env <- get_model_env()
89
910 if (is.null(x $ engine )) {
1011 stop(" Please declare an engine first using `set_engine()`." , call. = FALSE )
@@ -17,27 +18,35 @@ tunable.model_spec <- function(x, ...) {
1718 sep = " " , call. = FALSE )
1819 }
1920
20- arg_vals <-
21- mod_env [[ arg_name ]] % > %
22- dplyr :: filter( engine == x $ engine ) % > %
23- dplyr :: select( name = parsnip , call_info = func ) % > %
24- dplyr :: full_join(
25- tibble :: tibble( name = c(names(x $ args ), names(x $ eng_args ))),
26- by = " name"
27- ) % > %
28- dplyr :: mutate(
29- source = " model_spec " ,
30- component = mod_type( x ),
31- component_id = dplyr :: if_else( name %in% names( x $ args ), " main " , " engine " )
21+ arg_vals <- mod_env [[ arg_name ]]
22+ arg_vals <- arg_vals [ arg_vals $ engine == x $ engine , c( " parsnip " , " func " )]
23+ names( arg_vals )[names( arg_vals ) == " parsnip " ] <- " name "
24+ names( arg_vals )[names( arg_vals ) == " func" ] <- " call_info "
25+
26+ extra_args <- c(names(x $ args ), names(x $ eng_args ))
27+ extra_args <- extra_args [ ! extra_args %in% arg_vals $ name ]
28+
29+ extra_args_tbl <-
30+ tibble :: new_tibble(
31+ list ( name = extra_args , call_info = vector( " list " , vctrs :: vec_size( extra_args )) ),
32+ nrow = vctrs :: vec_size( extra_args )
3233 )
3334
34- if (nrow(arg_vals ) > 0 ) {
35- has_info <- purrr :: map_lgl(arg_vals $ call_info , is.null )
36- rm_list <- ! (has_info & (arg_vals $ component_id == " main" ))
35+ res <- vctrs :: vec_rbind(arg_vals , extra_args_tbl )
3736
38- arg_vals <- arg_vals [rm_list ,]
37+ res $ source <- " model_spec"
38+ res $ component <- mod_type(x )
39+ res $ component_id <- " main"
40+ res $ component_id [! res $ name %in% names(x $ args )] <- " engine"
41+
42+ if (nrow(res ) > 0 ) {
43+ has_info <- purrr :: map_lgl(res $ call_info , is.null )
44+ rm_list <- ! (has_info & (res $ component_id == " main" ))
45+
46+ res <- res [rm_list , ]
3947 }
40- arg_vals %> % dplyr :: select(name , call_info , source , component , component_id )
48+
49+ res [, c(" name" , " call_info" , " source" , " component" , " component_id" )]
4150}
4251
4352mod_type <- function (.mod ) class(.mod )[class(.mod ) != " model_spec" ][1 ]
0 commit comments