Skip to content

Commit f03548b

Browse files
committed
allow objects with correct names to pass through
1 parent 2ead20c commit f03548b

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

R/predict.R

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,16 @@ format_class <- function(x) {
243243
if (inherits(x, "tbl_spark"))
244244
return(x)
245245

246-
tibble(.pred_class = unname(x))
246+
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
247+
x <- as_tibble(x, .name_repair = "minimal")
248+
if (!any(grepl("^\\.pred_class", names(x)))) {
249+
names(x) <- ".pred_class"
250+
}
251+
} else {
252+
x <- tibble(.pred_class = unname(x))
253+
}
254+
255+
x
247256
}
248257

249258
#' @rdname format-internals
@@ -277,7 +286,9 @@ format_time <- function(x) {
277286
format_survival <- function(x) {
278287
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
279288
x <- as_tibble(x, .name_repair = "minimal")
280-
names(x) <- ".pred"
289+
if (!any(grepl("^\\.pred", names(x)))) {
290+
names(x) <- ".pred"
291+
}
281292
} else {
282293
x <- tibble(.pred = unname(x))
283294
}
@@ -293,7 +304,9 @@ format_linear_pred <- function(x) {
293304

294305
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
295306
x <- as_tibble(x, .name_repair = "minimal")
296-
names(x) <- ".pred_linear_pred"
307+
if (!any(grepl("^\\.pred_linear_pred", names(x)))) {
308+
names(x) <- ".pred_linear_pred"
309+
}
297310
} else {
298311
x <- tibble(.pred_linear_pred = unname(x))
299312
}
@@ -306,8 +319,11 @@ format_linear_pred <- function(x) {
306319
format_hazard <- function(x) {
307320
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
308321
x <- as_tibble(x, .name_repair = "minimal")
309-
names(x) <- ".pred"
310-
} else {
322+
if (!any(grepl("^\\.pred", names(x)))) {
323+
names(x) <- ".pred"
324+
}
325+
}
326+
else {
311327
x <- tibble(.pred_hazard = unname(x))
312328
}
313329

0 commit comments

Comments
 (0)