@@ -302,6 +302,9 @@ check_args.boost_tree <- function(object) {
302302# ' training iterations without improvement before stopping. If `validation` is
303303# ' used, performance is base on the validation set; otherwise, the training set
304304# ' is used.
305+ # ' @param objective A single string (or NULL) that defines the loss function that
306+ # ' `xgboost` uses to create trees. See [xgboost::xgb.train()] for options. If left
307+ # ' NULL, an appropriate loss function is chosen.
305308# ' @param ... Other options to pass to `xgb.train`.
306309# ' @return A fitted `xgboost` object.
307310# ' @keywords internal
@@ -310,7 +313,7 @@ xgb_train <- function(
310313 x , y ,
311314 max_depth = 6 , nrounds = 15 , eta = 0.3 , colsample_bytree = 1 ,
312315 min_child_weight = 1 , gamma = 0 , subsample = 1 , validation = 0 ,
313- early_stop = NULL , ... ) {
316+ early_stop = NULL , objective = NULL , ... ) {
314317
315318 others <- list (... )
316319
@@ -329,14 +332,14 @@ xgb_train <- function(
329332 }
330333
331334
332- if (! any(names( others ) == " objective" )) {
335+ if (is.null( objective )) {
333336 if (is.numeric(y )) {
334- others $ objective <- " reg:squarederror"
337+ objective <- " reg:squarederror"
335338 } else {
336339 if (num_class == 2 ) {
337- others $ objective <- " binary:logistic"
340+ objective <- " binary:logistic"
338341 } else {
339- others $ objective <- " multi:softprob"
342+ objective <- " multi:softprob"
340343 }
341344 }
342345 }
@@ -374,7 +377,8 @@ xgb_train <- function(
374377 gamma = gamma ,
375378 colsample_bytree = colsample_bytree ,
376379 min_child_weight = min(min_child_weight , n ),
377- subsample = subsample
380+ subsample = subsample ,
381+ objective = objective
378382 )
379383
380384 main_args <- list (
@@ -413,13 +417,12 @@ xgb_pred <- function(object, newdata, ...) {
413417
414418 res <- predict(object , newdata , ... )
415419
416- x = switch (
420+ x <- switch (
417421 object $ params $ objective ,
418- " reg:squarederror" = , " reg:logistic" = , " binary:logistic" = res ,
419422 " binary:logitraw" = stats :: binomial()$ linkinv(res ),
420423 " multi:softprob" = matrix (res , ncol = object $ params $ num_class , byrow = TRUE ),
421- res
422- )
424+ res )
425+
423426 x
424427}
425428
0 commit comments