Skip to content

Commit 078ed72

Browse files
committed
Updated predict interface in R and Python and tests / demos that call it
1 parent 13de68f commit 078ed72

File tree

10 files changed

+293
-127
lines changed

10 files changed

+293
-127
lines changed

R/bart.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,7 +1961,7 @@ bart <- function(
19611961
#' y_train <- y[train_inds]
19621962
#' bart_model <- bart(X_train = X_train, y_train = y_train,
19631963
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
1964-
#' y_hat_test <- predict(bart_model, X_test)$y_hat
1964+
#' y_hat_test <- predict(bart_model, X=X_test)$y_hat
19651965
predict.bartmodel <- function(
19661966
object,
19671967
X,
@@ -2843,7 +2843,7 @@ createBARTModelFromJsonFile <- function(json_filename) {
28432843
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
28442844
#' bart_json <- saveBARTModelToJsonString(bart_model)
28452845
#' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json)
2846-
#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat)
2846+
#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X=X_train)$y_hat)
28472847
createBARTModelFromJsonString <- function(json_string) {
28482848
# Load a `CppJson` object from string
28492849
bart_json <- createCppJsonString(json_string)

R/kernel.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ computeForestLeafIndices <- function(
129129
propensity <- rowMeans(
130130
predict(
131131
model_object$bart_propensity_model,
132-
covariates
132+
X = covariates
133133
)$y_hat
134134
)
135135
}

man/createBARTModelFromJsonString.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/predict.bartmodel.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

test/R/testthat/test-bart.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ test_that("Random Effects BART", {
584584
)
585585
preds <- predict(
586586
bart_model,
587-
covariates = X_test,
587+
X = X_test,
588588
leaf_basis = W_test,
589589
rfx_group_ids = rfx_group_ids_test,
590590
type = "posterior",

test/R/testthat/test-predict.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,20 +216,20 @@ test_that("BART predictions with pre-summarization", {
216216
)
217217

218218
# Check that the default predict method returns a list
219-
pred <- predict(bart_model, X_test)
219+
pred <- predict(bart_model, X = X_test)
220220
y_hat_posterior_test <- pred$y_hat
221221
expect_equal(dim(y_hat_posterior_test), c(20, 10))
222222

223223
# Check that the pre-aggregated predictions match with those computed by rowMeans
224-
pred_mean <- predict(bart_model, X_test, type = "mean")
224+
pred_mean <- predict(bart_model, X = X_test, type = "mean")
225225
y_hat_mean_test <- pred_mean$y_hat
226226
expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))
227227

228228
# Check that we warn and return a NULL when requesting terms that weren't fit
229229
expect_warning({
230230
pred_mean <- predict(
231231
bart_model,
232-
X_test,
232+
X = X_test,
233233
type = "mean",
234234
terms = c("rfx", "variance_forest")
235235
)
@@ -248,7 +248,7 @@ test_that("BART predictions with pre-summarization", {
248248
)
249249

250250
# Check that the default predict method returns a list
251-
pred <- predict(het_bart_model, X_test)
251+
pred <- predict(het_bart_model, X = X_test)
252252
y_hat_posterior_test <- pred$y_hat
253253
sigma2_hat_posterior_test <- pred$variance_forest_predictions
254254

@@ -257,7 +257,7 @@ test_that("BART predictions with pre-summarization", {
257257
expect_equal(dim(sigma2_hat_posterior_test), c(20, 10))
258258

259259
# Check that the pre-aggregated predictions match with those computed by rowMeans
260-
pred_mean <- predict(het_bart_model, X_test, type = "mean")
260+
pred_mean <- predict(het_bart_model, X = X_test, type = "mean")
261261
y_hat_mean_test <- pred_mean$y_hat
262262
sigma2_hat_mean_test <- pred_mean$variance_forest_predictions
263263

@@ -269,13 +269,13 @@ test_that("BART predictions with pre-summarization", {
269269
# match those computed by pre-aggregated predictions returned in a list
270270
y_hat_mean_test_single_term <- predict(
271271
het_bart_model,
272-
X_test,
272+
X = X_test,
273273
type = "mean",
274274
terms = "y_hat"
275275
)
276276
sigma2_hat_mean_test_single_term <- predict(
277277
het_bart_model,
278-
X_test,
278+
X = X_test,
279279
type = "mean",
280280
terms = "variance_forest"
281281
)

test/R/testthat/test-serialization.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ test_that("BART Serialization", {
3434
num_mcmc = 10,
3535
general_params = general_param_list
3636
)
37-
y_hat_orig <- rowMeans(predict(bart_model, X_test)$y_hat)
37+
y_hat_orig <- rowMeans(predict(bart_model, X = X_test)$y_hat)
3838

3939
# Save to JSON
4040
bart_json_string <- saveBARTModelToJsonString(bart_model)
@@ -43,7 +43,7 @@ test_that("BART Serialization", {
4343
bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string)
4444

4545
# Predict from the roundtrip BART model
46-
y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X_test)$y_hat)
46+
y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X = X_test)$y_hat)
4747

4848
# Assertion
4949
expect_equal(y_hat_orig, y_hat_reloaded)

tools/debug/bart_predict_debug.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@ bart_model <- bart(
3838
)
3939

4040
# Check several predict approaches
41-
y_hat_posterior_test <- predict(bart_model, X_test)$y_hat
41+
y_hat_posterior_test <- predict(bart_model, X = X_test)$y_hat
4242
y_hat_mean_test <- predict(
4343
bart_model,
44-
X_test,
44+
X = X_test,
4545
type = "mean",
4646
terms = c("y_hat")
4747
)
4848
y_hat_test <- predict(
4949
bart_model,
50-
X_test,
50+
X = X_test,
5151
type = "mean",
5252
terms = c("rfx", "variance")
5353
)
@@ -117,18 +117,18 @@ bart_model <- bart(
117117
# Predict on latent scale
118118
y_hat_post <- predict(
119119
object = bart_model,
120+
X = X_test,
120121
type = "posterior",
121122
terms = c("y_hat"),
122-
X = X_test,
123123
scale = "linear"
124124
)
125125

126126
# Predict on probability scale
127127
y_hat_post_prob <- predict(
128128
object = bart_model,
129+
X = X_test,
129130
type = "posterior",
130131
terms = c("y_hat"),
131-
X = X_test,
132132
scale = "probability"
133133
)
134134

tools/debug/parallel_warmstart.R

Lines changed: 129 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,55 @@ num_trees <- 100
1414
n <- 500
1515
p_x <- 20
1616
snr <- 2
17-
X <- matrix(runif(n*p_x), ncol = p_x)
18-
f_XW <- sin(4*pi*X[,1]) + cos(4*pi*X[,2]) + sin(4*pi*X[,3]) +cos(4*pi*X[,4])
17+
X <- matrix(runif(n * p_x), ncol = p_x)
18+
f_XW <- sin(4 * pi * X[, 1]) +
19+
cos(4 * pi * X[, 2]) +
20+
sin(4 * pi * X[, 3]) +
21+
cos(4 * pi * X[, 4])
1922
noise_sd <- sd(f_XW) / snr
20-
y <- f_XW + rnorm(n, 0, 1)*noise_sd
23+
y <- f_XW + rnorm(n, 0, 1) * noise_sd
2124

2225
# Split data into test and train sets
2326
test_set_pct <- 0.2
24-
n_test <- round(test_set_pct*n)
27+
n_test <- round(test_set_pct * n)
2528
n_train <- n - n_test
2629
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
2730
train_inds <- (1:n)[!((1:n) %in% test_inds)]
28-
X_test <- as.data.frame(X[test_inds,])
29-
X_train <- as.data.frame(X[train_inds,])
31+
X_test <- as.data.frame(X[test_inds, ])
32+
X_train <- as.data.frame(X[train_inds, ])
3033
y_test <- y[test_inds]
3134
y_train <- y[train_inds]
3235

3336
# Run the GFR algorithm
34-
xbart_params <- list(sample_sigma_global = T,
35-
num_trees_mean = num_trees, alpha_mean = 0.99,
36-
beta_mean = 1, max_depth_mean = -1,
37-
min_samples_leaf_mean = 1, sample_sigma_leaf = F,
38-
sigma_leaf_init = 1/num_trees)
37+
xbart_params <- list(
38+
sample_sigma_global = T,
39+
num_trees_mean = num_trees,
40+
alpha_mean = 0.99,
41+
beta_mean = 1,
42+
max_depth_mean = -1,
43+
min_samples_leaf_mean = 1,
44+
sample_sigma_leaf = F,
45+
sigma_leaf_init = 1 / num_trees
46+
)
3947
xbart_model <- stochtree::bart(
40-
X_train = X_train, y_train = y_train, X_test = X_test,
41-
num_gfr = num_gfr, num_burnin = 0, num_mcmc = 0, params = xbart_params
48+
X_train = X_train,
49+
y_train = y_train,
50+
X_test = X_test,
51+
num_gfr = num_gfr,
52+
num_burnin = 0,
53+
num_mcmc = 0,
54+
params = xbart_params
4255
)
43-
plot(rowMeans(xbart_model$y_hat_test), y_test); abline(0,1)
56+
plot(rowMeans(xbart_model$y_hat_test), y_test)
57+
abline(0, 1)
4458
cat(sqrt(mean((rowMeans(xbart_model$y_hat_test) - y_test)^2)), "\n")
45-
cat(mean((apply(xbart_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(xbart_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n")
59+
cat(
60+
mean(
61+
(apply(xbart_model$y_hat_test, 1, quantile, probs = 0.05) <= y_test) &
62+
(apply(xbart_model$y_hat_test, 1, quantile, probs = 0.95) >= y_test)
63+
),
64+
"\n"
65+
)
4666
xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model)
4767

4868
# Parallel setup
@@ -51,20 +71,32 @@ cl <- makeCluster(ncores)
5171
registerDoParallel(cl)
5272

5373
# Run the parallel BART MCMC samplers
54-
bart_model_outputs <- foreach (i = 1:num_chains) %dopar% {
74+
bart_model_outputs <- foreach(i = 1:num_chains) %dopar%
75+
{
5576
random_seed <- i
56-
bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T,
57-
num_trees_mean = num_trees, random_seed = random_seed,
58-
alpha_mean = 0.999, beta_mean = 1)
77+
bart_params <- list(
78+
sample_sigma_global = T,
79+
sample_sigma_leaf = T,
80+
num_trees_mean = num_trees,
81+
random_seed = random_seed,
82+
alpha_mean = 0.999,
83+
beta_mean = 1
84+
)
5985
bart_model <- stochtree::bart(
60-
X_train = X_train, y_train = y_train, X_test = X_test,
61-
num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, params = bart_params,
62-
previous_model_json = xbart_model_string, warmstart_sample_num = num_gfr - i + 1,
86+
X_train = X_train,
87+
y_train = y_train,
88+
X_test = X_test,
89+
num_gfr = 0,
90+
num_burnin = num_burnin,
91+
num_mcmc = num_mcmc,
92+
params = bart_params,
93+
previous_model_json = xbart_model_string,
94+
warmstart_sample_num = num_gfr - i + 1,
6395
)
6496
bart_model_string <- stochtree::saveBARTModelToJsonString(bart_model)
6597
y_hat_test <- bart_model$y_hat_test
66-
list(model=bart_model_string, yhat=y_hat_test)
67-
}
98+
list(model = bart_model_string, yhat = y_hat_test)
99+
}
68100

69101
# Close the cluster connection
70102
stopCluster(cl)
@@ -73,43 +105,89 @@ stopCluster(cl)
73105
bart_model_strings <- list()
74106
bart_model_yhats <- matrix(NA, nrow = length(y_test), ncol = num_chains)
75107
for (i in 1:length(bart_model_outputs)) {
76-
bart_model_strings[[i]] <- bart_model_outputs[[i]]$model
77-
bart_model_yhats[,i] <- rowMeans(bart_model_outputs[[i]]$yhat)
108+
bart_model_strings[[i]] <- bart_model_outputs[[i]]$model
109+
bart_model_yhats[, i] <- rowMeans(bart_model_outputs[[i]]$yhat)
78110
}
79111
combined_bart <- createBARTModelFromCombinedJsonString(bart_model_strings)
80112

81113
# Inspect the results
82-
yhat_combined <- predict(combined_bart, X_test)$y_hat
83-
par(mfrow = c(1,2))
114+
yhat_combined <- predict(combined_bart, X = X_test)$y_hat
115+
par(mfrow = c(1, 2))
84116
for (i in 1:num_chains) {
85-
offset <- (i-1)*num_mcmc
86-
inds_start <- offset + 1
87-
inds_end <- offset + num_mcmc
88-
plot(rowMeans(yhat_combined[,inds_start:inds_end]), bart_model_yhats[,i],
89-
xlab = "deserialized", ylab = "original",
90-
main = paste0("Chain ", i, "\nPredictions"))
91-
abline(0,1,col="red",lty=3,lwd=3)
117+
offset <- (i - 1) * num_mcmc
118+
inds_start <- offset + 1
119+
inds_end <- offset + num_mcmc
120+
plot(
121+
rowMeans(yhat_combined[, inds_start:inds_end]),
122+
bart_model_yhats[, i],
123+
xlab = "deserialized",
124+
ylab = "original",
125+
main = paste0("Chain ", i, "\nPredictions")
126+
)
127+
abline(0, 1, col = "red", lty = 3, lwd = 3)
92128
}
93129
for (i in 1:num_chains) {
94-
offset <- (i-1)*num_mcmc
95-
inds_start <- offset + 1
96-
inds_end <- offset + num_mcmc
97-
plot(rowMeans(yhat_combined[,inds_start:inds_end]), y_test,
98-
xlab = "predicted", ylab = "actual",
99-
main = paste0("Chain ", i, "\nPredictions"))
100-
abline(0,1,col="red",lty=3,lwd=3)
101-
cat(sqrt(mean((rowMeans(yhat_combined[,inds_start:inds_end]) - y_test)^2)), "\n")
102-
cat(mean((apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.05) <= y_test) & (apply(yhat_combined[,inds_start:inds_end], 1, quantile, probs=0.95) >= y_test)), "\n")
130+
offset <- (i - 1) * num_mcmc
131+
inds_start <- offset + 1
132+
inds_end <- offset + num_mcmc
133+
plot(
134+
rowMeans(yhat_combined[, inds_start:inds_end]),
135+
y_test,
136+
xlab = "predicted",
137+
ylab = "actual",
138+
main = paste0("Chain ", i, "\nPredictions")
139+
)
140+
abline(0, 1, col = "red", lty = 3, lwd = 3)
141+
cat(
142+
sqrt(mean((rowMeans(yhat_combined[, inds_start:inds_end]) - y_test)^2)),
143+
"\n"
144+
)
145+
cat(
146+
mean(
147+
(apply(yhat_combined[, inds_start:inds_end], 1, quantile, probs = 0.05) <=
148+
y_test) &
149+
(apply(
150+
yhat_combined[, inds_start:inds_end],
151+
1,
152+
quantile,
153+
probs = 0.95
154+
) >=
155+
y_test)
156+
),
157+
"\n"
158+
)
103159
}
104-
par(mfrow = c(1,1))
160+
par(mfrow = c(1, 1))
105161

106162
# Compare to a single chain of MCMC samples initialized at root
107-
bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T,
108-
num_trees_mean = num_trees, alpha_mean = 0.95, beta_mean = 2)
163+
bart_params <- list(
164+
sample_sigma_global = T,
165+
sample_sigma_leaf = T,
166+
num_trees_mean = num_trees,
167+
alpha_mean = 0.95,
168+
beta_mean = 2
169+
)
109170
bart_model <- stochtree::bart(
110-
X_train = X_train, y_train = y_train, X_test = X_test,
111-
num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, params = bart_params
171+
X_train = X_train,
172+
y_train = y_train,
173+
X_test = X_test,
174+
num_gfr = 0,
175+
num_burnin = 0,
176+
num_mcmc = num_mcmc,
177+
params = bart_params
178+
)
179+
plot(
180+
rowMeans(bart_model$y_hat_test),
181+
y_test,
182+
xlab = "predicted",
183+
ylab = "actual"
112184
)
113-
plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual"); abline(0,1)
185+
abline(0, 1)
114186
cat(sqrt(mean((rowMeans(bart_model$y_hat_test) - y_test)^2)), "\n")
115-
cat(mean((apply(bart_model$y_hat_test, 1, quantile, probs=0.05) <= y_test) & (apply(bart_model$y_hat_test, 1, quantile, probs=0.95) >= y_test)), "\n")
187+
cat(
188+
mean(
189+
(apply(bart_model$y_hat_test, 1, quantile, probs = 0.05) <= y_test) &
190+
(apply(bart_model$y_hat_test, 1, quantile, probs = 0.95) >= y_test)
191+
),
192+
"\n"
193+
)

0 commit comments

Comments
 (0)