Skip to content

Commit c874231

Browse files
authored
Refactor formatting functions for glmnet methods for multi_predict() (#930)
* refactor `format_glmnet_multi_linear_reg()` to match the pattern of `format_glmnet_multi_logistic_reg()` * refactor `format_glmnet_multi_multinom_reg()` to more closely match the other formatting function for glmnet multi_predict() methods * fix column name * required for `list_rbind()`
1 parent e2b9baa commit c874231

File tree

2 files changed

+56
-44
lines changed

2 files changed

+56
-44
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Imports:
3030
magrittr,
3131
pillar,
3232
prettyunits,
33-
purrr,
33+
purrr (>= 1.0.0),
3434
rlang (>= 0.3.1),
3535
stats,
3636
tibble (>= 2.1.1),

R/glmnet-engines.R

Lines changed: 55 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ multi_predict_glmnet <- function(object,
223223
"multinom_reg" = format_glmnet_multi_multinom_reg(pred,
224224
penalty = penalty,
225225
type = type,
226-
n_rows = nrow(new_data),
227-
lvl = object$lvl)
226+
lvl = object$lvl,
227+
n_obs = nrow(new_data))
228228
)
229229

230230
res
@@ -248,26 +248,28 @@ multi_predict._multnet <- multi_predict_glmnet
248248
multi_predict._glmnetfit <- multi_predict_glmnet
249249

250250
format_glmnet_multi_linear_reg <- function(pred, penalty) {
251-
param_key <- tibble(group = colnames(pred), penalty = penalty)
251+
penalty_key <- tibble(s = colnames(pred), penalty = penalty)
252+
252253
pred <- as_tibble(pred)
253-
pred$.row <- 1:nrow(pred)
254-
pred <- gather(pred, group, .pred, -.row)
254+
pred$.row <- seq_len(nrow(pred))
255+
pred <- tidyr::pivot_longer(pred, -.row, names_to = "s", values_to = ".pred")
256+
255257
if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
256-
pred <- full_join(param_key, pred, by = "group", multiple = "all")
258+
pred <- dplyr::full_join(penalty_key, pred, by = "s", multiple = "all")
257259
} else {
258-
pred <- full_join(param_key, pred, by = "group")
260+
pred <- dplyr::full_join(penalty_key, pred, by = "s")
259261
}
260-
pred$group <- NULL
261-
pred <- arrange(pred, .row, penalty)
262-
.row <- pred$.row
263-
pred$.row <- NULL
264-
pred <- split(pred, .row)
265-
names(pred) <- NULL
266-
tibble(.pred = pred)
262+
263+
pred <- pred %>%
264+
dplyr::select(-s) %>%
265+
dplyr::arrange(penalty) %>%
266+
tidyr::nest(.by = .row, .key = ".pred") %>%
267+
dplyr::select(-.row)
268+
269+
pred
267270
}
268271

269272
format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) {
270-
271273
type <- rlang::arg_match(type, c("class", "prob"))
272274

273275
penalty_key <- tibble(s = colnames(pred), penalty = penalty)
@@ -303,36 +305,46 @@ format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) {
303305
pred
304306
}
305307

306-
format_glmnet_multi_multinom_reg <- function(pred, penalty, type, n_rows, lvl) {
307-
format_probs <- function(x) {
308-
x <- as_tibble(x)
309-
names(x) <- paste0(".pred_", names(x))
310-
nms <- names(x)
311-
x$.row <- 1:nrow(x)
312-
x[, c(".row", nms)]
313-
}
308+
format_glmnet_multi_multinom_reg <- function(pred, penalty, type, lvl, n_obs) {
309+
type <- rlang::arg_match(type, c("class", "prob"))
314310

315-
if (type == "prob") {
316-
pred <- apply(pred, 3, format_probs)
317-
names(pred) <- NULL
318-
pred <- map_dfr(pred, function(x) x)
319-
pred$penalty <- rep(penalty, each = n_rows)
320-
pred <- dplyr::relocate(pred, penalty)
321-
} else {
322-
pred <-
323-
tibble(
324-
.row = rep(1:n_rows, length(penalty)),
325-
penalty = rep(penalty, each = n_rows),
326-
.pred_class = factor(as.vector(pred), levels = lvl)
327-
)
328-
}
311+
pred <- switch(
312+
type,
313+
prob = format_glmnet_multinom_prob(pred, penalty, lvl, n_obs),
314+
class = format_glmnet_multinom_class(pred, penalty, lvl, n_obs)
315+
)
316+
317+
pred <- pred %>%
318+
dplyr::arrange(.row, penalty) %>%
319+
tidyr::nest(.by = .row, .key = ".pred") %>%
320+
dplyr::select(-.row)
321+
322+
pred
323+
}
324+
325+
format_glmnet_multinom_prob <- function(pred, penalty, lvl, n_obs) {
326+
# pred is an array with
327+
# dim 1 = observations
328+
# dim 2 = levels of the response
329+
# dim 3 = penalty values
330+
apply(pred, 3, as_tibble) %>%
331+
purrr::list_rbind() %>%
332+
rlang::set_names(paste0(".pred_", lvl)) %>%
333+
dplyr::mutate(
334+
.row = rep(seq_len(n_obs), times = length(penalty)),
335+
penalty = rep(penalty, each = n_obs)
336+
) %>%
337+
dplyr::relocate(penalty)
338+
}
329339

330-
pred <- arrange(pred, .row, penalty)
331-
.row <- pred$.row
332-
pred$.row <- NULL
333-
pred <- split(pred, .row)
334-
names(pred) <- NULL
335-
tibble(.pred = pred)
340+
format_glmnet_multinom_class <- function(pred, penalty, lvl, n_obs) {
341+
# pred is a matrix n_obs x n_penalty
342+
# unless n_obs == 1, then it's a vector of length n_penalty
343+
tibble(
344+
.row = rep(seq_len(n_obs), times = length(penalty)),
345+
penalty = rep(penalty, each = n_obs),
346+
.pred_class = factor(as.vector(pred), levels = lvl)
347+
)
336348
}
337349

338350
# -------------------------------------------------------------------------

0 commit comments

Comments
 (0)