@@ -223,8 +223,8 @@ multi_predict_glmnet <- function(object,
223223 " multinom_reg" = format_glmnet_multi_multinom_reg(pred ,
224224 penalty = penalty ,
225225 type = type ,
226- n_rows = nrow( new_data ) ,
227- lvl = object $ lvl )
226+ lvl = object $ lvl ,
227+ n_obs = nrow( new_data ) )
228228 )
229229
230230 res
@@ -248,26 +248,28 @@ multi_predict._multnet <- multi_predict_glmnet
248248multi_predict._glmnetfit <- multi_predict_glmnet
249249
250250format_glmnet_multi_linear_reg <- function (pred , penalty ) {
251- param_key <- tibble(group = colnames(pred ), penalty = penalty )
251+ penalty_key <- tibble(s = colnames(pred ), penalty = penalty )
252+
252253 pred <- as_tibble(pred )
253- pred $ .row <- 1 : nrow(pred )
254- pred <- gather(pred , group , .pred , - .row )
254+ pred $ .row <- seq_len(nrow(pred ))
255+ pred <- tidyr :: pivot_longer(pred , - .row , names_to = " s" , values_to = " .pred" )
256+
255257 if (utils :: packageVersion(" dplyr" ) > = " 1.0.99.9000" ) {
256- pred <- full_join(param_key , pred , by = " group " , multiple = " all" )
258+ pred <- dplyr :: full_join(penalty_key , pred , by = " s " , multiple = " all" )
257259 } else {
258- pred <- full_join(param_key , pred , by = " group " )
260+ pred <- dplyr :: full_join(penalty_key , pred , by = " s " )
259261 }
260- pred $ group <- NULL
261- pred <- arrange(pred , .row , penalty )
262- .row <- pred $ .row
263- pred $ .row <- NULL
264- pred <- split(pred , .row )
265- names(pred ) <- NULL
266- tibble(.pred = pred )
262+
263+ pred <- pred %> %
264+ dplyr :: select(- s ) %> %
265+ dplyr :: arrange(penalty ) %> %
266+ tidyr :: nest(.by = .row , .key = " .pred" ) %> %
267+ dplyr :: select(- .row )
268+
269+ pred
267270}
268271
269272format_glmnet_multi_logistic_reg <- function (pred , penalty , type , lvl ) {
270-
271273 type <- rlang :: arg_match(type , c(" class" , " prob" ))
272274
273275 penalty_key <- tibble(s = colnames(pred ), penalty = penalty )
@@ -303,36 +305,46 @@ format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) {
303305 pred
304306}
305307
306- format_glmnet_multi_multinom_reg <- function (pred , penalty , type , n_rows , lvl ) {
307- format_probs <- function (x ) {
308- x <- as_tibble(x )
309- names(x ) <- paste0(" .pred_" , names(x ))
310- nms <- names(x )
311- x $ .row <- 1 : nrow(x )
312- x [, c(" .row" , nms )]
313- }
308+ format_glmnet_multi_multinom_reg <- function (pred , penalty , type , lvl , n_obs ) {
309+ type <- rlang :: arg_match(type , c(" class" , " prob" ))
314310
315- if (type == " prob" ) {
316- pred <- apply(pred , 3 , format_probs )
317- names(pred ) <- NULL
318- pred <- map_dfr(pred , function (x ) x )
319- pred $ penalty <- rep(penalty , each = n_rows )
320- pred <- dplyr :: relocate(pred , penalty )
321- } else {
322- pred <-
323- tibble(
324- .row = rep(1 : n_rows , length(penalty )),
325- penalty = rep(penalty , each = n_rows ),
326- .pred_class = factor (as.vector(pred ), levels = lvl )
327- )
328- }
311+ pred <- switch (
312+ type ,
313+ prob = format_glmnet_multinom_prob(pred , penalty , lvl , n_obs ),
314+ class = format_glmnet_multinom_class(pred , penalty , lvl , n_obs )
315+ )
316+
317+ pred <- pred %> %
318+ dplyr :: arrange(.row , penalty ) %> %
319+ tidyr :: nest(.by = .row , .key = " .pred" ) %> %
320+ dplyr :: select(- .row )
321+
322+ pred
323+ }
324+
325+ format_glmnet_multinom_prob <- function (pred , penalty , lvl , n_obs ) {
326+ # pred is an array with
327+ # dim 1 = observations
328+ # dim 2 = levels of the response
329+ # dim 3 = penalty values
330+ apply(pred , 3 , as_tibble ) %> %
331+ purrr :: list_rbind() %> %
332+ rlang :: set_names(paste0(" .pred_" , lvl )) %> %
333+ dplyr :: mutate(
334+ .row = rep(seq_len(n_obs ), times = length(penalty )),
335+ penalty = rep(penalty , each = n_obs )
336+ ) %> %
337+ dplyr :: relocate(penalty )
338+ }
329339
330- pred <- arrange(pred , .row , penalty )
331- .row <- pred $ .row
332- pred $ .row <- NULL
333- pred <- split(pred , .row )
334- names(pred ) <- NULL
335- tibble(.pred = pred )
340+ format_glmnet_multinom_class <- function (pred , penalty , lvl , n_obs ) {
341+ # pred is a matrix n_obs x n_penalty
342+ # unless n_obs == 1, then it's a vector of length n_penalty
343+ tibble(
344+ .row = rep(seq_len(n_obs ), times = length(penalty )),
345+ penalty = rep(penalty , each = n_obs ),
346+ .pred_class = factor (as.vector(pred ), levels = lvl )
347+ )
336348}
337349
338350# -------------------------------------------------------------------------
0 commit comments