Skip to content

Commit 9679068

Browse files
committed
Updated tests and BART and BCF R functions
1 parent 078ed72 commit 9679068

File tree

6 files changed

+75
-20
lines changed

6 files changed

+75
-20
lines changed

R/bart.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,11 @@ bart <- function(
418418
# Raise a warning if the data have ties and only GFR is being run
419419
if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) {
420420
num_values <- nrow(X_train)
421-
max_grid_size <- floor(num_values / cutpoint_grid_size)
421+
max_grid_size <- ifelse(
422+
num_values > cutpoint_grid_size,
423+
floor(num_values / cutpoint_grid_size),
424+
1
425+
)
422426
covs_warning_1 <- NULL
423427
covs_warning_2 <- NULL
424428
covs_warning_3 <- NULL

R/bcf.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,11 @@ bcf <- function(
522522
# Raise a warning if the data have ties and only GFR is being run
523523
if ((num_gfr > 0) && (num_mcmc == 0) && (num_burnin == 0)) {
524524
num_values <- nrow(X_train)
525-
max_grid_size <- floor(num_values / cutpoint_grid_size)
525+
max_grid_size <- ifelse(
526+
num_values > cutpoint_grid_size,
527+
floor(num_values / cutpoint_grid_size),
528+
1
529+
)
526530
covs_warning_1 <- NULL
527531
covs_warning_2 <- NULL
528532
covs_warning_3 <- NULL

test/R/testthat/test-bart.R

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,19 @@ test_that("Warmstart BART", {
312312
# Run a new BART chain from the existing (X)BART model
313313
general_param_list <- list(num_chains = 3, keep_every = 5)
314314
expect_no_error(
315+
bart_model <- bart(
316+
X_train = X_train,
317+
y_train = y_train,
318+
X_test = X_test,
319+
num_gfr = 0,
320+
num_burnin = 10,
321+
num_mcmc = 10,
322+
previous_model_json = bart_model_json_string,
323+
previous_model_warmstart_sample_num = 10,
324+
general_params = general_param_list
325+
)
326+
)
327+
expect_warning(
315328
bart_model <- bart(
316329
X_train = X_train,
317330
y_train = y_train,
@@ -376,6 +389,23 @@ test_that("Warmstart BART", {
376389
# Run a new BART chain from the existing (X)BART model
377390
general_param_list <- list(num_chains = 4, keep_every = 5)
378391
expect_no_error(
392+
bart_model <- bart(
393+
X_train = X_train,
394+
y_train = y_train,
395+
X_test = X_test,
396+
rfx_group_ids_train = rfx_group_ids_train,
397+
rfx_group_ids_test = rfx_group_ids_test,
398+
rfx_basis_train = rfx_basis_train,
399+
rfx_basis_test = rfx_basis_test,
400+
num_gfr = 0,
401+
num_burnin = 10,
402+
num_mcmc = 10,
403+
previous_model_json = bart_model_json_string,
404+
previous_model_warmstart_sample_num = 10,
405+
general_params = general_param_list
406+
)
407+
)
408+
expect_warning(
379409
bart_model <- bart(
380410
X_train = X_train,
381411
y_train = y_train,

test/R/testthat/test-bcf.R

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,23 @@ test_that("Warmstart BCF", {
375375
# Run a new BCF chain from the existing (X)BCF model
376376
general_param_list <- list(num_chains = 3, keep_every = 5)
377377
expect_no_error(
378+
bcf_model <- bcf(
379+
X_train = X_train,
380+
y_train = y_train,
381+
Z_train = Z_train,
382+
propensity_train = pi_train,
383+
X_test = X_test,
384+
Z_test = Z_test,
385+
propensity_test = pi_test,
386+
num_gfr = 0,
387+
num_burnin = 10,
388+
num_mcmc = 10,
389+
previous_model_json = bcf_model_json_string,
390+
previous_model_warmstart_sample_num = 10,
391+
general_params = general_param_list
392+
)
393+
)
394+
expect_warning(
378395
bcf_model <- bcf(
379396
X_train = X_train,
380397
y_train = y_train,
@@ -482,7 +499,7 @@ test_that("Warmstart BCF", {
482499
num_burnin = 10,
483500
num_mcmc = 10,
484501
previous_model_json = bcf_model_json_string,
485-
previous_model_warmstart_sample_num = 1,
502+
previous_model_warmstart_sample_num = 10,
486503
general_params = general_param_list
487504
)
488505
)

test/python/test_bart.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def outcome_mean(X):
8383
bart_model_3.from_json_string_list(bart_models_json)
8484

8585
# Assertions
86-
bart_preds_combined = bart_model_3.predict(covariates=X_train)
86+
bart_preds_combined = bart_model_3.predict(X=X_train)
8787
y_hat_train_combined = bart_preds_combined["y_hat"]
8888
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
8989
np.testing.assert_allclose(
@@ -190,7 +190,7 @@ def outcome_mean(X, W):
190190

191191
# Assertions
192192
bart_preds_combined = bart_model_3.predict(
193-
covariates=X_train, basis=basis_train
193+
X=X_train, leaf_basis=basis_train
194194
)
195195
y_hat_train_combined = bart_preds_combined["y_hat"]
196196
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
@@ -298,7 +298,7 @@ def outcome_mean(X, W):
298298

299299
# Assertions
300300
bart_preds_combined = bart_model_3.predict(
301-
covariates=X_train, basis=basis_train
301+
X=X_train, leaf_basis=basis_train
302302
)
303303
y_hat_train_combined = bart_preds_combined["y_hat"]
304304
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
@@ -410,7 +410,7 @@ def conditional_stddev(X):
410410
bart_model_3.from_json_string_list(bart_models_json)
411411

412412
# Assertions
413-
bart_preds_combined = bart_model_3.predict(covariates=X_train)
413+
bart_preds_combined = bart_model_3.predict(X=X_train)
414414
y_hat_train_combined, sigma2_x_train_combined = (
415415
bart_preds_combined["y_hat"],
416416
bart_preds_combined["variance_forest_predictions"],
@@ -545,7 +545,7 @@ def conditional_stddev(X):
545545

546546
# Assertions
547547
bart_preds_combined = bart_model_3.predict(
548-
covariates=X_train, basis=basis_train
548+
X=X_train, leaf_basis=basis_train
549549
)
550550
y_hat_train_combined = bart_preds_combined["y_hat"]
551551
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
@@ -670,7 +670,7 @@ def conditional_stddev(X):
670670

671671
# Assertions
672672
bart_preds_combined = bart_model_3.predict(
673-
covariates=X_train, basis=basis_train
673+
X=X_train, leaf_basis=basis_train
674674
)
675675
y_hat_train_combined = bart_preds_combined["y_hat"]
676676
assert y_hat_train_combined.shape == (n_train, num_mcmc * 2)
@@ -825,7 +825,7 @@ def rfx_term(group_labels, basis):
825825

826826
# Assertions
827827
bart_preds_combined = bart_model_3.predict(
828-
covariates=X_train,
828+
X=X_train,
829829
rfx_group_ids=group_labels_train,
830830
rfx_basis=rfx_basis_train,
831831
)
@@ -998,8 +998,8 @@ def conditional_stddev(X):
998998

999999
# Assertions
10001000
bart_preds_combined = bart_model_3.predict(
1001-
covariates=X_train,
1002-
basis=basis_train,
1001+
X=X_train,
1002+
leaf_basis=basis_train,
10031003
rfx_group_ids=group_labels_train,
10041004
rfx_basis=rfx_basis_train,
10051005
)
@@ -1196,8 +1196,8 @@ def conditional_stddev(X):
11961196
random_effects_params=rfx_params,
11971197
)
11981198
preds = bart_model_4.predict(
1199-
covariates=X_test,
1200-
basis=basis_test,
1199+
X=X_test,
1200+
leaf_basis=basis_test,
12011201
rfx_group_ids=group_labels_test,
12021202
type="posterior",
12031203
terms="rfx",

test/python/test_predict.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,12 @@ def test_bart_prediction(self):
221221
)
222222

223223
# Check that the default predict method returns a dictionary
224-
pred = bart_model.predict(covariates=X_test)
224+
pred = bart_model.predict(X=X_test)
225225
y_hat_posterior_test = pred["y_hat"]
226226
assert y_hat_posterior_test.shape == (20, 10)
227227

228228
# Check that the pre-aggregated predictions match with those computed by np.mean
229-
pred_mean = bart_model.predict(covariates=X_test, type="mean")
229+
pred_mean = bart_model.predict(X=X_test, type="mean")
230230
y_hat_mean_test = pred_mean["y_hat"]
231231
np.testing.assert_almost_equal(
232232
y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)
@@ -245,14 +245,14 @@ def test_bart_prediction(self):
245245
)
246246

247247
# Check that the default predict method returns a dictionary
248-
pred = het_bart_model.predict(covariates=X_test)
248+
pred = het_bart_model.predict(X=X_test)
249249
y_hat_posterior_test = pred["y_hat"]
250250
sigma2_hat_posterior_test = pred["variance_forest_predictions"]
251251
assert y_hat_posterior_test.shape == (20, 10)
252252
assert sigma2_hat_posterior_test.shape == (20, 10)
253253

254254
# Check that the pre-aggregated predictions match with those computed by np.mean
255-
pred_mean = het_bart_model.predict(covariates=X_test, type="mean")
255+
pred_mean = het_bart_model.predict(X=X_test, type="mean")
256256
y_hat_mean_test = pred_mean["y_hat"]
257257
sigma2_hat_mean_test = pred_mean["variance_forest_predictions"]
258258
np.testing.assert_almost_equal(
@@ -265,10 +265,10 @@ def test_bart_prediction(self):
265265
# Check that the "single-term" pre-aggregated predictions
266266
# match those computed by pre-aggregated predictions returned in a dictionary
267267
y_hat_mean_test_single_term = het_bart_model.predict(
268-
covariates=X_test, type="mean", terms="y_hat"
268+
X=X_test, type="mean", terms="y_hat"
269269
)
270270
sigma2_hat_mean_test_single_term = het_bart_model.predict(
271-
covariates=X_test, type="mean", terms="variance_forest"
271+
X=X_test, type="mean", terms="variance_forest"
272272
)
273273
np.testing.assert_almost_equal(y_hat_mean_test, y_hat_mean_test_single_term)
274274
np.testing.assert_almost_equal(

0 commit comments

Comments
 (0)