@@ -222,37 +222,19 @@ check_pred_type <- function(object, type, ...) {
222222# ' @export
223223
224224format_num <- function (x ) {
225- if (inherits(x , " tbl_spark" ))
225+ if (inherits(x , " tbl_spark" )) {
226226 return (x )
227-
228- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
229- x <- as_tibble(x , .name_repair = " minimal" )
230- if (! any(grepl(" ^\\ .pred" , names(x )))) {
231- names(x ) <- paste0(" .pred_" , names(x ))
232- }
233- } else {
234- x <- tibble(.pred = unname(x ))
235227 }
236-
237- x
228+ ensure_parsnip_format(x , " .pred" , overwrite = FALSE )
238229}
239230
240231# ' @rdname format-internals
241232# ' @export
242233format_class <- function (x ) {
243- if (inherits(x , " tbl_spark" ))
234+ if (inherits(x , " tbl_spark" )) {
244235 return (x )
245-
246- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
247- x <- as_tibble(x , .name_repair = " minimal" )
248- if (! any(grepl(" ^\\ .pred_class" , names(x )))) {
249- names(x ) <- " .pred_class"
250- }
251- } else {
252- x <- tibble(.pred_class = unname(x ))
253236 }
254-
255- x
237+ ensure_parsnip_format(x , " .pred_class" )
256238}
257239
258240# ' @rdname format-internals
@@ -269,64 +251,45 @@ format_classprobs <- function(x) {
269251# ' @rdname format-internals
270252# ' @export
271253format_time <- function (x ) {
272- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
273- x <- as_tibble(x , .name_repair = " minimal" )
274- if (! any(grepl(" ^\\ .pred_time" , names(x )))) {
275- names(x ) <- paste0(" .pred_time_" , names(x ))
276- }
277- } else {
278- x <- tibble(.pred_time = unname(x ))
279- }
280-
281- x
254+ ensure_parsnip_format(x , " .pred_time" , overwrite = FALSE )
282255}
283256
284257# ' @rdname format-internals
285258# ' @export
286259format_survival <- function (x ) {
287- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
288- x <- as_tibble(x , .name_repair = " minimal" )
289- if (! any(grepl(" ^\\ .pred" , names(x )))) {
290- names(x ) <- " .pred"
291- }
292- } else {
293- x <- tibble(.pred = unname(x ))
294- }
295-
296- x
260+ ensure_parsnip_format(x , " .pred" )
297261}
298262
299263# ' @rdname format-internals
300264# ' @export
301265format_linear_pred <- function (x ) {
302- if (inherits(x , " tbl_spark" ))
266+ if (inherits(x , " tbl_spark" )){
303267 return (x )
304-
305- if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
306- x <- as_tibble(x , .name_repair = " minimal" )
307- if (! any(grepl(" ^\\ .pred_linear_pred" , names(x )))) {
308- names(x ) <- " .pred_linear_pred"
309- }
310- } else {
311- x <- tibble(.pred_linear_pred = unname(x ))
312268 }
313-
314- x
269+ ensure_parsnip_format(x , " .pred_linear_pred" )
315270}
316271
317272# ' @rdname format-internals
318273# ' @export
319274format_hazard <- function (x ) {
275+ ensure_parsnip_format(x , " .pred" )
276+ }
277+
278+ ensure_parsnip_format <- function (x , col_name , overwrite = TRUE ) {
320279 if (isTRUE(ncol(x ) > 1 ) | is.data.frame(x )) {
321280 x <- as_tibble(x , .name_repair = " minimal" )
322- if (! any(grepl(" ^\\ .pred" , names(x )))) {
323- names(x ) <- " .pred"
281+ if (! any(grepl(paste0(" ^\\ " , col_name ), names(x )))) {
282+ if (overwrite ) {
283+ names(x ) <- col_name
284+ } else {
285+ names(x ) <- paste(col_name , names(x ), sep = " _" )
286+ }
324287 }
288+ } else {
289+ x <- tibble(unname(x ))
290+ names(x ) <- col_name
291+ x
325292 }
326- else {
327- x <- tibble(.pred_hazard = unname(x ))
328- }
329-
330293 x
331294}
332295
0 commit comments