Skip to content

Commit 13de68f

Browse files
committed
Harmonized BART sampler arguments between R and Python (and wherever they are called)
1 parent 51c7f40 commit 13de68f

18 files changed

+93
-93
lines changed

R/bart.R

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,7 +1924,7 @@ bart <- function(
19241924
#' Predict from a sampled BART model on new data
19251925
#'
19261926
#' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs.
1927-
#' @param covariates Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
1927+
#' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.
19281928
#' @param leaf_basis (Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: `NULL`.
19291929
#' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model.
19301930
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
@@ -1964,7 +1964,7 @@ bart <- function(
19641964
#' y_hat_test <- predict(bart_model, X_test)$y_hat
19651965
predict.bartmodel <- function(
19661966
object,
1967-
covariates,
1967+
X,
19681968
leaf_basis = NULL,
19691969
rfx_group_ids = NULL,
19701970
rfx_basis = NULL,
@@ -2047,8 +2047,8 @@ predict.bartmodel <- function(
20472047
}
20482048

20492049
# Check that covariates are matrix or data frame
2050-
if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) {
2051-
stop("covariates must be a matrix or dataframe")
2050+
if ((!is.data.frame(X)) && (!is.matrix(X))) {
2051+
stop("X must be a matrix or dataframe")
20522052
}
20532053

20542054
# Convert all input data to matrices if not already converted
@@ -2063,12 +2063,12 @@ predict.bartmodel <- function(
20632063
if ((object$model_params$requires_basis) && (is.null(leaf_basis))) {
20642064
stop("Basis (leaf_basis) must be provided for this model")
20652065
}
2066-
if ((!is.null(leaf_basis)) && (nrow(covariates) != nrow(leaf_basis))) {
2067-
stop("covariates and leaf_basis must have the same number of rows")
2066+
if ((!is.null(leaf_basis)) && (nrow(X) != nrow(leaf_basis))) {
2067+
stop("X and leaf_basis must have the same number of rows")
20682068
}
2069-
if (object$model_params$num_covariates != ncol(covariates)) {
2069+
if (object$model_params$num_covariates != ncol(X)) {
20702070
stop(
2071-
"covariates must contain the same number of columns as the BART model's training dataset"
2071+
"X must contain the same number of columns as the BART model's training dataset"
20722072
)
20732073
}
20742074
if ((predict_rfx) && (is.null(rfx_group_ids))) {
@@ -2089,7 +2089,7 @@ predict.bartmodel <- function(
20892089

20902090
# Preprocess covariates
20912091
train_set_metadata <- object$train_set_metadata
2092-
covariates <- preprocessPredictionData(covariates, train_set_metadata)
2092+
X <- preprocessPredictionData(X, train_set_metadata)
20932093

20942094
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
20952095
has_rfx <- FALSE
@@ -2119,8 +2119,8 @@ predict.bartmodel <- function(
21192119
# Only construct a basis if user-provided basis missing
21202120
if (is.null(rfx_basis)) {
21212121
rfx_basis <- matrix(
2122-
rep(1, nrow(covariates)),
2123-
nrow = nrow(covariates),
2122+
rep(1, nrow(X)),
2123+
nrow = nrow(X),
21242124
ncol = 1
21252125
)
21262126
}
@@ -2129,9 +2129,9 @@ predict.bartmodel <- function(
21292129

21302130
# Create prediction dataset
21312131
if (!is.null(leaf_basis)) {
2132-
prediction_dataset <- createForestDataset(covariates, leaf_basis)
2132+
prediction_dataset <- createForestDataset(X, leaf_basis)
21332133
} else {
2134-
prediction_dataset <- createForestDataset(covariates)
2134+
prediction_dataset <- createForestDataset(X)
21352135
}
21362136

21372137
# Compute variance forest predictions

demo/debug/bart_contrast_debug.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,15 @@
6565

6666
# Compute the same quantity via two predict calls
6767
y_hat_posterior_test_0 = bart_model.predict(
68-
covariates=X_test,
69-
basis=np.zeros((n_test, 1)),
68+
X=X_test,
69+
leaf_basis=np.zeros((n_test, 1)),
7070
type="posterior",
7171
terms="y_hat",
7272
scale="linear",
7373
)
7474
y_hat_posterior_test_1 = bart_model.predict(
75-
covariates=X_test,
76-
basis=np.ones((n_test, 1)),
75+
X=X_test,
76+
leaf_basis=np.ones((n_test, 1)),
7777
type="posterior",
7878
terms="y_hat",
7979
scale="linear",
@@ -157,17 +157,17 @@
157157

158158
# Compute the same quantity via two predict calls
159159
y_hat_posterior_test_0 = bart_model.predict(
160-
covariates=X_test,
161-
basis=np.zeros((n_test, 1)),
160+
X=X_test,
161+
leaf_basis=np.zeros((n_test, 1)),
162162
rfx_group_ids=group_ids_test,
163163
rfx_basis=rfx_basis_test,
164164
type="posterior",
165165
terms="y_hat",
166166
scale="linear",
167167
)
168168
y_hat_posterior_test_1 = bart_model.predict(
169-
covariates=X_test,
170-
basis=np.ones((n_test, 1)),
169+
X=X_test,
170+
leaf_basis=np.ones((n_test, 1)),
171171
rfx_group_ids=group_ids_test,
172172
rfx_basis=rfx_basis_test,
173173
type="posterior",

demo/debug/bart_predict_debug.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@
4646
)
4747

4848
# # Check several predict approaches
49-
bart_preds = bart_model.predict(covariates=X_test)
50-
y_hat_posterior_test = bart_model.predict(covariates=X_test)["y_hat"]
51-
y_hat_mean_test = bart_model.predict(covariates=X_test, type="mean", terms=["y_hat"])
49+
bart_preds = bart_model.predict(X=X_test)
50+
y_hat_posterior_test = bart_model.predict(X=X_test)["y_hat"]
51+
y_hat_mean_test = bart_model.predict(X=X_test, type="mean", terms=["y_hat"])
5252
y_hat_test = bart_model.predict(
53-
covariates=X_test, type="mean", terms=["rfx", "variance"]
53+
X=X_test, type="mean", terms=["rfx", "variance"]
5454
)
5555

5656
# Plot predicted versus actual

demo/debug/gfr_ties_debug.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939

4040
# Inspect the model fit
41-
y_hat_test = xbart_model.predict(X_test, type="mean", terms="y_hat")
41+
y_hat_test = xbart_model.predict(X=X_test, type="mean", terms="y_hat")
4242
plt.scatter(y_hat_test, y_test)
4343
plt.axline((0, 0), slope=1, color="red", linestyle="dashed", linewidth=1.5)
4444
plt.xlabel("Predicted Outcome Mean")
@@ -54,7 +54,7 @@
5454
)
5555

5656
# Inspect the model fit
57-
y_hat_test = bart_model.predict(X_test, type="mean", terms="y_hat")
57+
y_hat_test = bart_model.predict(X=X_test, type="mean", terms="y_hat")
5858
plt.clf()
5959
plt.scatter(y_hat_test, y_test)
6060
plt.axline((0, 0), slope=1, color="red", linestyle="dashed", linewidth=1.5)
@@ -95,7 +95,7 @@
9595
)
9696

9797
# Inspect the model fit
98-
y_hat_test = xbart_model.predict(X_test, type="mean", terms="y_hat")
98+
y_hat_test = xbart_model.predict(X=X_test, type="mean", terms="y_hat")
9999
plt.clf()
100100
plt.scatter(y_hat_test, y_test)
101101
plt.axline((0, 0), slope=1, color="red", linestyle="dashed", linewidth=1.5)
@@ -112,7 +112,7 @@
112112
)
113113

114114
# Inspect the model fit
115-
y_hat_test = bart_model.predict(X_test, type="mean", terms="y_hat")
115+
y_hat_test = bart_model.predict(X=X_test, type="mean", terms="y_hat")
116116
plt.clf()
117117
plt.scatter(y_hat_test, y_test)
118118
plt.axline((0, 0), slope=1, color="red", linestyle="dashed", linewidth=1.5)

demo/debug/multi_chain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def outcome_mean(X, W):
8989

9090
# Analyze model predictions collectively across all chains
9191
y_hat_test = bart_model.predict(
92-
covariates = X_test,
93-
basis = basis_test,
92+
X = X_test,
93+
leaf_basis = basis_test,
9494
type = "mean",
9595
terms = "y_hat"
9696
)

demo/debug/multiple_initializations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,14 @@ def outcome_mean(X, W):
118118
)
119119

120120
# Inspect the model outputs
121-
bart_preds_2 = bart_model_2.predict(X_test, basis_test)
121+
bart_preds_2 = bart_model_2.predict(X=X_test, basis_test)
122122
y_hat_mcmc_2 = bart_preds_2['y_hat']
123123
y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True)
124124
y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True)
125-
bart_preds_3 = bart_model_3.predict(X_test, basis_test)
125+
bart_preds_3 = bart_model_3.predict(X=X_test, basis_test)
126126
y_hat_mcmc_3 = bart_preds_3['y_hat']
127127
y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True)
128-
bart_preds_4 = bart_model_4.predict(X_test, basis_test)
128+
bart_preds_4 = bart_model_4.predict(X=X_test, basis_test)
129129
y_hat_mcmc_4 = bart_preds_4['y_hat']
130130
y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True)
131131
y_df = pd.DataFrame(

demo/debug/parallel_multi_chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def outcome_mean(X, W):
145145
)
146146

147147
# Inspect the model outputs
148-
bart_preds = combined_bart.predict(X_test, basis_test)
148+
bart_preds = combined_bart.predict(X=X_test, leaf_basis=basis_test)
149149
y_hat_mcmc = bart_preds['y_hat']
150150
y_avg_mcmc = np.squeeze(y_hat_mcmc).mean(axis=1, keepdims=True)
151151
y_df = pd.DataFrame(

demo/debug/probit_bart_rfx_debug.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,17 @@
8686

8787
# Compute the same quantity via two predict calls
8888
y_hat_posterior_test_0 = bart_model.predict(
89-
covariates=X_test,
90-
basis=np.zeros((n_test, 1)),
89+
X=X_test,
90+
leaf_basis=np.zeros((n_test, 1)),
9191
rfx_group_ids=group_ids_test,
9292
rfx_basis=rfx_basis_test,
9393
type="posterior",
9494
terms="y_hat",
9595
scale="linear",
9696
)
9797
y_hat_posterior_test_1 = bart_model.predict(
98-
covariates=X_test,
99-
basis=np.ones((n_test, 1)),
98+
X=X_test,
99+
leaf_basis=np.ones((n_test, 1)),
100100
rfx_group_ids=group_ids_test,
101101
rfx_basis=rfx_basis_test,
102102
type="posterior",
@@ -111,8 +111,8 @@
111111

112112
# Plot predicted versus actual outcome
113113
Z_hat_test = bart_model.predict(
114-
covariates=X_test,
115-
basis=W_test,
114+
X=X_test,
115+
leaf_basis=W_test,
116116
rfx_group_ids=group_ids_test,
117117
rfx_basis=rfx_basis_test,
118118
type="mean",

demo/debug/rfx_serialization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ def rfx_mean(group_labels, basis):
6060
rfx_basis_train=basis, num_gfr=10, num_mcmc=10)
6161

6262
# Extract predictions from the sampler
63-
bart_preds_orig = bart_orig.predict(X, W, group_labels, basis)
63+
bart_preds_orig = bart_orig.predict(X=X, leaf_basis=W, rfx_group_ids=group_labels, rfx_basis=basis)
6464
y_hat_orig = bart_preds_orig['y_hat']
6565

6666
# "Round-trip" the model to JSON string and back and check that the predictions agree
6767
bart_json_string = bart_orig.to_json()
6868
bart_reloaded = BARTModel()
6969
bart_reloaded.from_json(bart_json_string)
70-
bart_preds_reloaded = bart_reloaded.predict(X, W, group_labels, basis)
70+
bart_preds_reloaded = bart_reloaded.predict(X=X, leaf_basis=W, rfx_group_ids=group_labels, rfx_basis=basis)
7171
y_hat_reloaded = bart_preds_reloaded['y_hat']
7272
np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded)

demo/notebooks/multi_chain.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@
161161
"outputs": [],
162162
"source": [
163163
"y_hat_test = bart_model.predict(\n",
164-
" covariates = X_test,\n",
165-
" basis = leaf_basis_test, \n",
164+
" X = X_test,\n",
165+
" leaf_basis = leaf_basis_test, \n",
166166
" type = \"mean\", \n",
167167
" terms = \"y_hat\"\n",
168168
")\n",
@@ -321,8 +321,8 @@
321321
"outputs": [],
322322
"source": [
323323
"y_hat_test = bart_model.predict(\n",
324-
" covariates = X_test,\n",
325-
" basis = leaf_basis_test, \n",
324+
" X = X_test,\n",
325+
" leaf_basis = leaf_basis_test, \n",
326326
" type = \"mean\", \n",
327327
" terms = \"y_hat\"\n",
328328
")\n",

0 commit comments

Comments
 (0)