Skip to content

Commit 9597886

Browse files
committed
Updated compute_contrast in R
1 parent 9679068 commit 9597886

File tree

3 files changed

+30
-30
lines changed

3 files changed

+30
-30
lines changed

R/posterior_transformation.R

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ compute_contrast_bcf_model <- function(
260260
#' Only valid when there is either a mean forest or a random effects term in the BART model.
261261
#'
262262
#' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs.
263-
#' @param covariates_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe.
264-
#' @param covariates_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.
263+
#' @param X_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe.
264+
#' @param X_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.
265265
#' @param leaf_basis_0 (Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: `NULL`.
266266
#' @param leaf_basis_1 (Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: `NULL`.
267267
#' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects
@@ -306,17 +306,17 @@ compute_contrast_bcf_model <- function(
306306
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
307307
#' contrast_test <- compute_contrast_bart_model(
308308
#' bart_model,
309-
#' covariates_0 = X_test,
310-
#' covariates_1 = X_test,
309+
#' X_0 = X_test,
310+
#' X_1 = X_test,
311311
#' leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1),
312312
#' leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1),
313313
#' type = "posterior",
314314
#' scale = "linear"
315315
#' )
316316
compute_contrast_bart_model <- function(
317317
object,
318-
covariates_0,
319-
covariates_1,
318+
X_0,
319+
X_1,
320320
leaf_basis_0 = NULL,
321321
leaf_basis_1 = NULL,
322322
rfx_group_ids_0 = NULL,
@@ -360,11 +360,11 @@ compute_contrast_bart_model <- function(
360360
}
361361

362362
# Check that covariates are matrix or data frame
363-
if ((!is.data.frame(covariates_0)) && (!is.matrix(covariates_0))) {
364-
stop("covariates_0 must be a matrix or dataframe")
363+
if ((!is.data.frame(X_0)) && (!is.matrix(X_0))) {
364+
stop("X_0 must be a matrix or dataframe")
365365
}
366-
if ((!is.data.frame(covariates_1)) && (!is.matrix(covariates_1))) {
367-
stop("covariates_1 must be a matrix or dataframe")
366+
if ((!is.data.frame(X_1)) && (!is.matrix(X_1))) {
367+
stop("X_1 must be a matrix or dataframe")
368368
}
369369

370370
# Convert all input data to matrices if not already converted
@@ -388,20 +388,20 @@ compute_contrast_bart_model <- function(
388388
) {
389389
stop("leaf_basis_0 and leaf_basis_1 must be provided for this model")
390390
}
391-
if ((!is.null(leaf_basis_0)) && (nrow(covariates_0) != nrow(leaf_basis_0))) {
392-
stop("covariates_0 and leaf_basis_0 must have the same number of rows")
391+
if ((!is.null(leaf_basis_0)) && (nrow(X_0) != nrow(leaf_basis_0))) {
392+
stop("X_0 and leaf_basis_0 must have the same number of rows")
393393
}
394-
if ((!is.null(leaf_basis_1)) && (nrow(covariates_1) != nrow(leaf_basis_1))) {
395-
stop("covariates_1 and leaf_basis_1 must have the same number of rows")
394+
if ((!is.null(leaf_basis_1)) && (nrow(X_1) != nrow(leaf_basis_1))) {
395+
stop("X_1 and leaf_basis_1 must have the same number of rows")
396396
}
397-
if (object$model_params$num_covariates != ncol(covariates_0)) {
397+
if (object$model_params$num_covariates != ncol(X_0)) {
398398
stop(
399-
"covariates_0 must contain the same number of columns as the BART model's training dataset"
399+
"X_0 must contain the same number of columns as the BART model's training dataset"
400400
)
401401
}
402-
if (object$model_params$num_covariates != ncol(covariates_1)) {
402+
if (object$model_params$num_covariates != ncol(X_1)) {
403403
stop(
404-
"covariates_1 must contain the same number of columns as the BART model's training dataset"
404+
"X_1 must contain the same number of columns as the BART model's training dataset"
405405
)
406406
}
407407
if ((has_rfx) && (is.null(rfx_group_ids_0) || is.null(rfx_group_ids_1))) {
@@ -427,7 +427,7 @@ compute_contrast_bart_model <- function(
427427
# Predict for the control arm
428428
control_preds <- predict(
429429
object = object,
430-
covariates = covariates_0,
430+
X = X_0,
431431
leaf_basis = leaf_basis_0,
432432
rfx_group_ids = rfx_group_ids_0,
433433
rfx_basis = rfx_basis_0,
@@ -439,7 +439,7 @@ compute_contrast_bart_model <- function(
439439
# Predict for the treatment arm
440440
treatment_preds <- predict(
441441
object = object,
442-
covariates = covariates_1,
442+
X = X_1,
443443
leaf_basis = leaf_basis_1,
444444
rfx_group_ids = rfx_group_ids_1,
445445
rfx_basis = rfx_basis_1,

man/compute_contrast_bart_model.Rd

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tools/debug/bart_contrast_debug.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ bart_model <- bart(
4545
# Compute contrast posterior
4646
contrast_posterior_test <- compute_contrast_bart_model(
4747
bart_model,
48-
covariates_0 = X_test,
49-
covariates_1 = X_test,
48+
X_0 = X_test,
49+
X_1 = X_test,
5050
leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1),
5151
leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1),
5252
type = "posterior",
@@ -128,8 +128,8 @@ bart_model <- bart(
128128
# Compute contrast posterior
129129
contrast_posterior_test <- compute_contrast_bart_model(
130130
bart_model,
131-
covariates_0 = X_test,
132-
covariates_1 = X_test,
131+
X_0 = X_test,
132+
X_1 = X_test,
133133
leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1),
134134
leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1),
135135
rfx_group_ids_0 = group_ids_test,

0 commit comments

Comments
 (0)