@@ -1924,7 +1924,7 @@ bart <- function(
19241924# ' Predict from a sampled BART model on new data
19251925# '
19261926# ' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs.
1927- # ' @param covariates Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
1927+ # ' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
19281928# ' @param leaf_basis (Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: `NULL`.
19291929# ' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model.
19301930# ' We do not currently support (but plan to in the near future), test set evaluation for group labels
@@ -1964,7 +1964,7 @@ bart <- function(
19641964# ' y_hat_test <- predict(bart_model, X_test)$y_hat
19651965predict.bartmodel <- function (
19661966 object ,
1967- covariates ,
1967+ X ,
19681968 leaf_basis = NULL ,
19691969 rfx_group_ids = NULL ,
19701970 rfx_basis = NULL ,
@@ -2047,8 +2047,8 @@ predict.bartmodel <- function(
20472047 }
20482048
20492049 # Check that covariates are matrix or data frame
2050- if ((! is.data.frame(covariates )) && (! is.matrix(covariates ))) {
2051- stop(" covariates must be a matrix or dataframe" )
2050+ if ((! is.data.frame(X )) && (! is.matrix(X ))) {
2051+ stop(" X must be a matrix or dataframe" )
20522052 }
20532053
20542054 # Convert all input data to matrices if not already converted
@@ -2063,12 +2063,12 @@ predict.bartmodel <- function(
20632063 if ((object $ model_params $ requires_basis ) && (is.null(leaf_basis ))) {
20642064 stop(" Basis (leaf_basis) must be provided for this model" )
20652065 }
2066- if ((! is.null(leaf_basis )) && (nrow(covariates ) != nrow(leaf_basis ))) {
2067- stop(" covariates and leaf_basis must have the same number of rows" )
2066+ if ((! is.null(leaf_basis )) && (nrow(X ) != nrow(leaf_basis ))) {
2067+ stop(" X and leaf_basis must have the same number of rows" )
20682068 }
2069- if (object $ model_params $ num_covariates != ncol(covariates )) {
2069+ if (object $ model_params $ num_covariates != ncol(X )) {
20702070 stop(
2071- " covariates must contain the same number of columns as the BART model's training dataset"
2071+ " X must contain the same number of columns as the BART model's training dataset"
20722072 )
20732073 }
20742074 if ((predict_rfx ) && (is.null(rfx_group_ids ))) {
@@ -2089,7 +2089,7 @@ predict.bartmodel <- function(
20892089
20902090 # Preprocess covariates
20912091 train_set_metadata <- object $ train_set_metadata
2092- covariates <- preprocessPredictionData(covariates , train_set_metadata )
2092+ X <- preprocessPredictionData(X , train_set_metadata )
20932093
20942094 # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
20952095 has_rfx <- FALSE
@@ -2119,8 +2119,8 @@ predict.bartmodel <- function(
21192119 # Only construct a basis if user-provided basis missing
21202120 if (is.null(rfx_basis )) {
21212121 rfx_basis <- matrix (
2122- rep(1 , nrow(covariates )),
2123- nrow = nrow(covariates ),
2122+ rep(1 , nrow(X )),
2123+ nrow = nrow(X ),
21242124 ncol = 1
21252125 )
21262126 }
@@ -2129,9 +2129,9 @@ predict.bartmodel <- function(
21292129
21302130 # Create prediction dataset
21312131 if (! is.null(leaf_basis )) {
2132- prediction_dataset <- createForestDataset(covariates , leaf_basis )
2132+ prediction_dataset <- createForestDataset(X , leaf_basis )
21332133 } else {
2134- prediction_dataset <- createForestDataset(covariates )
2134+ prediction_dataset <- createForestDataset(X )
21352135 }
21362136
21372137 # Compute variance forest predictions
0 commit comments