@@ -260,8 +260,8 @@ compute_contrast_bcf_model <- function(
260260# ' Only valid when there is either a mean forest or a random effects term in the BART model.
261261# '
262262# ' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs.
263- # ' @param covariates_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe.
264- # ' @param covariates_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.
263+ # ' @param X_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe.
264+ # ' @param X_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.
265265# ' @param leaf_basis_0 (Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: `NULL`.
266266# ' @param leaf_basis_1 (Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: `NULL`.
267267# ' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects
@@ -306,17 +306,17 @@ compute_contrast_bcf_model <- function(
306306# ' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
307307# ' contrast_test <- compute_contrast_bart_model(
308308# ' bart_model,
309- # ' covariates_0 = X_test,
310- # ' covariates_1 = X_test,
309+ # ' X_0 = X_test,
310+ # ' X_1 = X_test,
311311# ' leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1),
312312# ' leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1),
313313# ' type = "posterior",
314314# ' scale = "linear"
315315# ' )
316316compute_contrast_bart_model <- function (
317317 object ,
318- covariates_0 ,
319- covariates_1 ,
318+ X_0 ,
319+ X_1 ,
320320 leaf_basis_0 = NULL ,
321321 leaf_basis_1 = NULL ,
322322 rfx_group_ids_0 = NULL ,
@@ -360,11 +360,11 @@ compute_contrast_bart_model <- function(
360360 }
361361
362362 # Check that covariates are matrix or data frame
363- if ((! is.data.frame(covariates_0 )) && (! is.matrix(covariates_0 ))) {
364- stop(" covariates_0 must be a matrix or dataframe" )
363+ if ((! is.data.frame(X_0 )) && (! is.matrix(X_0 ))) {
364+ stop(" X_0 must be a matrix or dataframe" )
365365 }
366- if ((! is.data.frame(covariates_1 )) && (! is.matrix(covariates_1 ))) {
367- stop(" covariates_1 must be a matrix or dataframe" )
366+ if ((! is.data.frame(X_1 )) && (! is.matrix(X_1 ))) {
367+ stop(" X_1 must be a matrix or dataframe" )
368368 }
369369
370370 # Convert all input data to matrices if not already converted
@@ -388,20 +388,20 @@ compute_contrast_bart_model <- function(
388388 ) {
389389 stop(" leaf_basis_0 and leaf_basis_1 must be provided for this model" )
390390 }
391- if ((! is.null(leaf_basis_0 )) && (nrow(covariates_0 ) != nrow(leaf_basis_0 ))) {
392- stop(" covariates_0 and leaf_basis_0 must have the same number of rows" )
391+ if ((! is.null(leaf_basis_0 )) && (nrow(X_0 ) != nrow(leaf_basis_0 ))) {
392+ stop(" X_0 and leaf_basis_0 must have the same number of rows" )
393393 }
394- if ((! is.null(leaf_basis_1 )) && (nrow(covariates_1 ) != nrow(leaf_basis_1 ))) {
395- stop(" covariates_1 and leaf_basis_1 must have the same number of rows" )
394+ if ((! is.null(leaf_basis_1 )) && (nrow(X_1 ) != nrow(leaf_basis_1 ))) {
395+ stop(" X_1 and leaf_basis_1 must have the same number of rows" )
396396 }
397- if (object $ model_params $ num_covariates != ncol(covariates_0 )) {
397+ if (object $ model_params $ num_covariates != ncol(X_0 )) {
398398 stop(
399- " covariates_0 must contain the same number of columns as the BART model's training dataset"
399+ " X_0 must contain the same number of columns as the BART model's training dataset"
400400 )
401401 }
402- if (object $ model_params $ num_covariates != ncol(covariates_1 )) {
402+ if (object $ model_params $ num_covariates != ncol(X_1 )) {
403403 stop(
404- " covariates_1 must contain the same number of columns as the BART model's training dataset"
404+ " X_1 must contain the same number of columns as the BART model's training dataset"
405405 )
406406 }
407407 if ((has_rfx ) && (is.null(rfx_group_ids_0 ) || is.null(rfx_group_ids_1 ))) {
@@ -427,7 +427,7 @@ compute_contrast_bart_model <- function(
427427 # Predict for the control arm
428428 control_preds <- predict(
429429 object = object ,
430- covariates = covariates_0 ,
430+ X = X_0 ,
431431 leaf_basis = leaf_basis_0 ,
432432 rfx_group_ids = rfx_group_ids_0 ,
433433 rfx_basis = rfx_basis_0 ,
@@ -439,7 +439,7 @@ compute_contrast_bart_model <- function(
439439 # Predict for the treatment arm
440440 treatment_preds <- predict(
441441 object = object ,
442- covariates = covariates_1 ,
442+ X = X_1 ,
443443 leaf_basis = leaf_basis_1 ,
444444 rfx_group_ids = rfx_group_ids_1 ,
445445 rfx_basis = rfx_basis_1 ,
0 commit comments