|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "fbb58475", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# Running Multiple Chains (Sequentially or in Parallel) in StochTree\n", |
| 9 | + "\n", |
| 10 | + "Mixing of an MCMC sampler is a perennial concern for complex Bayesian models. BART and BCF are no exception. On common way to address such concerns is to run multiple independent \"chains\" of an MCMC sampler, so that if each chain gets stuck in a different region of the posterior, their combined samples attain better coverage of the full posterior.\n", |
| 11 | + "\n", |
| 12 | + "This idea works with the classic \"root-initialized\" MCMC sampler of Chipman et al (2010), but a key insight of He and Hahn (2023) and Krantsevich et al (2023) is that the GFR algorithm may be used to warm-start initialize multiple chains of the BART / BCF MCMC sampler.\n", |
| 13 | + "\n", |
| 14 | + "Operationally, the above two approaches have the same implementation (setting `num_gfr > 0` if warm-start initialization is desired), so this vignette will demonstrate how to run a multi-chain sampler sequentially.\n", |
| 15 | + "\n", |
| 16 | + "To begin, load `stochtree` and other relevant libraries" |
| 17 | + ] |
| 18 | + }, |
| 19 | + { |
| 20 | + "cell_type": "code", |
| 21 | + "execution_count": null, |
| 22 | + "id": "1310a192", |
| 23 | + "metadata": {}, |
| 24 | + "outputs": [], |
| 25 | + "source": [ |
| 26 | + "import numpy as np\n", |
| 27 | + "import matplotlib.pyplot as plt\n", |
| 28 | + "import arviz as az\n", |
| 29 | + "from sklearn.model_selection import train_test_split\n", |
| 30 | + "from stochtree import BARTModel, BCFModel" |
| 31 | + ] |
| 32 | + }, |
| 33 | + { |
| 34 | + "cell_type": "markdown", |
| 35 | + "id": "7ea3a091", |
| 36 | + "metadata": {}, |
| 37 | + "source": [ |
| 38 | + "# Demo 1: Supervised Learning" |
| 39 | + ] |
| 40 | + }, |
| 41 | + { |
| 42 | + "cell_type": "markdown", |
| 43 | + "id": "c85b6eac", |
| 44 | + "metadata": {}, |
| 45 | + "source": [ |
| 46 | + "## Data Simulation" |
| 47 | + ] |
| 48 | + }, |
| 49 | + { |
| 50 | + "cell_type": "markdown", |
| 51 | + "id": "ea171e49", |
| 52 | + "metadata": {}, |
| 53 | + "source": [ |
| 54 | + "Simulate a simple partitioned linear model" |
| 55 | + ] |
| 56 | + }, |
| 57 | + { |
| 58 | + "cell_type": "code", |
| 59 | + "execution_count": null, |
| 60 | + "id": "0e80f3b7", |
| 61 | + "metadata": {}, |
| 62 | + "outputs": [], |
| 63 | + "source": [ |
| 64 | + "# Generate the data\n", |
| 65 | + "random_seed = 1111\n", |
| 66 | + "rng = np.random.default_rng(random_seed)\n", |
| 67 | + "n = 500\n", |
| 68 | + "p_x = 10\n", |
| 69 | + "p_w = 1\n", |
| 70 | + "snr = 3\n", |
| 71 | + "X = rng.uniform(size=(n, p_x))\n", |
| 72 | + "leaf_basis = rng.uniform(size=(n, p_w))\n", |
| 73 | + "f_XW = (((0 <= X[:, 0]) & (0.25 > X[:, 0])) *\n", |
| 74 | + " (-7.5 * leaf_basis[:, 0]) +\n", |
| 75 | + " ((0.25 <= X[:, 0]) & (0.5 > X[:, 0])) * (-2.5 * leaf_basis[:, 0]) +\n", |
| 76 | + " ((0.5 <= X[:, 0]) & (0.75 > X[:, 0])) * (2.5 * leaf_basis[:, 0]) +\n", |
| 77 | + " ((0.75 <= X[:, 0]) & (1 > X[:, 0])) * (7.5 * leaf_basis[:, 0]))\n", |
| 78 | + "noise_sd = np.std(f_XW) / snr\n", |
| 79 | + "y = f_XW + rng.normal(0, noise_sd, size=n)\n", |
| 80 | + "\n", |
| 81 | + "# Split data into test and train sets\n", |
| 82 | + "test_set_pct = 0.2\n", |
| 83 | + "train_inds, test_inds = train_test_split(np.arange(n), test_size=test_set_pct, random_state=random_seed)\n", |
| 84 | + "n_train = len(train_inds)\n", |
| 85 | + "n_test = len(test_inds)\n", |
| 86 | + "X_train = X[train_inds]\n", |
| 87 | + "X_test = X[test_inds]\n", |
| 88 | + "leaf_basis_train = leaf_basis[train_inds]\n", |
| 89 | + "leaf_basis_test = leaf_basis[test_inds]\n", |
| 90 | + "y_train = y[train_inds]\n", |
| 91 | + "y_test = y[test_inds]" |
| 92 | + ] |
| 93 | + }, |
| 94 | + { |
| 95 | + "cell_type": "markdown", |
| 96 | + "id": "dfb36dbe", |
| 97 | + "metadata": {}, |
| 98 | + "source": [ |
| 99 | + "## Sampling Multiple Chains Sequentially\n", |
| 100 | + "\n", |
| 101 | + "The simplest way to sample multiple chains of a stochtree model is to do so \"sequentially,\" that is, after chain 1 is sampled, chain 2 is sampled from a different starting state, and similarly for each of the requested chains. This is supported internally in both the `bart()` and `bcf()` functions, with the `num_chains` parameter in the `general_params` list.\n", |
| 102 | + "\n", |
| 103 | + "Define some high-level parameters, including number of chains to run and number of samples per chain. Here we run 4 independent chains with 5000 MCMC iterations, each of which is initialized by a different \"grow-from-root\" sample (the last 4 of 5 GFR samples) and burned in for 2000 iterations after warm-start." |
| 104 | + ] |
| 105 | + }, |
| 106 | + { |
| 107 | + "cell_type": "code", |
| 108 | + "execution_count": null, |
| 109 | + "id": "e3e978d6", |
| 110 | + "metadata": {}, |
| 111 | + "outputs": [], |
| 112 | + "source": [ |
| 113 | + "num_chains = 4\n", |
| 114 | + "num_gfr = 5\n", |
| 115 | + "num_burnin = 2000\n", |
| 116 | + "num_mcmc = 5000" |
| 117 | + ] |
| 118 | + }, |
| 119 | + { |
| 120 | + "cell_type": "markdown", |
| 121 | + "id": "89e67fe1", |
| 122 | + "metadata": {}, |
| 123 | + "source": [ |
| 124 | + "Run the sampler" |
| 125 | + ] |
| 126 | + }, |
| 127 | + { |
| 128 | + "cell_type": "code", |
| 129 | + "execution_count": null, |
| 130 | + "id": "c1af3f2d", |
| 131 | + "metadata": {}, |
| 132 | + "outputs": [], |
| 133 | + "source": [ |
| 134 | + "bart_model = BARTModel()\n", |
| 135 | + "bart_model.sample(\n", |
| 136 | + " X_train = X_train,\n", |
| 137 | + " leaf_basis_train = leaf_basis_train,\n", |
| 138 | + " y_train = y_train,\n", |
| 139 | + " num_gfr = num_gfr,\n", |
| 140 | + " num_burnin = num_burnin,\n", |
| 141 | + " num_mcmc = num_mcmc,\n", |
| 142 | + " general_params = {'num_chains' : num_chains}\n", |
| 143 | + ")" |
| 144 | + ] |
| 145 | + }, |
| 146 | + { |
| 147 | + "cell_type": "markdown", |
| 148 | + "id": "5bbcc00a", |
| 149 | + "metadata": {}, |
| 150 | + "source": [ |
| 151 | + "Now we have a `BARTModel` object with `num_chains * num_mcmc` samples stored internally. These samples are arranged sequentially, with the first `num_mcmc` samples corresponding to chain 1, the next `num_mcmc` samples to chain 2, etc...\n", |
| 152 | + "\n", |
| 153 | + "Since each chain is a set of samples of the same model, we can analyze the samples collectively, for example, by looking at out-of-sample predictions." |
| 154 | + ] |
| 155 | + }, |
| 156 | + { |
| 157 | + "cell_type": "code", |
| 158 | + "execution_count": null, |
| 159 | + "id": "0cfd1a28", |
| 160 | + "metadata": {}, |
| 161 | + "outputs": [], |
| 162 | + "source": [ |
| 163 | + "y_hat_test = bart_model.predict(\n", |
| 164 | + " covariates = X_test,\n", |
| 165 | + " basis = leaf_basis_test, \n", |
| 166 | + " type = \"mean\", \n", |
| 167 | + " terms = \"y_hat\"\n", |
| 168 | + ")\n", |
| 169 | + "plt.scatter(y_hat_test, y_test)\n", |
| 170 | + "plt.xlabel(\"Estimated conditional mean\")\n", |
| 171 | + "plt.ylabel(\"Actual outcome\")\n", |
| 172 | + "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3, 3)))" |
| 173 | + ] |
| 174 | + }, |
| 175 | + { |
| 176 | + "cell_type": "markdown", |
| 177 | + "id": "89015a31", |
| 178 | + "metadata": {}, |
| 179 | + "source": [ |
| 180 | + "Now, suppose we want to analyze each of the chains separately to assess mixing / convergence.\n", |
| 181 | + "\n", |
| 182 | + "We can use our knowledge of the internal arrangement of the chain samples to construct a an `mcmc.list` in the `coda` package, from which we can perform various diagnostics." |
| 183 | + ] |
| 184 | + }, |
| 185 | + { |
| 186 | + "cell_type": "code", |
| 187 | + "execution_count": null, |
| 188 | + "id": "96cee0e4", |
| 189 | + "metadata": {}, |
| 190 | + "outputs": [], |
| 191 | + "source": [ |
| 192 | + "sigma2_samples = bart_model.global_var_samples\n", |
| 193 | + "sigma2_samples_by_chain = {\"sigma2\": np.reshape(sigma2_samples, (num_chains, num_mcmc))}\n", |
| 194 | + "az.plot_trace(sigma2_samples_by_chain)" |
| 195 | + ] |
| 196 | + }, |
| 197 | + { |
| 198 | + "cell_type": "code", |
| 199 | + "execution_count": null, |
| 200 | + "id": "08137cda", |
| 201 | + "metadata": {}, |
| 202 | + "outputs": [], |
| 203 | + "source": [ |
| 204 | + "az.ess(sigma2_samples_by_chain)" |
| 205 | + ] |
| 206 | + }, |
| 207 | + { |
| 208 | + "cell_type": "code", |
| 209 | + "execution_count": null, |
| 210 | + "id": "552ba09c", |
| 211 | + "metadata": {}, |
| 212 | + "outputs": [], |
| 213 | + "source": [ |
| 214 | + "az.rhat(sigma2_samples_by_chain)" |
| 215 | + ] |
| 216 | + }, |
| 217 | + { |
| 218 | + "cell_type": "code", |
| 219 | + "execution_count": null, |
| 220 | + "id": "2a0c65f0", |
| 221 | + "metadata": {}, |
| 222 | + "outputs": [], |
| 223 | + "source": [ |
| 224 | + "az.plot_autocorr(sigma2_samples_by_chain)" |
| 225 | + ] |
| 226 | + }, |
| 227 | + { |
| 228 | + "cell_type": "code", |
| 229 | + "execution_count": null, |
| 230 | + "id": "5a0659ff", |
| 231 | + "metadata": {}, |
| 232 | + "outputs": [], |
| 233 | + "source": [ |
| 234 | + "az.plot_violin(sigma2_samples_by_chain)" |
| 235 | + ] |
| 236 | + } |
| 237 | + ], |
| 238 | + "metadata": { |
| 239 | + "language_info": { |
| 240 | + "name": "python" |
| 241 | + } |
| 242 | + }, |
| 243 | + "nbformat": 4, |
| 244 | + "nbformat_minor": 5 |
| 245 | +} |
0 commit comments