Skip to content

Commit bb7c0e3

Browse files
committed
Added multi-chain support for Python BCF and updated multi-chain implementation for python BART
1 parent c130a86 commit bb7c0e3

File tree

7 files changed

+674
-17
lines changed

7 files changed

+674
-17
lines changed

demo/notebooks/multi_chain.ipynb

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "fbb58475",
6+
"metadata": {},
7+
"source": [
8+
"# Running Multiple Chains (Sequentially or in Parallel) in StochTree\n",
9+
"\n",
10+
"Mixing of an MCMC sampler is a perennial concern for complex Bayesian models. BART and BCF are no exception. On common way to address such concerns is to run multiple independent \"chains\" of an MCMC sampler, so that if each chain gets stuck in a different region of the posterior, their combined samples attain better coverage of the full posterior.\n",
11+
"\n",
12+
"This idea works with the classic \"root-initialized\" MCMC sampler of Chipman et al (2010), but a key insight of He and Hahn (2023) and Krantsevich et al (2023) is that the GFR algorithm may be used to warm-start initialize multiple chains of the BART / BCF MCMC sampler.\n",
13+
"\n",
14+
"Operationally, the above two approaches have the same implementation (setting `num_gfr > 0` if warm-start initialization is desired), so this vignette will demonstrate how to run a multi-chain sampler sequentially.\n",
15+
"\n",
16+
"To begin, load `stochtree` and other relevant libraries"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"id": "1310a192",
23+
"metadata": {},
24+
"outputs": [],
25+
"source": [
26+
"import numpy as np\n",
27+
"import matplotlib.pyplot as plt\n",
28+
"import arviz as az\n",
29+
"from sklearn.model_selection import train_test_split\n",
30+
"from stochtree import BARTModel, BCFModel"
31+
]
32+
},
33+
{
34+
"cell_type": "markdown",
35+
"id": "7ea3a091",
36+
"metadata": {},
37+
"source": [
38+
"# Demo 1: Supervised Learning"
39+
]
40+
},
41+
{
42+
"cell_type": "markdown",
43+
"id": "c85b6eac",
44+
"metadata": {},
45+
"source": [
46+
"## Data Simulation"
47+
]
48+
},
49+
{
50+
"cell_type": "markdown",
51+
"id": "ea171e49",
52+
"metadata": {},
53+
"source": [
54+
"Simulate a simple partitioned linear model"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"id": "0e80f3b7",
61+
"metadata": {},
62+
"outputs": [],
63+
"source": [
64+
"# Generate the data\n",
65+
"random_seed = 1111\n",
66+
"rng = np.random.default_rng(random_seed)\n",
67+
"n = 500\n",
68+
"p_x = 10\n",
69+
"p_w = 1\n",
70+
"snr = 3\n",
71+
"X = rng.uniform(size=(n, p_x))\n",
72+
"leaf_basis = rng.uniform(size=(n, p_w))\n",
73+
"f_XW = (((0 <= X[:, 0]) & (0.25 > X[:, 0])) *\n",
74+
" (-7.5 * leaf_basis[:, 0]) +\n",
75+
" ((0.25 <= X[:, 0]) & (0.5 > X[:, 0])) * (-2.5 * leaf_basis[:, 0]) +\n",
76+
" ((0.5 <= X[:, 0]) & (0.75 > X[:, 0])) * (2.5 * leaf_basis[:, 0]) +\n",
77+
" ((0.75 <= X[:, 0]) & (1 > X[:, 0])) * (7.5 * leaf_basis[:, 0]))\n",
78+
"noise_sd = np.std(f_XW) / snr\n",
79+
"y = f_XW + rng.normal(0, noise_sd, size=n)\n",
80+
"\n",
81+
"# Split data into test and train sets\n",
82+
"test_set_pct = 0.2\n",
83+
"train_inds, test_inds = train_test_split(np.arange(n), test_size=test_set_pct, random_state=random_seed)\n",
84+
"n_train = len(train_inds)\n",
85+
"n_test = len(test_inds)\n",
86+
"X_train = X[train_inds]\n",
87+
"X_test = X[test_inds]\n",
88+
"leaf_basis_train = leaf_basis[train_inds]\n",
89+
"leaf_basis_test = leaf_basis[test_inds]\n",
90+
"y_train = y[train_inds]\n",
91+
"y_test = y[test_inds]"
92+
]
93+
},
94+
{
95+
"cell_type": "markdown",
96+
"id": "dfb36dbe",
97+
"metadata": {},
98+
"source": [
99+
"## Sampling Multiple Chains Sequentially\n",
100+
"\n",
101+
"The simplest way to sample multiple chains of a stochtree model is to do so \"sequentially,\" that is, after chain 1 is sampled, chain 2 is sampled from a different starting state, and similarly for each of the requested chains. This is supported internally in both the `bart()` and `bcf()` functions, with the `num_chains` parameter in the `general_params` list.\n",
102+
"\n",
103+
"Define some high-level parameters, including number of chains to run and number of samples per chain. Here we run 4 independent chains with 5000 MCMC iterations, each of which is initialized by a different \"grow-from-root\" sample (the last 4 of 5 GFR samples) and burned in for 2000 iterations after warm-start."
104+
]
105+
},
106+
{
107+
"cell_type": "code",
108+
"execution_count": null,
109+
"id": "e3e978d6",
110+
"metadata": {},
111+
"outputs": [],
112+
"source": [
113+
"num_chains = 4\n",
114+
"num_gfr = 5\n",
115+
"num_burnin = 2000\n",
116+
"num_mcmc = 5000"
117+
]
118+
},
119+
{
120+
"cell_type": "markdown",
121+
"id": "89e67fe1",
122+
"metadata": {},
123+
"source": [
124+
"Run the sampler"
125+
]
126+
},
127+
{
128+
"cell_type": "code",
129+
"execution_count": null,
130+
"id": "c1af3f2d",
131+
"metadata": {},
132+
"outputs": [],
133+
"source": [
134+
"bart_model = BARTModel()\n",
135+
"bart_model.sample(\n",
136+
" X_train = X_train,\n",
137+
" leaf_basis_train = leaf_basis_train,\n",
138+
" y_train = y_train,\n",
139+
" num_gfr = num_gfr,\n",
140+
" num_burnin = num_burnin,\n",
141+
" num_mcmc = num_mcmc,\n",
142+
" general_params = {'num_chains' : num_chains}\n",
143+
")"
144+
]
145+
},
146+
{
147+
"cell_type": "markdown",
148+
"id": "5bbcc00a",
149+
"metadata": {},
150+
"source": [
151+
"Now we have a `BARTModel` object with `num_chains * num_mcmc` samples stored internally. These samples are arranged sequentially, with the first `num_mcmc` samples corresponding to chain 1, the next `num_mcmc` samples to chain 2, etc...\n",
152+
"\n",
153+
"Since each chain is a set of samples of the same model, we can analyze the samples collectively, for example, by looking at out-of-sample predictions."
154+
]
155+
},
156+
{
157+
"cell_type": "code",
158+
"execution_count": null,
159+
"id": "0cfd1a28",
160+
"metadata": {},
161+
"outputs": [],
162+
"source": [
163+
"y_hat_test = bart_model.predict(\n",
164+
" covariates = X_test,\n",
165+
" basis = leaf_basis_test, \n",
166+
" type = \"mean\", \n",
167+
" terms = \"y_hat\"\n",
168+
")\n",
169+
"plt.scatter(y_hat_test, y_test)\n",
170+
"plt.xlabel(\"Estimated conditional mean\")\n",
171+
"plt.ylabel(\"Actual outcome\")\n",
172+
"plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3, 3)))"
173+
]
174+
},
175+
{
176+
"cell_type": "markdown",
177+
"id": "89015a31",
178+
"metadata": {},
179+
"source": [
180+
"Now, suppose we want to analyze each of the chains separately to assess mixing / convergence.\n",
181+
"\n",
182+
"We can use our knowledge of the internal arrangement of the chain samples to construct a an `mcmc.list` in the `coda` package, from which we can perform various diagnostics."
183+
]
184+
},
185+
{
186+
"cell_type": "code",
187+
"execution_count": null,
188+
"id": "96cee0e4",
189+
"metadata": {},
190+
"outputs": [],
191+
"source": [
192+
"sigma2_samples = bart_model.global_var_samples\n",
193+
"sigma2_samples_by_chain = {\"sigma2\": np.reshape(sigma2_samples, (num_chains, num_mcmc))}\n",
194+
"az.plot_trace(sigma2_samples_by_chain)"
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": null,
200+
"id": "08137cda",
201+
"metadata": {},
202+
"outputs": [],
203+
"source": [
204+
"az.ess(sigma2_samples_by_chain)"
205+
]
206+
},
207+
{
208+
"cell_type": "code",
209+
"execution_count": null,
210+
"id": "552ba09c",
211+
"metadata": {},
212+
"outputs": [],
213+
"source": [
214+
"az.rhat(sigma2_samples_by_chain)"
215+
]
216+
},
217+
{
218+
"cell_type": "code",
219+
"execution_count": null,
220+
"id": "2a0c65f0",
221+
"metadata": {},
222+
"outputs": [],
223+
"source": [
224+
"az.plot_autocorr(sigma2_samples_by_chain)"
225+
]
226+
},
227+
{
228+
"cell_type": "code",
229+
"execution_count": null,
230+
"id": "5a0659ff",
231+
"metadata": {},
232+
"outputs": [],
233+
"source": [
234+
"az.plot_violin(sigma2_samples_by_chain)"
235+
]
236+
}
237+
],
238+
"metadata": {
239+
"language_info": {
240+
"name": "python"
241+
}
242+
},
243+
"nbformat": 4,
244+
"nbformat_minor": 5
245+
}

