Skip to content

Commit 455ba21

Browse files
committed
Updated python BCF predictions
1 parent 5bbcab9 commit 455ba21

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

stochtree/bcf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,7 +2267,12 @@ def sample(
22672267
adaptive_coding_weights = np.expand_dims(
22682268
self.b1_samples - self.b0_samples, axis=(0, 2)
22692269
)
2270+
b0_weights = np.expand_dims(
2271+
self.b0_samples, axis=(0, 2)
2272+
)
2273+
control_adj_train = self.tau_hat_train * b0_weights * self.y_std
22702274
self.tau_hat_train = self.tau_hat_train * adaptive_coding_weights
2275+
self.mu_hat_train = self.mu_hat_train + np.squeeze(control_adj_train)
22712276
self.tau_hat_train = np.squeeze(self.tau_hat_train * self.y_std)
22722277
if self.multivariate_treatment:
22732278
treatment_term_train = np.multiply(
@@ -2289,7 +2294,12 @@ def sample(
22892294
adaptive_coding_weights_test = np.expand_dims(
22902295
self.b1_samples - self.b0_samples, axis=(0, 2)
22912296
)
2297+
b0_weights = np.expand_dims(
2298+
self.b0_samples, axis=(0, 2)
2299+
)
2300+
control_adj_test = self.tau_hat_test * b0_weights * self.y_std
22922301
self.tau_hat_test = self.tau_hat_test * adaptive_coding_weights_test
2302+
self.mu_hat_test = self.mu_hat_test + np.squeeze(control_adj_test)
22932303
self.tau_hat_test = np.squeeze(self.tau_hat_test * self.y_std)
22942304
if self.multivariate_treatment:
22952305
treatment_term_test = np.multiply(
@@ -2594,7 +2604,12 @@ def predict(
25942604
adaptive_coding_weights = np.expand_dims(
25952605
self.b1_samples - self.b0_samples, axis=(0, 2)
25962606
)
2607+
b0_weights = np.expand_dims(
2608+
self.b0_samples, axis=(0, 2)
2609+
)
2610+
control_adj = tau_raw * b0_weights * self.y_std
25972611
tau_raw = tau_raw * adaptive_coding_weights
2612+
mu_x_forest = mu_x_forest + np.squeeze(control_adj)
25982613
tau_x_forest = np.squeeze(tau_raw * self.y_std)
25992614
if Z.shape[1] > 1:
26002615
treatment_term = np.multiply(

0 commit comments

Comments
 (0)