Skip to content

Commit 9ddc9ba

Browse files
committed
Updated multi-chain in R to initialize different chains from different forests if a previous model json was provided
1 parent bb7c0e3 commit 9ddc9ba

File tree

2 files changed

+78
-16
lines changed

2 files changed

+78
-16
lines changed

R/bart.R

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
2929
#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
3030
#' @param previous_model_json (Optional) JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: `NULL`.
31-
#' @param previous_model_warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BART sampler. One-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 1`). Default: `NULL`.
31+
#' @param previous_model_warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BART sampler. One-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 1`). Default: `NULL`. If `num_chains` in the `general_params` list is > 1, then each successive chain will be initialized from a different sample, counting backwards from `previous_model_warmstart_sample_num`. That is, if `previous_model_warmstart_sample_num = 10` and `num_chains = 4`, then chain 1 will be initialized from sample 10, chain 2 from sample 9, chain 3 from sample 8, and chain 4 from sample 7. If `previous_model_json` is provided but `previous_model_warmstart_sample_num` is NULL, the last sample in the previous model will be used to initialize the first chain, counting backwards as noted before. If more chains are requested than there are samples in `previous_model_json`, a warning will be raised and only the last sample will be used.
3232
#' @param general_params (Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional.
3333
#'
3434
#' - `cutpoint_grid_size` Maximum size of the "grid" of potential cutpoints to consider in the GFR algorithm. Default: `100`.
@@ -293,10 +293,36 @@ bart <- function(
293293

294294
# Check if previous model JSON is provided and parse it if so
295295
has_prev_model <- !is.null(previous_model_json)
296+
has_prev_model_index <- !is.null(previous_model_warmstart_sample_num)
296297
if (has_prev_model) {
297298
previous_bart_model <- createBARTModelFromJsonString(
298299
previous_model_json
299300
)
301+
prev_num_samples <- previous_bart_model$model_params$num_samples
302+
if (!has_prev_model_index) {
303+
previous_model_warmstart_sample_num <- prev_num_samples
304+
warning(
305+
"`previous_model_warmstart_sample_num` was not provided alongside `previous_model_json`, so it will be set to the number of samples available in `previous_model_json`"
306+
)
307+
} else {
308+
if (previous_model_warmstart_sample_num < 1) {
309+
stop(
310+
"`previous_model_warmstart_sample_num` must be a positive integer"
311+
)
312+
}
313+
if (previous_model_warmstart_sample_num > prev_num_samples) {
314+
stop(
315+
"`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`"
316+
)
317+
}
318+
}
319+
previous_model_decrement <- T
320+
if (num_chains > previous_model_warmstart_sample_num) {
321+
warning(
322+
"The number of chains being sampled exceeds the number of previous model samples available from the requested position in `previous_model_json`. All chains will be initialized from the same sample."
323+
)
324+
previous_model_decrement <- F
325+
}
300326
previous_y_bar <- previous_bart_model$model_params$outcome_mean
301327
previous_y_scale <- previous_bart_model$model_params$outcome_scale
302328
if (previous_bart_model$model_params$include_mean_forest) {
@@ -1375,11 +1401,16 @@ bart <- function(
13751401
)
13761402
}
13771403
} else if (has_prev_model) {
1404+
warmstart_index <- ifelse(
1405+
previous_model_decrement,
1406+
previous_model_warmstart_sample_num - chain_num + 1,
1407+
previous_model_warmstart_sample_num
1408+
)
13781409
if (include_mean_forest) {
13791410
resetActiveForest(
13801411
active_forest_mean,
13811412
previous_forest_samples_mean,
1382-
previous_model_warmstart_sample_num - 1
1413+
warmstart_index - 1
13831414
)
13841415
resetForestModel(
13851416
forest_model_mean,
@@ -1393,7 +1424,7 @@ bart <- function(
13931424
(!is.null(previous_leaf_var_samples))
13941425
) {
13951426
leaf_scale_double <- previous_leaf_var_samples[
1396-
previous_model_warmstart_sample_num
1427+
warmstart_index
13971428
]
13981429
current_leaf_scale <- as.matrix(leaf_scale_double)
13991430
forest_model_config_mean$update_leaf_model_scale(
@@ -1405,7 +1436,7 @@ bart <- function(
14051436
resetActiveForest(
14061437
active_forest_variance,
14071438
previous_forest_samples_variance,
1408-
previous_model_warmstart_sample_num - 1
1439+
warmstart_index - 1
14091440
)
14101441
resetForestModel(
14111442
forest_model_variance,
@@ -1439,7 +1470,7 @@ bart <- function(
14391470
resetRandomEffectsModel(
14401471
rfx_model,
14411472
previous_rfx_samples,
1442-
previous_model_warmstart_sample_num - 1,
1473+
warmstart_index - 1,
14431474
sigma_alpha_init
14441475
)
14451476
resetRandomEffectsTracker(
@@ -1454,7 +1485,7 @@ bart <- function(
14541485
if (sample_sigma2_global) {
14551486
if (!is.null(previous_global_var_samples)) {
14561487
current_sigma2 <- previous_global_var_samples[
1457-
previous_model_warmstart_sample_num
1488+
warmstart_index
14581489
]
14591490
global_model_config$update_global_error_variance(
14601491
current_sigma2

R/bcf.R

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
2626
#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
2727
#' @param previous_model_json (Optional) JSON string containing a previous BCF model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: `NULL`.
28-
#' @param previous_model_warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BCF sampler. One-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 1`). Default: `NULL`.
28+
#' @param previous_model_warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BCF sampler. One-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 1`). Default: `NULL`. If `num_chains` in the `general_params` list is > 1, then each successive chain will be initialized from a different sample, counting backwards from `previous_model_warmstart_sample_num`. That is, if `previous_model_warmstart_sample_num = 10` and `num_chains = 4`, then chain 1 will be initialized from sample 10, chain 2 from sample 9, chain 3 from sample 8, and chain 4 from sample 7. If `previous_model_json` is provided but `previous_model_warmstart_sample_num` is NULL, the last sample in the previous model will be used to initialize the first chain, counting backwards as noted before. If more chains are requested than there are samples in `previous_model_json`, a warning will be raised and only the last sample will be used.
2929
#' @param general_params (Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional.
3030
#'
3131
#' - `cutpoint_grid_size` Maximum size of the "grid" of potential cutpoints to consider in the GFR algorithm. Default: `100`.
@@ -397,8 +397,34 @@ bcf <- function(
397397

398398
# Check if previous model JSON is provided and parse it if so
399399
has_prev_model <- !is.null(previous_model_json)
400+
has_prev_model_index <- !is.null(previous_model_warmstart_sample_num)
400401
if (has_prev_model) {
401402
previous_bcf_model <- createBCFModelFromJsonString(previous_model_json)
403+
prev_num_samples <- previous_bcf_model$model_params$num_samples
404+
if (!has_prev_model_index) {
405+
previous_model_warmstart_sample_num <- prev_num_samples
406+
warning(
407+
"`previous_model_warmstart_sample_num` was not provided alongside `previous_model_json`, so it will be set to the number of samples available in `previous_model_json`"
408+
)
409+
} else {
410+
if (previous_model_warmstart_sample_num < 1) {
411+
stop(
412+
"`previous_model_warmstart_sample_num` must be a positive integer"
413+
)
414+
}
415+
if (previous_model_warmstart_sample_num > prev_num_samples) {
416+
stop(
417+
"`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`"
418+
)
419+
}
420+
}
421+
previous_model_decrement <- T
422+
if (num_chains > previous_model_warmstart_sample_num) {
423+
warning(
424+
"The number of chains being sampled exceeds the number of previous model samples available from the requested position in `previous_model_json`. All chains will be initialized from the same sample."
425+
)
426+
previous_model_decrement <- F
427+
}
402428
previous_y_bar <- previous_bcf_model$model_params$outcome_mean
403429
previous_y_scale <- previous_bcf_model$model_params$outcome_scale
404430
previous_forest_samples_mu <- previous_bcf_model$forests_mu
@@ -1910,10 +1936,15 @@ bcf <- function(
19101936
)
19111937
}
19121938
} else if (has_prev_model) {
1939+
warmstart_index <- ifelse(
1940+
previous_model_decrement,
1941+
previous_model_warmstart_sample_num - chain_num + 1,
1942+
previous_model_warmstart_sample_num
1943+
)
19131944
resetActiveForest(
19141945
active_forest_mu,
19151946
previous_forest_samples_mu,
1916-
previous_model_warmstart_sample_num - 1
1947+
warmstart_index - 1
19171948
)
19181949
resetForestModel(
19191950
forest_model_mu,
@@ -1925,7 +1956,7 @@ bcf <- function(
19251956
resetActiveForest(
19261957
active_forest_tau,
19271958
previous_forest_samples_tau,
1928-
previous_model_warmstart_sample_num - 1
1959+
warmstart_index - 1
19291960
)
19301961
resetForestModel(
19311962
forest_model_tau,
@@ -1938,7 +1969,7 @@ bcf <- function(
19381969
resetActiveForest(
19391970
active_forest_variance,
19401971
previous_forest_samples_variance,
1941-
previous_model_warmstart_sample_num - 1
1972+
warmstart_index - 1
19421973
)
19431974
resetForestModel(
19441975
forest_model_variance,
@@ -1953,7 +1984,7 @@ bcf <- function(
19531984
(!is.null(previous_leaf_var_mu_samples))
19541985
) {
19551986
leaf_scale_mu_double <- previous_leaf_var_mu_samples[
1956-
previous_model_warmstart_sample_num
1987+
warmstart_index
19571988
]
19581989
current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double)
19591990
forest_model_config_mu$update_leaf_model_scale(
@@ -1965,7 +1996,7 @@ bcf <- function(
19651996
(!is.null(previous_leaf_var_tau_samples))
19661997
) {
19671998
leaf_scale_tau_double <- previous_leaf_var_tau_samples[
1968-
previous_model_warmstart_sample_num
1999+
warmstart_index
19692000
]
19702001
current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double)
19712002
forest_model_config_tau$update_leaf_model_scale(
@@ -1975,12 +2006,12 @@ bcf <- function(
19752006
if (adaptive_coding) {
19762007
if (!is.null(previous_b_1_samples)) {
19772008
current_b_1 <- previous_b_1_samples[
1978-
previous_model_warmstart_sample_num
2009+
warmstart_index
19792010
]
19802011
}
19812012
if (!is.null(previous_b_0_samples)) {
19822013
current_b_0 <- previous_b_0_samples[
1983-
previous_model_warmstart_sample_num
2014+
warmstart_index
19842015
]
19852016
}
19862017
tau_basis_train <- (1 - Z_train) *
@@ -2023,7 +2054,7 @@ bcf <- function(
20232054
resetRandomEffectsModel(
20242055
rfx_model,
20252056
previous_rfx_samples,
2026-
previous_model_warmstart_sample_num - 1,
2057+
warmstart_index - 1,
20272058
sigma_alpha_init
20282059
)
20292060
resetRandomEffectsTracker(
@@ -2038,7 +2069,7 @@ bcf <- function(
20382069
if (sample_sigma2_global) {
20392070
if (!is.null(previous_global_var_samples)) {
20402071
current_sigma2 <- previous_global_var_samples[
2041-
previous_model_warmstart_sample_num
2072+
warmstart_index
20422073
]
20432074
}
20442075
global_model_config$update_global_error_variance(

0 commit comments

Comments
 (0)