@@ -14,35 +14,55 @@ num_trees <- 100
1414n <- 500
1515p_x <- 20
1616snr <- 2
17- X <- matrix (runif(n * p_x ), ncol = p_x )
18- f_XW <- sin(4 * pi * X [,1 ]) + cos(4 * pi * X [,2 ]) + sin(4 * pi * X [,3 ]) + cos(4 * pi * X [,4 ])
17+ X <- matrix (runif(n * p_x ), ncol = p_x )
18+ f_XW <- sin(4 * pi * X [, 1 ]) +
19+ cos(4 * pi * X [, 2 ]) +
20+ sin(4 * pi * X [, 3 ]) +
21+ cos(4 * pi * X [, 4 ])
1922noise_sd <- sd(f_XW ) / snr
20- y <- f_XW + rnorm(n , 0 , 1 )* noise_sd
23+ y <- f_XW + rnorm(n , 0 , 1 ) * noise_sd
2124
2225# Split data into test and train sets
2326test_set_pct <- 0.2
24- n_test <- round(test_set_pct * n )
27+ n_test <- round(test_set_pct * n )
2528n_train <- n - n_test
2629test_inds <- sort(sample(1 : n , n_test , replace = FALSE ))
2730train_inds <- (1 : n )[! ((1 : n ) %in% test_inds )]
28- X_test <- as.data.frame(X [test_inds ,])
29- X_train <- as.data.frame(X [train_inds ,])
31+ X_test <- as.data.frame(X [test_inds , ])
32+ X_train <- as.data.frame(X [train_inds , ])
3033y_test <- y [test_inds ]
3134y_train <- y [train_inds ]
3235
3336# Run the GFR algorithm
34- xbart_params <- list (sample_sigma_global = T ,
35- num_trees_mean = num_trees , alpha_mean = 0.99 ,
36- beta_mean = 1 , max_depth_mean = - 1 ,
37- min_samples_leaf_mean = 1 , sample_sigma_leaf = F ,
38- sigma_leaf_init = 1 / num_trees )
37+ xbart_params <- list (
38+ sample_sigma_global = T ,
39+ num_trees_mean = num_trees ,
40+ alpha_mean = 0.99 ,
41+ beta_mean = 1 ,
42+ max_depth_mean = - 1 ,
43+ min_samples_leaf_mean = 1 ,
44+ sample_sigma_leaf = F ,
45+ sigma_leaf_init = 1 / num_trees
46+ )
3947xbart_model <- stochtree :: bart(
40- X_train = X_train , y_train = y_train , X_test = X_test ,
41- num_gfr = num_gfr , num_burnin = 0 , num_mcmc = 0 , params = xbart_params
48+ X_train = X_train ,
49+ y_train = y_train ,
50+ X_test = X_test ,
51+ num_gfr = num_gfr ,
52+ num_burnin = 0 ,
53+ num_mcmc = 0 ,
54+ params = xbart_params
4255)
43- plot(rowMeans(xbart_model $ y_hat_test ), y_test ); abline(0 ,1 )
56+ plot(rowMeans(xbart_model $ y_hat_test ), y_test )
57+ abline(0 , 1 )
4458cat(sqrt(mean((rowMeans(xbart_model $ y_hat_test ) - y_test )^ 2 )), " \n " )
45- cat(mean((apply(xbart_model $ y_hat_test , 1 , quantile , probs = 0.05 ) < = y_test ) & (apply(xbart_model $ y_hat_test , 1 , quantile , probs = 0.95 ) > = y_test )), " \n " )
59+ cat(
60+ mean(
61+ (apply(xbart_model $ y_hat_test , 1 , quantile , probs = 0.05 ) < = y_test ) &
62+ (apply(xbart_model $ y_hat_test , 1 , quantile , probs = 0.95 ) > = y_test )
63+ ),
64+ " \n "
65+ )
4666xbart_model_string <- stochtree :: saveBARTModelToJsonString(xbart_model )
4767
4868# Parallel setup
@@ -51,20 +71,32 @@ cl <- makeCluster(ncores)
5171registerDoParallel(cl )
5272
5373# Run the parallel BART MCMC samplers
54- bart_model_outputs <- foreach (i = 1 : num_chains ) %dopar % {
74+ bart_model_outputs <- foreach(i = 1 : num_chains ) %dopar %
75+ {
5576 random_seed <- i
56- bart_params <- list (sample_sigma_global = T , sample_sigma_leaf = T ,
57- num_trees_mean = num_trees , random_seed = random_seed ,
58- alpha_mean = 0.999 , beta_mean = 1 )
77+ bart_params <- list (
78+ sample_sigma_global = T ,
79+ sample_sigma_leaf = T ,
80+ num_trees_mean = num_trees ,
81+ random_seed = random_seed ,
82+ alpha_mean = 0.999 ,
83+ beta_mean = 1
84+ )
5985 bart_model <- stochtree :: bart(
60- X_train = X_train , y_train = y_train , X_test = X_test ,
61- num_gfr = 0 , num_burnin = num_burnin , num_mcmc = num_mcmc , params = bart_params ,
62- previous_model_json = xbart_model_string , warmstart_sample_num = num_gfr - i + 1 ,
86+ X_train = X_train ,
87+ y_train = y_train ,
88+ X_test = X_test ,
89+ num_gfr = 0 ,
90+ num_burnin = num_burnin ,
91+ num_mcmc = num_mcmc ,
92+ params = bart_params ,
93+ previous_model_json = xbart_model_string ,
94+ warmstart_sample_num = num_gfr - i + 1 ,
6395 )
6496 bart_model_string <- stochtree :: saveBARTModelToJsonString(bart_model )
6597 y_hat_test <- bart_model $ y_hat_test
66- list (model = bart_model_string , yhat = y_hat_test )
67- }
98+ list (model = bart_model_string , yhat = y_hat_test )
99+ }
68100
69101# Close the cluster connection
70102stopCluster(cl )
@@ -73,43 +105,89 @@ stopCluster(cl)
73105bart_model_strings <- list ()
74106bart_model_yhats <- matrix (NA , nrow = length(y_test ), ncol = num_chains )
75107for (i in 1 : length(bart_model_outputs )) {
76- bart_model_strings [[i ]] <- bart_model_outputs [[i ]]$ model
77- bart_model_yhats [,i ] <- rowMeans(bart_model_outputs [[i ]]$ yhat )
108+ bart_model_strings [[i ]] <- bart_model_outputs [[i ]]$ model
109+ bart_model_yhats [, i ] <- rowMeans(bart_model_outputs [[i ]]$ yhat )
78110}
79111combined_bart <- createBARTModelFromCombinedJsonString(bart_model_strings )
80112
81113# Inspect the results
82- yhat_combined <- predict(combined_bart , X_test )$ y_hat
83- par(mfrow = c(1 ,2 ))
114+ yhat_combined <- predict(combined_bart , X = X_test )$ y_hat
115+ par(mfrow = c(1 , 2 ))
84116for (i in 1 : num_chains ) {
85- offset <- (i - 1 )* num_mcmc
86- inds_start <- offset + 1
87- inds_end <- offset + num_mcmc
88- plot(rowMeans(yhat_combined [,inds_start : inds_end ]), bart_model_yhats [,i ],
89- xlab = " deserialized" , ylab = " original" ,
90- main = paste0(" Chain " , i , " \n Predictions" ))
91- abline(0 ,1 ,col = " red" ,lty = 3 ,lwd = 3 )
117+ offset <- (i - 1 ) * num_mcmc
118+ inds_start <- offset + 1
119+ inds_end <- offset + num_mcmc
120+ plot(
121+ rowMeans(yhat_combined [, inds_start : inds_end ]),
122+ bart_model_yhats [, i ],
123+ xlab = " deserialized" ,
124+ ylab = " original" ,
125+ main = paste0(" Chain " , i , " \n Predictions" )
126+ )
127+ abline(0 , 1 , col = " red" , lty = 3 , lwd = 3 )
92128}
93129for (i in 1 : num_chains ) {
94- offset <- (i - 1 )* num_mcmc
95- inds_start <- offset + 1
96- inds_end <- offset + num_mcmc
97- plot(rowMeans(yhat_combined [,inds_start : inds_end ]), y_test ,
98- xlab = " predicted" , ylab = " actual" ,
99- main = paste0(" Chain " , i , " \n Predictions" ))
100- abline(0 ,1 ,col = " red" ,lty = 3 ,lwd = 3 )
101- cat(sqrt(mean((rowMeans(yhat_combined [,inds_start : inds_end ]) - y_test )^ 2 )), " \n " )
102- cat(mean((apply(yhat_combined [,inds_start : inds_end ], 1 , quantile , probs = 0.05 ) < = y_test ) & (apply(yhat_combined [,inds_start : inds_end ], 1 , quantile , probs = 0.95 ) > = y_test )), " \n " )
130+ offset <- (i - 1 ) * num_mcmc
131+ inds_start <- offset + 1
132+ inds_end <- offset + num_mcmc
133+ plot(
134+ rowMeans(yhat_combined [, inds_start : inds_end ]),
135+ y_test ,
136+ xlab = " predicted" ,
137+ ylab = " actual" ,
138+ main = paste0(" Chain " , i , " \n Predictions" )
139+ )
140+ abline(0 , 1 , col = " red" , lty = 3 , lwd = 3 )
141+ cat(
142+ sqrt(mean((rowMeans(yhat_combined [, inds_start : inds_end ]) - y_test )^ 2 )),
143+ " \n "
144+ )
145+ cat(
146+ mean(
147+ (apply(yhat_combined [, inds_start : inds_end ], 1 , quantile , probs = 0.05 ) < =
148+ y_test ) &
149+ (apply(
150+ yhat_combined [, inds_start : inds_end ],
151+ 1 ,
152+ quantile ,
153+ probs = 0.95
154+ ) > =
155+ y_test )
156+ ),
157+ " \n "
158+ )
103159}
104- par(mfrow = c(1 ,1 ))
160+ par(mfrow = c(1 , 1 ))
105161
106162# Compare to a single chain of MCMC samples initialized at root
107- bart_params <- list (sample_sigma_global = T , sample_sigma_leaf = T ,
108- num_trees_mean = num_trees , alpha_mean = 0.95 , beta_mean = 2 )
163+ bart_params <- list (
164+ sample_sigma_global = T ,
165+ sample_sigma_leaf = T ,
166+ num_trees_mean = num_trees ,
167+ alpha_mean = 0.95 ,
168+ beta_mean = 2
169+ )
109170bart_model <- stochtree :: bart(
110- X_train = X_train , y_train = y_train , X_test = X_test ,
111- num_gfr = 0 , num_burnin = 0 , num_mcmc = num_mcmc , params = bart_params
171+ X_train = X_train ,
172+ y_train = y_train ,
173+ X_test = X_test ,
174+ num_gfr = 0 ,
175+ num_burnin = 0 ,
176+ num_mcmc = num_mcmc ,
177+ params = bart_params
178+ )
179+ plot(
180+ rowMeans(bart_model $ y_hat_test ),
181+ y_test ,
182+ xlab = " predicted" ,
183+ ylab = " actual"
112184)
113- plot(rowMeans( bart_model $ y_hat_test ), y_test , xlab = " predicted " , ylab = " actual " ); abline(0 ,1 )
185+ abline(0 , 1 )
114186cat(sqrt(mean((rowMeans(bart_model $ y_hat_test ) - y_test )^ 2 )), " \n " )
115- cat(mean((apply(bart_model $ y_hat_test , 1 , quantile , probs = 0.05 ) < = y_test ) & (apply(bart_model $ y_hat_test , 1 , quantile , probs = 0.95 ) > = y_test )), " \n " )
187+ cat(
188+ mean(
189+ (apply(bart_model $ y_hat_test , 1 , quantile , probs = 0.05 ) < = y_test ) &
190+ (apply(bart_model $ y_hat_test , 1 , quantile , probs = 0.95 ) > = y_test )
191+ ),
192+ " \n "
193+ )
0 commit comments