src/R_random_effects.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ cpp11::list rfx_label_mapper_to_list_cpp(cpp11::external_pointer<StochTree::Labe
321321
void reset_rfx_model_cpp(cpp11::external_pointer<StochTree::MultivariateRegressionRandomEffectsModel> rfx_model,
322322
cpp11::external_pointer<StochTree::RandomEffectsContainer> rfx_container,
323323
int sample_num) {
324-
// Reet the RFX tracker
324+
// Reset the RFX model from a previous sample
325325
rfx_model->ResetFromSample(*rfx_container, sample_num);
326326
}
327327

@@ -330,7 +330,7 @@ void reset_rfx_tracker_cpp(cpp11::external_pointer<StochTree::RandomEffectsTrack
330330
cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset,
331331
cpp11::external_pointer<StochTree::ColumnVector> residual,
332332
cpp11::external_pointer<StochTree::MultivariateRegressionRandomEffectsModel> rfx_model) {
333-
// Reset the RFX tracker
333+
// Reset the RFX tracker from a previous sample
334334
tracker->ResetFromSample(*rfx_model, *dataset, *residual);
335335
}
336336

@@ -339,6 +339,6 @@ void root_reset_rfx_tracker_cpp(cpp11::external_pointer<StochTree::RandomEffects
339339
cpp11::external_pointer<StochTree::RandomEffectsDataset> dataset,
340340
cpp11::external_pointer<StochTree::ColumnVector> residual,
341341
cpp11::external_pointer<StochTree::MultivariateRegressionRandomEffectsModel> rfx_model) {
342-
// Reset the RFX tracker
342+
// Reset the RFX tracker from root
343343
tracker->RootReset(*rfx_model, *dataset, *residual);
344344
}

src/py_stochtree.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,6 +1526,8 @@ class RandomEffectsTrackerCpp {
15261526
StochTree::RandomEffectsTracker* GetTracker() {
15271527
return rfx_tracker_.get();
15281528
}
1529+
void Reset(RandomEffectsModelCpp& rfx_model, RandomEffectsDatasetCpp& rfx_dataset, ResidualCpp& residual);
1530+
void RootReset(RandomEffectsModelCpp& rfx_model, RandomEffectsDatasetCpp& rfx_dataset, ResidualCpp& residual);
15291531

15301532
private:
15311533
std::unique_ptr<StochTree::RandomEffectsTracker> rfx_tracker_;
@@ -1630,6 +1632,9 @@ class RandomEffectsModelCpp {
16301632
void SetVariancePriorScale(double scale) {
16311633
rfx_model_->SetVariancePriorScale(scale);
16321634
}
1635+
void Reset(RandomEffectsContainerCpp& rfx_container, int sample_num) {
1636+
rfx_model_->ResetFromSample(*rfx_container.GetRandomEffectsContainer(), sample_num);
1637+
}
16331638

16341639
private:
16351640
std::unique_ptr<StochTree::MultivariateRegressionRandomEffectsModel> rfx_model_;
@@ -2144,6 +2149,14 @@ void RandomEffectsModelCpp::SampleRandomEffects(RandomEffectsDatasetCpp& rfx_dat
21442149
if (keep_sample) rfx_container.AddSample(*this);
21452150
}
21462151

2152+
void RandomEffectsTrackerCpp::Reset(RandomEffectsModelCpp& rfx_model, RandomEffectsDatasetCpp& rfx_dataset, ResidualCpp& residual) {
2153+
rfx_tracker_->ResetFromSample(*rfx_model.GetModel(), *rfx_dataset.GetDataset(), *residual.GetData());
2154+
}
2155+
2156+
void RandomEffectsTrackerCpp::RootReset(RandomEffectsModelCpp& rfx_model, RandomEffectsDatasetCpp& rfx_dataset, ResidualCpp& residual) {
2157+
rfx_tracker_->RootReset(*rfx_model.GetModel(), *rfx_dataset.GetDataset(), *residual.GetData());
2158+
}
2159+
21472160
PYBIND11_MODULE(stochtree_cpp, m) {
21482161
m.def("cppComputeForestContainerLeafIndices", &cppComputeForestContainerLeafIndices, "Compute leaf indices of the forests in a forest container");
21492162
m.def("cppComputeForestMaxLeafIndex", &cppComputeForestMaxLeafIndex, "Compute max leaf index of a forest in a forest container");
@@ -2369,7 +2382,9 @@ PYBIND11_MODULE(stochtree_cpp, m) {
23692382
py::class_<RandomEffectsTrackerCpp>(m, "RandomEffectsTrackerCpp")
23702383
.def(py::init<py::array_t<int>>())
23712384
.def("GetUniqueGroupIds", &RandomEffectsTrackerCpp::GetUniqueGroupIds)
2372-
.def("GetTracker", &RandomEffectsTrackerCpp::GetTracker);
2385+
.def("GetTracker", &RandomEffectsTrackerCpp::GetTracker)
2386+
.def("Reset", &RandomEffectsTrackerCpp::Reset)
2387+
.def("RootReset", &RandomEffectsTrackerCpp::RootReset);
23732388

23742389
py::class_<RandomEffectsLabelMapperCpp>(m, "RandomEffectsLabelMapperCpp")
23752390
.def(py::init<>())
@@ -2391,7 +2406,8 @@ PYBIND11_MODULE(stochtree_cpp, m) {
23912406
.def("SetWorkingParameterCovariance", &RandomEffectsModelCpp::SetWorkingParameterCovariance)
23922407
.def("SetGroupParameterCovariance", &RandomEffectsModelCpp::SetGroupParameterCovariance)
23932408
.def("SetVariancePriorShape", &RandomEffectsModelCpp::SetVariancePriorShape)
2394-
.def("SetVariancePriorScale", &RandomEffectsModelCpp::SetVariancePriorScale);
2409+
.def("SetVariancePriorScale", &RandomEffectsModelCpp::SetVariancePriorScale)
2410+
.def("Reset", &RandomEffectsModelCpp::Reset);
23952411

23962412
py::class_<GlobalVarianceModelCpp>(m, "GlobalVarianceModelCpp")
23972413
.def(py::init<>())

0 commit comments

Comments
 (0)