@@ -2267,7 +2267,12 @@ def sample(
22672267 adaptive_coding_weights = np .expand_dims (
22682268 self .b1_samples - self .b0_samples , axis = (0 , 2 )
22692269 )
2270+ b0_weights = np .expand_dims (
2271+ self .b0_samples , axis = (0 , 2 )
2272+ )
2273+ control_adj_train = self .tau_hat_train * b0_weights * self .y_std
22702274 self .tau_hat_train = self .tau_hat_train * adaptive_coding_weights
2275+ self .mu_hat_train = self .mu_hat_train + np .squeeze (control_adj_train )
22712276 self .tau_hat_train = np .squeeze (self .tau_hat_train * self .y_std )
22722277 if self .multivariate_treatment :
22732278 treatment_term_train = np .multiply (
@@ -2289,7 +2294,12 @@ def sample(
22892294 adaptive_coding_weights_test = np .expand_dims (
22902295 self .b1_samples - self .b0_samples , axis = (0 , 2 )
22912296 )
2297+ b0_weights = np .expand_dims (
2298+ self .b0_samples , axis = (0 , 2 )
2299+ )
2300+ control_adj_test = self .tau_hat_test * b0_weights * self .y_std
22922301 self .tau_hat_test = self .tau_hat_test * adaptive_coding_weights_test
2302+ self .mu_hat_test = self .mu_hat_test + np .squeeze (control_adj_test )
22932303 self .tau_hat_test = np .squeeze (self .tau_hat_test * self .y_std )
22942304 if self .multivariate_treatment :
22952305 treatment_term_test = np .multiply (
@@ -2594,7 +2604,12 @@ def predict(
25942604 adaptive_coding_weights = np .expand_dims (
25952605 self .b1_samples - self .b0_samples , axis = (0 , 2 )
25962606 )
2607+ b0_weights = np .expand_dims (
2608+ self .b0_samples , axis = (0 , 2 )
2609+ )
2610+ control_adj = tau_raw * b0_weights * self .y_std
25972611 tau_raw = tau_raw * adaptive_coding_weights
2612+ mu_x_forest = mu_x_forest + np .squeeze (control_adj )
25982613 tau_x_forest = np .squeeze (tau_raw * self .y_std )
25992614 if Z .shape [1 ] > 1 :
26002615 treatment_term = np .multiply (
0 commit comments