@@ -222,28 +222,19 @@ check_pred_type <- function(object, type, ...) {
222222# ' @export
223223
224224format_num <- function (x ) {
225- if (inherits(x , " tbl_spark" ))
225+ if (inherits(x , " tbl_spark" )) {
226226 return (x )
227-
228- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
229- x <- as_tibble(x , .name_repair = " minimal" )
230- if (! any(grepl(" ^\\ .pred" , names(x )))) {
231- names(x ) <- paste0(" .pred_" , names(x ))
232- }
233- } else {
234- x <- tibble(.pred = unname(x ))
235227 }
236-
237- x
228+ ensure_parsnip_format(x , " .pred" , overwrite = FALSE )
238229}
239230
240231# ' @rdname format-internals
241232# ' @export
242233format_class <- function (x ) {
243- if (inherits(x , " tbl_spark" ))
234+ if (inherits(x , " tbl_spark" )) {
244235 return (x )
245-
246- tibble( .pred_class = unname( x ) )
236+ }
237+ ensure_parsnip_format( x , " .pred_class" )
247238}
248239
249240# ' @rdname format-internals
@@ -260,57 +251,45 @@ format_classprobs <- function(x) {
260251# ' @rdname format-internals
261252# ' @export
262253format_time <- function (x ) {
263- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
264- x <- as_tibble(x , .name_repair = " minimal" )
265- if (! any(grepl(" ^\\ .pred_time" , names(x )))) {
266- names(x ) <- paste0(" .pred_time_" , names(x ))
267- }
268- } else {
269- x <- tibble(.pred_time = unname(x ))
270- }
271-
272- x
254+ ensure_parsnip_format(x , " .pred_time" , overwrite = FALSE )
273255}
274256
275257# ' @rdname format-internals
276258# ' @export
277259format_survival <- function (x ) {
278- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
279- x <- as_tibble(x , .name_repair = " minimal" )
280- names(x ) <- " .pred"
281- } else {
282- x <- tibble(.pred = unname(x ))
283- }
284-
285- x
260+ ensure_parsnip_format(x , " .pred" )
286261}
287262
288263# ' @rdname format-internals
289264# ' @export
290265format_linear_pred <- function (x ) {
291- if (inherits(x , " tbl_spark" ))
266+ if (inherits(x , " tbl_spark" )){
292267 return (x )
293-
294- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
295- x <- as_tibble(x , .name_repair = " minimal" )
296- names(x ) <- " .pred_linear_pred"
297- } else {
298- x <- tibble(.pred_linear_pred = unname(x ))
299268 }
300-
301- x
269+ ensure_parsnip_format(x , " .pred_linear_pred" )
302270}
303271
304272# ' @rdname format-internals
305273# ' @export
306274format_hazard <- function (x ) {
275+ ensure_parsnip_format(x , " .pred" )
276+ }
277+
278+ ensure_parsnip_format <- function (x , col_name , overwrite = TRUE ) {
307279 if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
308280 x <- as_tibble(x , .name_repair = " minimal" )
309- names(x ) <- " .pred"
281+ if (! any(grepl(paste0(" ^\\ " , col_name ), names(x )))) {
282+ if (overwrite ) {
283+ names(x ) <- col_name
284+ } else {
285+ names(x ) <- paste(col_name , names(x ), sep = " _" )
286+ }
287+ }
310288 } else {
311- x <- tibble(.pred_hazard = unname(x ))
289+ x <- tibble(unname(x ))
290+ names(x ) <- col_name
291+ x
312292 }
313-
314293 x
315294}
316295
0 commit comments