Skip to content

Commit 5bbcab9

Browse files
committed
Fixed yhat bug in adaptive coding BCF for the R interface
1 parent 4f1c3e8 commit 5bbcab9

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

R/bcf.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2474,6 +2474,8 @@ bcf <- function(
24742474
)
24752475
tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples)) *
24762476
y_std_train
2477+
control_adj_train <- t(t(tau_hat_train_raw) * b_0_samples) * y_std_train
2478+
mu_hat_train <- mu_hat_train + control_adj_train
24772479
} else {
24782480
tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train) *
24792481
y_std_train
@@ -2508,6 +2510,8 @@ bcf <- function(
25082510
t(tau_hat_test_raw) * (b_1_samples - b_0_samples)
25092511
) *
25102512
y_std_train
2513+
control_adj_test <- t(t(tau_hat_test_raw) * b_0_samples) * y_std_train
2514+
mu_hat_test <- mu_hat_test + control_adj_test
25112515
} else {
25122516
tau_hat_test <- forest_samples_tau$predict_raw(
25132517
forest_dataset_test
@@ -2849,10 +2853,11 @@ predict.bcfmodel <- function(
28492853
"all"
28502854
))
28512855
) {
2852-
stop(paste0(
2856+
warning(paste0(
28532857
"Term '",
28542858
term,
2855-
"' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'."
2859+
"' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'.",
2860+
" This term will be ignored and prediction will only proceed if other requested terms are available in the model."
28562861
))
28572862
}
28582863
}
@@ -3056,6 +3061,8 @@ predict.bcfmodel <- function(
30563061
t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples)
30573062
) *
30583063
y_std
3064+
control_adj <- t(t(tau_hat_raw) * object$b_0_samples) * y_std
3065+
mu_hat_forest <- mu_hat_forest + control_adj
30593066
} else {
30603067
tau_hat_forest <- object$forests_tau$predict_raw(forest_dataset_pred) *
30613068
y_std

tools/simulations/bcf-pred-rmse.R

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Load library
2+
library(stochtree)
3+
4+
# Simulation parameters
5+
n <- 500
6+
p <- 5
7+
n_sim <- 100
8+
test_set_pct <- 0.2
9+
10+
# Simulation container
11+
rmses_cached <- rep(NA_real_, n_sim)
12+
rmses_pred <- rep(NA_real_, n_sim)
13+
14+
# Run the simulation
15+
for (i in 1:n_sim) {
16+
# Generate data
17+
X <- matrix(rnorm(n * p), ncol = p)
18+
mu_x <- X[, 1]
19+
tau_x <- 0.25 * X[, 2]
20+
pi_x <- pnorm(0.5 * X[, 1])
21+
Z <- rbinom(n, 1, pi_x)
22+
E_XZ <- mu_x + Z * tau_x
23+
snr <- 2
24+
y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr)
25+
26+
# Train-test split
27+
n_test <- round(test_set_pct * n)
28+
n_train <- n - n_test
29+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
30+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
31+
X_test <- X[test_inds, ]
32+
X_train <- X[train_inds, ]
33+
pi_test <- pi_x[test_inds]
34+
pi_train <- pi_x[train_inds]
35+
Z_test <- Z[test_inds]
36+
Z_train <- Z[train_inds]
37+
y_test <- y[test_inds]
38+
y_train <- y[train_inds]
39+
mu_test <- mu_x[test_inds]
40+
mu_train <- mu_x[train_inds]
41+
tau_test <- tau_x[test_inds]
42+
tau_train <- tau_x[train_inds]
43+
E_XZ_test <- E_XZ[test_inds]
44+
E_XZ_train <- E_XZ[train_inds]
45+
46+
# Fit a simple BCF model
47+
bcf_model <- bcf(
48+
X_train = X_train,
49+
Z_train = Z_train,
50+
y_train = y_train,
51+
propensity_train = pi_train,
52+
X_test = X_test,
53+
Z_test = Z_test,
54+
propensity_test = pi_test
55+
)
56+
57+
# Predict out of sample
58+
y_hat_test <- predict(
59+
bcf_model,
60+
X = X_test,
61+
Z = Z_test,
62+
propensity = pi_test,
63+
type = "mean",
64+
terms = "y_hat"
65+
)
66+
67+
# Compute RMSE using both cached predictions and those returned by predict()
68+
rmses_cached[i] <- sqrt(mean((rowMeans(bcf_model$y_hat_test) - E_XZ_test)^2))
69+
rmses_pred[i] <- sqrt(mean((y_hat_test - E_XZ_test)^2))
70+
}
71+
72+
# Inspect results
73+
mean(rmses_cached)
74+
mean(rmses_pred)

0 commit comments

Comments
 (0)