diff --git a/.github/workflows/slow-api-test.yml b/.github/workflows/slow-api-test.yml new file mode 100644 index 00000000..308a5bf8 --- /dev/null +++ b/.github/workflows/slow-api-test.yml @@ -0,0 +1,72 @@ +name: Unit Tests and Slow Running API Integration Tests for R and Python + +on: + workflow_dispatch: + +jobs: + testing: + name: test-slow-api-combinations + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + + steps: + - name: Prevent conversion of line endings on Windows + if: startsWith(matrix.os, 'windows') + shell: pwsh + run: git config --global core.autocrlf false + + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: 'recursive' + + - name: Setup Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: "pip" + + - name: Set up openmp (macos) + # Set up openMP on MacOS since it doesn't ship with the apple clang compiler suite + if: matrix.os == 'macos-latest' + run: | + brew install libomp + + - name: Install Package with Relevant Dependencies + run: | + pip install --upgrade pip + pip install -r requirements.txt + pip install . + + - name: Run Pytest with Slow Running API Tests Enabled + run: | + pytest --runslow test/python + + - name: Setup Pandoc for R + uses: r-lib/actions/setup-pandoc@v2 + + - name: Setup R + uses: r-lib/actions/setup-r@v2 + with: + use-public-rspm: true + + - name: Setup R Package Dependencies + uses: r-lib/actions/setup-r-dependencies@v2 + with: + extra-packages: any::testthat, any::decor, any::rcmdcheck + needs: check + + - name: Create a CRAN-ready version of the R package + run: | + Rscript cran-bootstrap.R 0 0 1 + + - name: Run CRAN Checks with Slow Running API Tests Enabled + uses: r-lib/actions/check-r-package@v2 + env: + RUN_SLOW_TESTS: true + with: + working-directory: 'stochtree_cran' diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e3be9f5..af065869 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,19 +1,15 @@ # Changelog -# stochtree (development version) - -## New Features - -## Computational Improvements +# stochtree 0.2.1 ## Bug Fixes -* Predict random effects correctly in R for univariate random effects models ([#248](https://github.com/StochasticTree/stochtree/pull/248)) - -## Documentation Improvements +* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248)) ## Other Changes +* Encode expectations about which combinations of BART / BCF features work together and ensure warning ([#250](https://github.com/StochasticTree/stochtree/pull/250)) + # stochtree 0.2.0 ## New Features diff --git a/DESCRIPTION b/DESCRIPTION index aa7ae91a..02cc3bc6 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: stochtree Title: Stochastic Tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference -Version: 0.2.0.9000 +Version: 0.2.1 Authors@R: c( person("Drew", "Herren", email = "drewherrenopensource@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4109-6611")), diff --git a/Doxyfile b/Doxyfile index 4c8168ae..4230f847 100644 --- a/Doxyfile +++ b/Doxyfile @@ -48,7 +48,7 @@ PROJECT_NAME = "StochTree" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 0.2.0.9000 +PROJECT_NUMBER = 0.2.1 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/NEWS.md b/NEWS.md index 8844d3c1..644e6f72 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,17 +1,13 @@ -# stochtree (development version) - -## New Features - -## Computational Improvements +# stochtree 0.2.1 ## Bug Fixes -* Predict random effects correctly in R for univariate random effects models ([#248](https://github.com/StochasticTree/stochtree/pull/248)) - -## Documentation Improvements +* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248)) ## Other Changes +* Encode expectations about which combinations of BART / BCF features work together and ensure warning ([#250](https://github.com/StochasticTree/stochtree/pull/250)) + # stochtree 0.2.0 ## New Features diff --git a/R/bart.R b/R/bart.R index f5068f96..23fee012 100644 --- a/R/bart.R +++ b/R/bart.R @@ -835,6 +835,16 @@ bart <- function( } } + # Runtime checks for variance forest + if (include_variance_forest) { + if (sample_sigma2_global) { + warning( + "Global error variance will not be sampled with a heteroskedasticity forest" + ) + sample_sigma2_global <- F + } + } + # Handle standardization, prior calibration, and initialization of forest # differently for binary and continuous outcomes if (probit_outcome_model) { @@ -2124,7 +2134,6 @@ predict.bartmodel <- function( X <- preprocessPredictionData(X, train_set_metadata) # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) - has_rfx <- FALSE if (predict_rfx) { if (!is.null(rfx_group_ids)) { rfx_unique_group_ids <- object$rfx_unique_group_ids @@ -2135,7 +2144,6 @@ predict.bartmodel <- function( ) } rfx_group_ids <- as.integer(group_ids_factor) - has_rfx <- TRUE } } diff --git a/R/bcf.R b/R/bcf.R index 765347cb..c634295b 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -897,14 +897,7 @@ bcf <- function( # Handle multivariate treatment has_multivariate_treatment <- ncol(Z_train) > 1 if (has_multivariate_treatment) { - # Disable adaptive coding, internal propensity model, and - # leaf scale sampling if treatment is multivariate - if (adaptive_coding) { - warning( - "Adaptive coding is incompatible with multivariate treatment and will be ignored" - ) - adaptive_coding <- FALSE - } + # Disable internal propensity model and leaf scale sampling if treatment is multivariate if (is.null(propensity_train)) { if (propensity_covariate != "none") { warning( @@ -949,21 +942,31 @@ bcf <- function( } has_basis_rfx <- TRUE num_basis_rfx <- ncol(rfx_basis_train) - } else if (rfx_model_spec == "intercept_only") { - rfx_basis_train <- matrix( - rep(1, nrow(X_train)), - nrow = nrow(X_train), - ncol = 1 - ) - has_basis_rfx <- TRUE - num_basis_rfx <- 1 } else if (rfx_model_spec == "intercept_plus_treatment") { - rfx_basis_train <- cbind( - rep(1, nrow(X_train)), - Z_train - ) - has_basis_rfx <- TRUE - num_basis_rfx <- 1 + ncol(Z_train) + if (has_multivariate_treatment) { + warning( + "Random effects `intercept_plus_treatment` specification is not currently implemented for multivariate treatments. This model will be fit under the `intercept_only` specification instead. Please provide a custom `rfx_basis_train` if you wish to have random slopes on multivariate treatment variables." + ) + rfx_model_spec <- "intercept_only" + } + } + if (is.null(rfx_basis_train)) { + if (rfx_model_spec == "intercept_only") { + rfx_basis_train <- matrix( + rep(1, nrow(X_train)), + nrow = nrow(X_train), + ncol = 1 + ) + has_basis_rfx <- TRUE + num_basis_rfx <- 1 + } else { + rfx_basis_train <- cbind( + rep(1, nrow(X_train)), + Z_train + ) + has_basis_rfx <- TRUE + num_basis_rfx <- 1 + ncol(Z_train) + } } num_rfx_groups <- length(unique(rfx_group_ids_train)) num_rfx_components <- ncol(rfx_basis_train) @@ -1021,15 +1024,21 @@ bcf <- function( y_train <- as.matrix(y_train) } - # Check whether treatment is binary (specifically 0-1 binary) - binary_treatment <- length(unique(Z_train)) == 2 - if (binary_treatment) { - unique_treatments <- sort(unique(Z_train)) - if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE + # Check whether treatment is binary and univariate (specifically 0-1 binary) + binary_treatment <- FALSE + if (!has_multivariate_treatment) { + binary_treatment <- length(unique(Z_train)) == 2 + if (binary_treatment) { + unique_treatments <- sort(unique(Z_train)) + if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE + } } # Adaptive coding will be ignored for continuous / ordered categorical treatments if ((!binary_treatment) && (adaptive_coding)) { + warning( + "Adaptive coding is only compatible with binary (univariate) treatment and, as a result, will be ignored in sampling this model" + ) adaptive_coding <- FALSE } diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index dca34be3..be883198 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -409,19 +409,31 @@ compute_contrast_bart_model <- function( "rfx_group_ids_0 and rfx_group_ids_1 must be provided for this model" ) } - if ((has_rfx) && (is.null(rfx_basis_0) || is.null(rfx_basis_1))) { - stop( - "rfx_basis_0 and rfx_basis_1 must be provided for this model" - ) - } - if ( - (object$model_params$num_rfx_basis > 0) && - ((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) || - (ncol(rfx_basis_1) != object$model_params$num_rfx_basis)) - ) { - stop( - "rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model" - ) + if (has_rfx) { + if (object$model_params$rfx_model_spec == "custom") { + if ((is.null(rfx_basis_0) || is.null(rfx_basis_1))) { + stop( + "A user-provided basis (`rfx_basis_0` and `rfx_basis_1`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + } + if (!is.matrix(rfx_basis_0) || !is.matrix(rfx_basis_1)) { + stop("'rfx_basis_0' and 'rfx_basis_1' must be matrices") + } + if ((nrow(rfx_basis_0) != nrow(X)) || (nrow(rfx_basis_1) != nrow(X))) { + stop( + "'rfx_basis_0' and 'rfx_basis_1' must have the same number of rows as 'X'" + ) + } + if ( + (object$model_params$num_rfx_basis > 0) && + ((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) || + (ncol(rfx_basis_1) != object$model_params$num_rfx_basis)) + ) { + stop( + "rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model" + ) + } + } } # Predict for the control arm @@ -574,16 +586,22 @@ sample_bcf_posterior_predictive <- function( "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } - if (is.null(rfx_basis)) { - stop( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - } - if (!is.matrix(rfx_basis)) { - stop("'rfx_basis' must be a matrix") + + if (model_object$model_params$rfx_model_spec == "custom") { + if (is.null(rfx_basis)) { + stop( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + } } - if (nrow(rfx_basis) != nrow(X)) { - stop("'rfx_basis' must have the same number of rows as 'X'") + + if (!is.null(rfx_basis)) { + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") + } } } @@ -735,16 +753,18 @@ sample_bart_posterior_predictive <- function( "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } - if (is.null(rfx_basis)) { - stop( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - } - if (!is.matrix(rfx_basis)) { - stop("'rfx_basis' must be a matrix") - } - if (nrow(rfx_basis) != nrow(X)) { - stop("'rfx_basis' must have the same number of rows as 'X'") + if (model_object$model_params$rfx_model_spec == "custom") { + if (is.null(rfx_basis)) { + stop( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + } + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") + } } } @@ -1172,16 +1192,18 @@ compute_bart_posterior_interval <- function( "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) } - if (is.null(rfx_basis)) { - stop( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - } - if (!is.matrix(rfx_basis)) { - stop("'rfx_basis' must be a matrix") - } - if (nrow(rfx_basis) != nrow(X)) { - stop("'rfx_basis' must have the same number of rows as 'X'") + if (model_object$model_params$rfx_model_spec == "custom") { + if (is.null(rfx_basis)) { + stop( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + } + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(X)) { + stop("'rfx_basis' must have the same number of rows as 'X'") + } } } diff --git a/configure b/configure index d862d747..54bfbeb1 100755 --- a/configure +++ b/configure @@ -1,6 +1,6 @@ #! /bin/sh # Guess values for system-dependent variables and create Makefiles. -# Generated by GNU Autoconf 2.72 for stochtree 0.2.0.9000. +# Generated by GNU Autoconf 2.72 for stochtree 0.2.1. # # # Copyright (C) 1992-1996, 1998-2017, 2020-2023 Free Software Foundation, @@ -600,8 +600,8 @@ MAKEFLAGS= # Identity of this package. PACKAGE_NAME='stochtree' PACKAGE_TARNAME='stochtree' -PACKAGE_VERSION='0.2.0.9000' -PACKAGE_STRING='stochtree 0.2.0.9000' +PACKAGE_VERSION='0.2.1' +PACKAGE_STRING='stochtree 0.2.1' PACKAGE_BUGREPORT='' PACKAGE_URL='' @@ -1204,7 +1204,7 @@ if test "$ac_init_help" = "long"; then # Omit some internal or obsolete options to make the list less imposing. # This message is too long to be a string in the A/UX 3.1 sh. cat <<_ACEOF -'configure' configures stochtree 0.2.0.9000 to adapt to many kinds of systems. +'configure' configures stochtree 0.2.1 to adapt to many kinds of systems. Usage: $0 [OPTION]... [VAR=VALUE]... @@ -1266,7 +1266,7 @@ fi if test -n "$ac_init_help"; then case $ac_init_help in - short | recursive ) echo "Configuration of stochtree 0.2.0.9000:";; + short | recursive ) echo "Configuration of stochtree 0.2.1:";; esac cat <<\_ACEOF @@ -1334,7 +1334,7 @@ fi test -n "$ac_init_help" && exit $ac_status if $ac_init_version; then cat <<\_ACEOF -stochtree configure 0.2.0.9000 +stochtree configure 0.2.1 generated by GNU Autoconf 2.72 Copyright (C) 2023 Free Software Foundation, Inc. @@ -1371,7 +1371,7 @@ cat >config.log <<_ACEOF This file contains any messages produced by compilers while running configure, to aid debugging if configure makes a mistake. -It was created by stochtree $as_me 0.2.0.9000, which was +It was created by stochtree $as_me 0.2.1, which was generated by GNU Autoconf 2.72. Invocation command line was $ $0$ac_configure_args_raw @@ -2380,7 +2380,7 @@ cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 # report actual input values of CONFIG_FILES etc. instead of their # values after options handling. ac_log=" -This file was extended by stochtree $as_me 0.2.0.9000, which was +This file was extended by stochtree $as_me 0.2.1, which was generated by GNU Autoconf 2.72. Invocation command line was CONFIG_FILES = $CONFIG_FILES @@ -2435,7 +2435,7 @@ ac_cs_config_escaped=`printf "%s\n" "$ac_cs_config" | sed "s/^ //; s/'/'\\\\\\\\ cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 ac_cs_config='$ac_cs_config_escaped' ac_cs_version="\\ -stochtree config.status 0.2.0.9000 +stochtree config.status 0.2.1 configured by $0, generated by GNU Autoconf 2.72, with options \\"\$ac_cs_config\\" diff --git a/configure.ac b/configure.ac index 3d1143ba..33505e2c 100644 --- a/configure.ac +++ b/configure.ac @@ -3,7 +3,7 @@ # https://github.com/microsoft/LightGBM/blob/master/R-package/configure.ac AC_PREREQ(2.69) -AC_INIT([stochtree], [0.2.0.9000], [], [stochtree], []) +AC_INIT([stochtree], [0.2.1], [], [stochtree], []) # Note: consider making version number dynamic as in # https://github.com/microsoft/LightGBM/blob/195c26fc7b00eb0fec252dfe841e2e66d6833954/build-cran-package.sh diff --git a/pyproject.toml b/pyproject.toml index 0fe8a12a..2e992666 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta" [project] name = "stochtree" -version = "0.2.0-dev" +version = "0.2.1" dynamic = ["readme", "optional-dependencies", "license"] description = "Stochastic Tree Ensembles for Machine Learning and Causal Inference" requires-python = ">=3.8.0" diff --git a/stochtree/bart.py b/stochtree/bart.py index b7bf1c88..832d9d00 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -70,15 +70,15 @@ def __init__(self) -> None: def sample( self, - X_train: Union[np.array, pd.DataFrame], - y_train: np.array, - leaf_basis_train: np.array = None, - rfx_group_ids_train: np.array = None, - rfx_basis_train: np.array = None, - X_test: Union[np.array, pd.DataFrame] = None, - leaf_basis_test: np.array = None, - rfx_group_ids_test: np.array = None, - rfx_basis_test: np.array = None, + X_train: Union[np.ndarray, pd.DataFrame], + y_train: np.ndarray, + leaf_basis_train: Optional[np.ndarray] = None, + rfx_group_ids_train: Optional[np.ndarray] = None, + rfx_basis_train: Optional[np.ndarray] = None, + X_test: Optional[Union[np.ndarray, pd.DataFrame]] = None, + leaf_basis_test: Optional[np.ndarray] = None, + rfx_group_ids_test: Optional[np.ndarray] = None, + rfx_basis_test: Optional[np.ndarray] = None, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, @@ -859,6 +859,13 @@ def sample( if num_features_subsample_variance is None: num_features_subsample_variance = X_train.shape[1] + # Runtime check for multivariate leaf regression + if sample_sigma2_leaf and self.num_basis > 1: + warnings.warn( + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." + ) + sample_sigma2_leaf = False + # Preliminary runtime checks for probit link if not self.include_mean_forest: self.probit_outcome_model = False @@ -872,13 +879,21 @@ def sample( raise ValueError( "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" ) + if sample_sigma2_global: + warnings.warn( + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) + sample_sigma2_global = False if self.include_variance_forest: raise ValueError( "We do not support heteroskedasticity with a probit link" ) + + # Runtime checks for variance forest + if self.include_variance_forest: if sample_sigma2_global: warnings.warn( - "Global error variance will not be sampled with a probit link as it is fixed at 1" + "Sampling global error variance not yet supported for models with variance forests, so the global error variance parameter will not be sampled in this model." ) sample_sigma2_global = False @@ -1217,7 +1232,7 @@ def sample( else: leaf_model_mean_forest = 2 leaf_dimension_mean = self.num_basis - + # Sampling data structures global_model_config = GlobalModelConfig(global_error_variance=current_sigma2) if self.include_mean_forest: @@ -1900,6 +1915,9 @@ def predict( if leaf_basis is not None: if leaf_basis.ndim == 1: leaf_basis = np.expand_dims(leaf_basis, 1) + if rfx_basis is not None: + if rfx_basis.ndim == 1: + rfx_basis = np.expand_dims(rfx_basis, 1) # Covariate preprocessing if not self._covariate_preprocessor._check_is_fitted(): @@ -1958,21 +1976,18 @@ def predict( mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar # Random effects data checks - if has_rfx: - if rfx_group_ids is None: - raise ValueError( - "rfx_group_ids must be provided if rfx_basis is provided" - ) - if rfx_basis is not None: - if rfx_basis.ndim == 1: - rfx_basis = np.expand_dims(rfx_basis, 1) - if rfx_basis.shape[0] != X.shape[0]: - raise ValueError("X and rfx_basis must have the same number of rows") + if predict_rfx and rfx_group_ids is None: + raise ValueError( + "Random effect group labels (rfx_group_ids) must be provided for this model" + ) + if predict_rfx and rfx_basis is None and not rfx_intercept: + raise ValueError("Random effects basis (rfx_basis) must be provided for this model") + if self.num_rfx_basis > 0 and not rfx_intercept: if rfx_basis.shape[1] != self.num_rfx_basis: raise ValueError( - "rfx_basis must have the same number of columns as the random effects basis used to sample this model" + "Random effects basis has a different dimension than the basis used to train this model" ) - + # Random effects predictions if predict_rfx or predict_rfx_intermediate: if rfx_basis is not None: @@ -1983,7 +1998,7 @@ def predict( # Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only" if not rfx_intercept: raise ValueError( - "rfx_basis must be provided for random effects models with random slopes" + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" ) # Extract the raw RFX samples and scale by train set outcome standard deviation @@ -2321,16 +2336,18 @@ def compute_posterior_interval( raise ValueError( "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) - if rfx_basis is None: - raise ValueError( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - if not isinstance(rfx_basis, np.ndarray): - raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != X.shape[0]: - raise ValueError( - "'rfx_basis' must have the same number of rows as 'X'" - ) + if self.rfx_model_spec == "custom": + if rfx_basis is None: + raise ValueError( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + if rfx_basis is not None: + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != X.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'X'" + ) # Compute posterior matrices for the requested model terms predictions = self.predict( @@ -2427,16 +2444,18 @@ def sample_posterior_predictive( raise ValueError( "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) - if rfx_basis is None: - raise ValueError( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - if not isinstance(rfx_basis, np.ndarray): - raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != X.shape[0]: - raise ValueError( - "'rfx_basis' must have the same number of rows as 'X'" - ) + if self.rfx_model_spec == "custom": + if rfx_basis is None: + raise ValueError( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + if rfx_basis is not None: + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != X.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'X'" + ) # Compute posterior predictive samples bart_preds = self.predict( diff --git a/stochtree/bcf.py b/stochtree/bcf.py index ac98fdbb..dfe0610e 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1320,25 +1320,38 @@ def sample( self.p_x = X_train_processed.shape[1] # Check whether treatment is binary - self.binary_treatment = np.unique(Z_train).size == 2 + self.binary_treatment = False + if not self.multivariate_treatment: + self.binary_treatment = np.unique(Z_train).size == 2 + if self.binary_treatment: + unique_treatments = np.squeeze(np.unique(Z_train)).tolist() + if not all(i in [0,1] for i in unique_treatments): + self.binary_treatment = False # Adaptive coding will be ignored for continuous / ordered categorical treatments self.adaptive_coding = adaptive_coding if adaptive_coding and not self.binary_treatment: - self.adaptive_coding = False - if adaptive_coding and self.multivariate_treatment: + warnings.warn( + "Adaptive coding is only compatible with binary (univariate) treatment and, as a result, will be ignored in sampling this model" + ) self.adaptive_coding = False # Sampling sigma2_leaf_tau will be ignored for multivariate treatments if sample_sigma2_leaf_tau and self.multivariate_treatment: + warnings.warn( + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model." + ) sample_sigma2_leaf_tau = False # Check if user has provided propensities that are needed in the model if propensity_train is None and propensity_covariate != "none": + # Disable internal propensity model if treatment is multivariate if self.multivariate_treatment: - raise ValueError( - "Propensities must be provided (via propensity_train and / or propensity_test parameters) or omitted by setting propensity_covariate = 'none' for multivariate treatments" + warnings.warn( + "No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'" ) + propensity_covariate = "none" + self.internal_propensity_model = True else: self.bart_propensity_model = BARTModel() num_gfr_propensity = 10 @@ -1373,6 +1386,64 @@ def sample( self.internal_propensity_model = True else: self.internal_propensity_model = False + + # Runtime checks on RFX group ids + self.has_rfx = False + has_rfx_test = False + if rfx_group_ids_train is not None: + self.has_rfx = True + if rfx_group_ids_test is not None: + has_rfx_test = True + if not np.all(np.isin(rfx_group_ids_test, rfx_group_ids_train)): + raise ValueError( + "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" + ) + + # Handle the rfx basis matrices + self.has_rfx_basis = False + self.num_rfx_basis = 0 + if self.has_rfx: + if self.rfx_model_spec == "custom": + if rfx_basis_train is None: + raise ValueError( + "rfx_basis_train must be provided when rfx_model_spec = 'custom'" + ) + elif self.rfx_model_spec == "intercept_plus_treatment": + if self.multivariate_treatment: + warnings.warn( + "Random effects `intercept_plus_treatment` specification is not currently implemented for multivariate treatments. This model will be fit under the `intercept_only` specification instead. Please provide a custom `rfx_basis_train` if you wish to have random slopes on multivariate treatment variables." + ) + self.rfx_model_spec = "intercept_only" + if rfx_basis_train is None: + if self.rfx_model_spec == "intercept_only": + rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) + else: + rfx_basis_train = np.concatenate( + (np.ones((rfx_group_ids_train.shape[0], 1)), Z_train), axis=1 + ) + + self.has_rfx_basis = True + self.num_rfx_basis = rfx_basis_train.shape[1] + num_rfx_groups = np.unique(rfx_group_ids_train).shape[0] + num_rfx_components = rfx_basis_train.shape[1] + if num_rfx_groups == 1: + warnings.warn( + "Only one group was provided for random effect sampling, so the random effects model is likely overkill" + ) + if has_rfx_test: + if self.rfx_model_spec == "custom": + if rfx_basis_test is None: + raise ValueError( + "rfx_basis_test must be provided when rfx_model_spec = 'custom' and a test set is provided" + ) + elif self.rfx_model_spec == "intercept_only": + if rfx_basis_test is None: + rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) + elif self.rfx_model_spec == "intercept_plus_treatment": + if rfx_basis_test is None: + rfx_basis_test = np.concatenate( + (np.ones((rfx_group_ids_test.shape[0], 1)), Z_test), axis=1 + ) # Preliminary runtime checks for probit link if self.probit_outcome_model: @@ -1385,13 +1456,21 @@ def sample( raise ValueError( "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" ) + if sample_sigma2_global: + warnings.warn( + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) + sample_sigma2_global = False if self.include_variance_forest: raise ValueError( "We do not support heteroskedasticity with a probit link" ) + + # Runtime checks for variance forest + if self.include_variance_forest: if sample_sigma2_global: warnings.warn( - "Global error variance will not be sampled with a probit link as it is fixed at 1" + "Sampling global error variance not yet supported for models with variance forests, so the global error variance parameter will not be sampled in this model." ) sample_sigma2_global = False @@ -1550,58 +1629,6 @@ def sample( if not b_forest: b_forest = 1.0 - # Runtime checks on RFX group ids - self.has_rfx = False - has_rfx_test = False - if rfx_group_ids_train is not None: - self.has_rfx = True - if rfx_group_ids_test is not None: - has_rfx_test = True - if not np.all(np.isin(rfx_group_ids_test, rfx_group_ids_train)): - raise ValueError( - "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" - ) - - # Handle the rfx basis matrices - self.has_rfx_basis = False - self.num_rfx_basis = 0 - if self.has_rfx: - if self.rfx_model_spec == "custom": - if rfx_basis_train is None: - raise ValueError( - "rfx_basis_train must be provided when rfx_model_spec = 'custom'" - ) - elif self.rfx_model_spec == "intercept_only": - if rfx_basis_train is None: - rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) - elif self.rfx_model_spec == "intercept_plus_treatment": - if rfx_basis_train is None: - rfx_basis_train = np.concatenate( - (np.ones((rfx_group_ids_train.shape[0], 1)), Z_train), axis=1 - ) - self.has_rfx_basis = True - self.num_rfx_basis = rfx_basis_train.shape[1] - num_rfx_groups = np.unique(rfx_group_ids_train).shape[0] - num_rfx_components = rfx_basis_train.shape[1] - if num_rfx_groups == 1: - warnings.warn( - "Only one group was provided for random effect sampling, so the random effects model is likely overkill" - ) - if has_rfx_test: - if self.rfx_model_spec == "custom": - if rfx_basis_test is None: - raise ValueError( - "rfx_basis_test must be provided when rfx_model_spec = 'custom' and a test set is provided" - ) - elif self.rfx_model_spec == "intercept_only": - if rfx_basis_test is None: - rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) - elif self.rfx_model_spec == "intercept_plus_treatment": - if rfx_basis_test is None: - rfx_basis_test = np.concatenate( - (np.ones((rfx_group_ids_test.shape[0], 1)), Z_test), axis=1 - ) - # Set up random effects structures if self.has_rfx: # Prior parameters @@ -3570,14 +3597,18 @@ def sample_posterior_predictive( raise ValueError( "'rfx_group_ids' must have the same length as the number of rows in 'X'" ) - if rfx_basis is None: - raise ValueError( - "'rfx_basis' must be provided in order to compute the requested intervals" - ) - if not isinstance(rfx_basis, np.ndarray): - raise ValueError("'rfx_basis' must be a numpy array") - if rfx_basis.shape[0] != X.shape[0]: - raise ValueError("'rfx_basis' must have the same number of rows as 'X'") + if self.rfx_model_spec == "custom": + if rfx_basis is None: + raise ValueError( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + if rfx_basis is not None: + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != X.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'X'" + ) # Compute posterior predictive samples bcf_preds = self.predict( diff --git a/test/R/testthat/test-api-combinations.R b/test/R/testthat/test-api-combinations.R new file mode 100644 index 00000000..d2c2f2ce --- /dev/null +++ b/test/R/testthat/test-api-combinations.R @@ -0,0 +1,810 @@ +run_bart_factorial <- function( + bart_data_train, + bart_data_test, + leaf_reg = "none", + variance_forest = FALSE, + random_effects = "none", + sampling_global_error_scale = FALSE, + sampling_leaf_scale = FALSE, + outcome_type = "continuous", + num_chains = 1 +) { + # Unpack BART training data + y <- bart_data_train[["y"]] + X <- bart_data_train[["X"]] + if (leaf_reg != "none") { + leaf_basis <- bart_data_train[["leaf_basis"]] + } else { + leaf_basis <- NULL + } + if (random_effects != "none") { + rfx_group_ids <- bart_data_train[["rfx_group_ids"]] + } else { + rfx_group_ids <- NULL + } + if (random_effects == "custom") { + rfx_basis <- bart_data_train[["rfx_basis"]] + } else { + rfx_basis <- NULL + } + + # Set BART model parameters + general_params <- list( + num_chains = num_chains, + sample_sigma2_global = sampling_global_error_scale, + probit_outcome_model = outcome_type == "binary" + ) + mean_forest_params <- list( + sample_sigma2_leaf = sampling_leaf_scale + ) + variance_forest_params <- list( + num_trees = ifelse(variance_forest, 20, 0) + ) + rfx_params <- list( + model_spec = ifelse(random_effects == "none", "custom", random_effects) + ) + + # Sample BART model + bart_model <- stochtree::bart( + X_train = X, + y_train = y, + leaf_basis_train = leaf_basis, + rfx_group_ids_train = rfx_group_ids, + rfx_basis_train = rfx_basis, + general_params = general_params, + mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params, + random_effects_params = rfx_params + ) + + # Unpack test set data + y_test <- bart_data_test[["y"]] + X_test <- bart_data_test[["X"]] + if (leaf_reg != "none") { + leaf_basis_test <- bart_data_test[["leaf_basis"]] + } else { + leaf_basis_test <- NULL + } + if (random_effects != "none") { + rfx_group_ids_test <- bart_data_test[["rfx_group_ids"]] + } else { + rfx_group_ids_test <- NULL + } + if (random_effects == "custom") { + rfx_basis_test <- bart_data_test[["rfx_basis"]] + } else { + rfx_basis_test <- NULL + } + + # Predict on test set + mean_preds <- predict( + bart_model, + X = X_test, + leaf_basis = leaf_basis_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "mean", + terms = "all", + scale = ifelse(outcome_type == "binary", "probability", "linear") + ) + posterior_preds <- predict( + bart_model, + X = X_test, + leaf_basis = leaf_basis_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "all", + scale = ifelse(outcome_type == "binary", "probability", "linear") + ) + + # Compute intervals + posterior_interval <- compute_bart_posterior_interval( + bart_model, + terms = "all", + level = 0.95, + scale = ifelse(outcome_type == "binary", "probability", "linear"), + X = X_test, + leaf_basis = leaf_basis_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test + ) + + # Sample posterior predictive + posterior_predictive_draws <- sample_bart_posterior_predictive( + bart_model, + X = X_test, + leaf_basis = leaf_basis_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + num_draws_per_sample = 5 + ) +} + +run_bcf_factorial <- function( + bcf_data_train, + bcf_data_test, + treatment_type = "binary", + variance_forest = FALSE, + random_effects = "none", + sampling_global_error_scale = FALSE, + sampling_mu_leaf_scale = FALSE, + sampling_tau_leaf_scale = FALSE, + outcome_type = "continuous", + num_chains = 1, + adaptive_coding = TRUE, + include_propensity = TRUE +) { + # Unpack BART training data + y <- bcf_data_train[["y"]] + X <- bcf_data_train[["X"]] + Z <- bcf_data_train[["Z"]] + if (include_propensity) { + propensity_train <- bcf_data_train[["propensity"]] + } else { + propensity_train <- NULL + } + if (random_effects != "none") { + rfx_group_ids <- bcf_data_train[["rfx_group_ids"]] + } else { + rfx_group_ids <- NULL + } + if (random_effects == "custom") { + rfx_basis <- bcf_data_train[["rfx_basis"]] + } else { + rfx_basis <- NULL + } + + # Set BART model parameters + general_params <- list( + num_chains = num_chains, + sample_sigma2_global = sampling_global_error_scale, + probit_outcome_model = outcome_type == "binary", + adaptive_coding = adaptive_coding + ) + mu_forest_params <- list( + sample_sigma2_leaf = sampling_mu_leaf_scale + ) + tau_forest_params <- list( + sample_sigma2_leaf = sampling_tau_leaf_scale + ) + variance_forest_params <- list( + num_trees = ifelse(variance_forest, 20, 0) + ) + rfx_params <- list( + model_spec = ifelse(random_effects == "none", "custom", random_effects) + ) + + # Sample BART model + bcf_model <- stochtree::bcf( + X_train = X, + y_train = y, + Z_train = Z, + propensity_train = propensity_train, + rfx_group_ids_train = rfx_group_ids, + rfx_basis_train = rfx_basis, + general_params = general_params, + prognostic_forest_params = mu_forest_params, + treatment_effect_forest_params = tau_forest_params, + variance_forest_params = variance_forest_params, + random_effects_params = rfx_params + ) + + # Unpack test set data + y_test <- bcf_data_test[["y"]] + X_test <- bcf_data_test[["X"]] + Z_test <- bcf_data_test[["Z"]] + if (include_propensity) { + propensity_test <- bcf_data_test[["propensity"]] + } else { + propensity_test <- NULL + } + if (random_effects != "none") { + rfx_group_ids_test <- bcf_data_test[["rfx_group_ids"]] + } else { + rfx_group_ids_test <- NULL + } + if (random_effects == "custom") { + rfx_basis_test <- bcf_data_test[["rfx_basis"]] + } else { + rfx_basis_test <- NULL + } + + # Predict on test set + mean_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = propensity_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "mean", + terms = "all", + scale = ifelse(outcome_type == "binary", "probability", "linear") + ) + posterior_preds <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = propensity_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "all", + scale = ifelse(outcome_type == "binary", "probability", "linear") + ) + + # Compute intervals + posterior_interval <- compute_bcf_posterior_interval( + bcf_model, + terms = "all", + level = 0.95, + scale = ifelse(outcome_type == "binary", "probability", "linear"), + X = X_test, + Z = Z_test, + propensity = propensity_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test + ) + + # Sample posterior predictive + posterior_predictive_draws <- sample_bcf_posterior_predictive( + bcf_model, + X = X_test, + Z = Z_test, + propensity = propensity_test, + rfx_group_ids = rfx_group_ids_test, + rfx_basis = rfx_basis_test, + num_draws_per_sample = 5 + ) +} + +# Construct chained expectations without writing out every combination of function calls +construct_chained_expectation_bart <- function( + error_cond, + warning_cond_1, + warning_cond_2, + warning_cond_3 +) { + # Build the chain from innermost to outermost + function_text <- "x" + if (warning_cond_1) { + function_text <- paste0( + "warning_fun_1(", + function_text, + ")" + ) + } + if (warning_cond_2) { + function_text <- paste0( + "warning_fun_2(", + function_text, + ")" + ) + } + if (warning_cond_3) { + function_text <- paste0( + "warning_fun_3(", + function_text, + ")" + ) + } + if (error_cond) { + function_text <- paste0( + "expect_error(", + function_text, + ")" + ) + } + return(as.function( + c(alist(x = ), parse(text = function_text)[[1]]), + envir = parent.frame() + )) +} + +construct_chained_expectation_bcf <- function( + error_cond, + warning_cond_1, + warning_cond_2, + warning_cond_3, + warning_cond_4, + warning_cond_5, + warning_cond_6 +) { + # Build the chain from innermost to outermost + function_text <- "x" + if (warning_cond_1) { + function_text <- paste0( + "warning_fun_1(", + function_text, + ")" + ) + } + if (warning_cond_2) { + function_text <- paste0( + "warning_fun_2(", + function_text, + ")" + ) + } + if (warning_cond_3) { + function_text <- paste0( + "warning_fun_3(", + function_text, + ")" + ) + } + if (warning_cond_4) { + function_text <- paste0( + "warning_fun_4(", + function_text, + ")" + ) + } + if (warning_cond_5) { + function_text <- paste0( + "warning_fun_5(", + function_text, + ")" + ) + } + if (warning_cond_6) { + function_text <- paste0( + "warning_fun_6(", + function_text, + ")" + ) + } + if (error_cond) { + function_text <- paste0( + "expect_error(", + function_text, + ")" + ) + } + return(as.function( + c(alist(x = ), parse(text = function_text)[[1]]), + envir = parent.frame() + )) +} + +test_that("Quick check of interactions between components of BART functionality", { + skip_on_cran() + # Code from: https://github.com/r-lib/testthat/blob/main/R/skip.R + skip_if( + isTRUE(as.logical(Sys.getenv("RUN_SLOW_TESTS", "false"))), + "skipping slow tests" + ) + + # Overall, we have seven components of a BART sampler which can be on / off or set to different levels: + # 1. Leaf regression: none, univariate, multivariate + # 2. Variance forest: no, yes + # 3. Random effects: no, custom basis, `intercept_only` + # 4. Sampling global error scale: no, yes + # 5. Sampling leaf scale on mean forest: no, yes (only available for constant leaf or univariate leaf regression) + # 6. Outcome type: continuous (identity link), binary (probit link) + # 7. Number of chains: 1, >1 + # + # For each of the possible models this implies, + # we'd like to be sure that stochtree functions that operate on BART models + # will run without error. Since there are so many possible models implied by the + # options above, this test is designed to be quick (small sample size, low dimensional data) + # and we are only interested in ensuring no errors are triggered. + + # Generate data with random effects + n <- 50 + p <- 3 + num_basis <- 2 + num_rfx_groups <- 3 + num_rfx_basis <- 2 + X <- matrix(runif(n * p), ncol = p) + leaf_basis <- matrix(runif(n * num_basis), ncol = num_basis) + leaf_coefs <- runif(num_basis) + group_ids <- sample(1:num_rfx_groups, n, replace = T) + rfx_basis <- matrix(runif(n * num_rfx_basis), ncol = num_rfx_basis) + rfx_coefs <- matrix( + runif(num_rfx_groups * num_rfx_basis), + ncol = num_rfx_basis + ) + mean_term <- sin(X[, 1]) * rowSums(leaf_basis * leaf_coefs) + rfx_term <- rowSums(rfx_coefs[group_ids, ] * rfx_basis) + E_y <- mean_term + rfx_term + E_y <- E_y - mean(E_y) + epsilon <- rnorm(n, 0, 1) + y_continuous <- E_y + epsilon + y_binary <- 1 * (y_continuous > 0) + + # Split into test and train sets + test_set_pct <- 0.5 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + leaf_basis_test <- leaf_basis[test_inds, ] + leaf_basis_train <- leaf_basis[train_inds, ] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + group_ids_test <- group_ids[test_inds] + group_ids_train <- group_ids[train_inds] + y_continuous_test <- y_continuous[test_inds] + y_continuous_train <- y_continuous[train_inds] + y_binary_test <- y_binary[test_inds] + y_binary_train <- y_binary[train_inds] + + # Run the power set of models + leaf_reg_options <- c("none", "univariate", "multivariate") + variance_forest_options <- c(FALSE, TRUE) + random_effects_options <- c("none", "custom", "intercept_only") + sampling_global_error_scale_options <- c(FALSE, TRUE) + sampling_leaf_scale_options <- c(FALSE, TRUE) + outcome_type_options <- c("continuous", "binary") + num_chains_options <- c(1, 3) + model_options_df <- expand.grid( + leaf_reg = leaf_reg_options, + variance_forest = variance_forest_options, + random_effects = random_effects_options, + sampling_global_error_scale = sampling_global_error_scale_options, + sampling_leaf_scale = sampling_leaf_scale_options, + outcome_type = outcome_type_options, + num_chains = num_chains_options, + stringsAsFactors = FALSE + ) + for (i in 1:nrow(model_options_df)) { + # Determine which errors and warnings should be triggered + error_cond <- (model_options_df$variance_forest[i]) && + (model_options_df$outcome_type[i] == "binary") + warning_cond_1 <- (model_options_df$sampling_leaf_scale[i]) && + (model_options_df$leaf_reg[i] == "multivariate") + warning_fun_1 <- function(x) { + expect_warning( + x, + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." + ) + } + warning_cond_2 <- (model_options_df$sampling_global_error_scale[i]) && + (model_options_df$outcome_type[i] == "binary") + warning_fun_2 <- function(x) { + expect_warning( + x, + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) + } + warning_cond_3 <- (model_options_df$sampling_global_error_scale[i]) && + (model_options_df$variance_forest[i]) + warning_fun_3 <- function(x) { + expect_warning( + x, + "Global error variance will not be sampled with a heteroskedasticity" + ) + } + warning_cond <- warning_cond_1 || warning_cond_2 || warning_cond_3 + + if (error_cond || warning_cond) { + test_fun <- construct_chained_expectation_bart( + error_cond = error_cond, + warning_cond_1 = warning_cond_1, + warning_cond_2 = warning_cond_2, + warning_cond_3 = warning_cond_3 + ) + } else { + test_fun <- expect_no_error + } + + # Prepare test function arguments + bart_data_train <- list(X = X_train) + bart_data_test <- list(X = X_test) + if (model_options_df$outcome_type[i] == "continuous") { + bart_data_train[["y"]] <- y_continuous_train + bart_data_test[["y"]] <- y_continuous_test + } else { + bart_data_train[["y"]] <- y_binary_train + bart_data_test[["y"]] <- y_binary_test + } + if (model_options_df$leaf_reg[i] != "none") { + if (model_options_df$leaf_reg[i] == "univariate") { + bart_data_train[["leaf_basis"]] <- leaf_basis_train[, 1, drop = FALSE] + bart_data_test[["leaf_basis"]] <- leaf_basis_test[, 1, drop = FALSE] + } else { + bart_data_train[["leaf_basis"]] <- leaf_basis_train + bart_data_test[["leaf_basis"]] <- leaf_basis_test + } + } else { + bart_data_train[["leaf_basis"]] <- NULL + bart_data_test[["leaf_basis"]] <- NULL + } + if (model_options_df$random_effects[i] != "none") { + bart_data_train[["rfx_group_ids"]] <- group_ids_train + bart_data_test[["rfx_group_ids"]] <- group_ids_test + } else { + bart_data_train[["rfx_group_ids"]] <- NULL + bart_data_test[["rfx_group_ids"]] <- NULL + } + if (model_options_df$random_effects[i] == "custom") { + bart_data_train[["rfx_basis"]] <- rfx_basis_train + bart_data_test[["rfx_basis"]] <- rfx_basis_test + } else { + bart_data_train[["rfx_basis"]] <- NULL + bart_data_test[["rfx_basis"]] <- NULL + } + + # Apply testthat expectation(s) + test_fun({ + run_bart_factorial( + bart_data_train = bart_data_train, + bart_data_test = bart_data_test, + leaf_reg = model_options_df$leaf_reg[i], + variance_forest = model_options_df$variance_forest[i], + random_effects = model_options_df$random_effects[i], + sampling_global_error_scale = model_options_df$sampling_global_error_scale[ + i + ], + sampling_leaf_scale = model_options_df$sampling_leaf_scale[ + i + ], + outcome_type = model_options_df$outcome_type[i], + num_chains = model_options_df$num_chains[i] + ) + }) + } +}) + +test_that("Quick check of interactions between components of BCF functionality", { + skip_on_cran() + # Code from: https://github.com/r-lib/testthat/blob/main/R/skip.R + skip_if( + isTRUE(as.logical(Sys.getenv("RUN_SLOW_TESTS", "false"))), + "skipping slow tests" + ) + + # Overall, we have nine components of a BCF sampler which can be on / off or set to different levels: + # 1. treatment: binary, univariate continuous, multivariate + # 2. Variance forest: no, yes + # 3. Random effects: no, custom basis, `intercept_only`, `intercept_plus_treatment` + # 4. Sampling global error scale: no, yes + # 5. Sampling leaf scale on prognostic forest: no, yes + # 6. Sampling leaf scale on treatment forest: no, yes (only available for univariate treatment) + # 7. Outcome type: continuous (identity link), binary (probit link) + # 8. Number of chains: 1, >1 + # 9. Adaptive coding: no, yes + # + # For each of the possible models this implies, + # we'd like to be sure that stochtree functions that operate on BCF models + # will run without error. Since there are so many possible models implied by the + # options above, this test is designed to be quick (small sample size, low dimensional data) + # and we are only interested in ensuring no errors are triggered. + + # Generate data with random effects + n <- 50 + p <- 3 + num_rfx_groups <- 3 + num_rfx_basis <- 2 + X <- matrix(runif(n * p), ncol = p) + binary_treatment <- rbinom(n, 1, 0.5) + continuous_treatment <- runif(n, 0, 1) + multivariate_treatment <- cbind( + binary_treatment, + continuous_treatment + ) + group_ids <- sample(1:num_rfx_groups, n, replace = T) + rfx_basis <- matrix(runif(n * num_rfx_basis), ncol = num_rfx_basis) + rfx_coefs <- matrix( + runif(num_rfx_groups * num_rfx_basis), + ncol = num_rfx_basis + ) + propensity <- runif(n) + prognostic_term <- sin(X[, 1]) + binary_treatment_effect <- X[, 2] + continuous_treatment_effect <- X[, 3] + rfx_term <- rowSums(rfx_coefs[group_ids, ] * rfx_basis) + E_y <- prognostic_term + + binary_treatment_effect * binary_treatment + + continuous_treatment_effect * continuous_treatment + + rfx_term + E_y <- E_y - mean(E_y) + epsilon <- rnorm(n, 0, 1) + y_continuous <- E_y + epsilon + y_binary <- 1 * (y_continuous > 0) + + # Split into test and train sets + test_set_pct <- 0.5 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + binary_treatment_test <- binary_treatment[test_inds] + binary_treatment_train <- binary_treatment[train_inds] + propensity_test <- propensity[test_inds] + propensity_train <- propensity[train_inds] + continuous_treatment_test <- continuous_treatment[test_inds] + continuous_treatment_train <- continuous_treatment[train_inds] + multivariate_treatment_test <- multivariate_treatment[test_inds, ] + multivariate_treatment_train <- multivariate_treatment[train_inds, ] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + group_ids_test <- group_ids[test_inds] + group_ids_train <- group_ids[train_inds] + y_continuous_test <- y_continuous[test_inds] + y_continuous_train <- y_continuous[train_inds] + y_binary_test <- y_binary[test_inds] + y_binary_train <- y_binary[train_inds] + + # Run the power set of models + treatment_options <- c("binary", "univariate_continuous", "multivariate") + variance_forest_options <- c(FALSE, TRUE) + random_effects_options <- c( + "none", + "custom", + "intercept_only", + "intercept_plus_treatment" + ) + sampling_global_error_scale_options <- c(FALSE, TRUE) + sampling_mu_leaf_scale_options <- c(FALSE, TRUE) + sampling_tau_leaf_scale_options <- c(FALSE, TRUE) + outcome_type_options <- c("continuous", "binary") + num_chains_options <- c(1, 3) + adaptive_coding_options <- c(FALSE, TRUE) + include_propensity_options <- c(FALSE, TRUE) + model_options_df <- expand.grid( + treatment_type = treatment_options, + variance_forest = variance_forest_options, + random_effects = random_effects_options, + sampling_global_error_scale = sampling_global_error_scale_options, + sampling_mu_leaf_scale = sampling_mu_leaf_scale_options, + sampling_tau_leaf_scale = sampling_tau_leaf_scale_options, + outcome_type = outcome_type_options, + num_chains = num_chains_options, + adaptive_coding = adaptive_coding_options, + include_propensity = include_propensity_options, + stringsAsFactors = FALSE + ) + for (i in 1:nrow(model_options_df)) { + # Determine which errors and warnings should be triggered + error_cond <- (model_options_df$variance_forest[i]) && + (model_options_df$outcome_type[i] == "binary") + warning_cond_1 <- (model_options_df$sampling_tau_leaf_scale[i]) && + (model_options_df$treatment_type[i] == "multivariate") + warning_fun_1 <- function(x) { + expect_warning( + x, + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.", + fixed = TRUE + ) + } + warning_cond_2 <- (!model_options_df$include_propensity[i]) && + (model_options_df$treatment_type[i] == "multivariate") + warning_fun_2 <- function(x) { + expect_warning( + x, + "No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'", + fixed = TRUE + ) + } + warning_cond_3 <- (model_options_df$adaptive_coding[i]) && + (model_options_df$treatment_type[i] != "binary") + warning_fun_3 <- function(x) { + expect_warning( + x, + "Adaptive coding is only compatible with binary (univariate) treatment and, as a result, will be ignored in sampling this model", + fixed = TRUE + ) + } + warning_cond_4 <- (model_options_df$sampling_global_error_scale[i]) && + (model_options_df$outcome_type[i] == "binary") + warning_fun_4 <- function(x) { + expect_warning( + x, + "Global error variance will not be sampled with a probit link as it is fixed at 1", + fixed = TRUE + ) + } + warning_cond_5 <- (model_options_df$sampling_global_error_scale[i]) && + (model_options_df$variance_forest[i]) + warning_fun_5 <- function(x) { + expect_warning( + x, + "Global error variance will not be sampled with a heteroskedasticity", + fixed = TRUE + ) + } + warning_cond_6 <- (model_options_df$treatment_type[i] == "multivariate") && + (model_options_df$random_effects[i] == "intercept_plus_treatment") + warning_fun_6 <- function(x) { + expect_warning( + x, + "Random effects `intercept_plus_treatment` specification is not currently implemented for multivariate treatments. This model will be fit under the `intercept_only` specification instead. Please provide a custom `rfx_basis_train` if you wish to have random slopes on multivariate treatment variables.", + fixed = TRUE + ) + } + warning_cond <- (warning_cond_1 || + warning_cond_2 || + warning_cond_3 || + warning_cond_4 || + warning_cond_5 || + warning_cond_6) + + # Generate something like the below code but for all five warnings + if (error_cond || warning_cond) { + test_fun <- construct_chained_expectation_bcf( + error_cond = error_cond, + warning_cond_1 = warning_cond_1, + warning_cond_2 = warning_cond_2, + warning_cond_3 = warning_cond_3, + warning_cond_4 = warning_cond_4, + warning_cond_5 = warning_cond_5, + warning_cond_6 = warning_cond_6 + ) + } else { + test_fun <- expect_no_error + } + + # Prepare test function arguments + bcf_data_train <- list(X = X_train) + bcf_data_test <- list(X = X_test) + if (model_options_df$outcome_type[i] == "continuous") { + bcf_data_train[["y"]] <- y_continuous_train + bcf_data_test[["y"]] <- y_continuous_test + } else { + bcf_data_train[["y"]] <- y_binary_train + bcf_data_test[["y"]] <- y_binary_test + } + if (model_options_df$include_propensity[i]) { + bcf_data_train[["propensity"]] <- propensity_train + bcf_data_test[["propensity"]] <- propensity_test + } else { + bcf_data_train[["propensity"]] <- NULL + bcf_data_test[["propensity"]] <- NULL + } + if (model_options_df$treatment_type[i] == "binary") { + bcf_data_train[["Z"]] <- binary_treatment_train + bcf_data_test[["Z"]] <- binary_treatment_test + } else if (model_options_df$treatment_type[i] == "univariate_continuous") { + bcf_data_train[["Z"]] <- continuous_treatment_train + bcf_data_test[["Z"]] <- continuous_treatment_test + } else { + bcf_data_train[["Z"]] <- multivariate_treatment_train + bcf_data_test[["Z"]] <- multivariate_treatment_test + } + if (model_options_df$random_effects[i] != "none") { + bcf_data_train[["rfx_group_ids"]] <- group_ids_train + bcf_data_test[["rfx_group_ids"]] <- group_ids_test + } else { + bcf_data_train[["rfx_group_ids"]] <- NULL + bcf_data_test[["rfx_group_ids"]] <- NULL + } + if (model_options_df$random_effects[i] == "custom") { + bcf_data_train[["rfx_basis"]] <- rfx_basis_train + bcf_data_test[["rfx_basis"]] <- rfx_basis_test + } else { + bcf_data_train[["rfx_basis"]] <- NULL + bcf_data_test[["rfx_basis"]] <- NULL + } + + # Apply testthat expectation(s) + test_fun({ + run_bcf_factorial( + bcf_data_train = bcf_data_train, + bcf_data_test = bcf_data_test, + treatment_type = model_options_df$treatment_type[i], + variance_forest = model_options_df$variance_forest[i], + random_effects = model_options_df$random_effects[i], + sampling_global_error_scale = model_options_df$sampling_global_error_scale[ + i + ], + sampling_mu_leaf_scale = model_options_df$sampling_mu_leaf_scale[ + i + ], + sampling_tau_leaf_scale = model_options_df$sampling_tau_leaf_scale[ + i + ], + outcome_type = model_options_df$outcome_type[i], + num_chains = model_options_df$num_chains[i], + adaptive_coding = model_options_df$adaptive_coding[i], + include_propensity = model_options_df$include_propensity[i] + ) + }) + } +}) diff --git a/test/python/conftest.py b/test/python/conftest.py new file mode 100644 index 00000000..e446d0a1 --- /dev/null +++ b/test/python/conftest.py @@ -0,0 +1,21 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--runslow", action="store_true", default=False, help="run slow tests" + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "slow: mark test as slow to run") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runslow"): + # --runslow given in cli: do not skip slow tests + return + skip_slow = pytest.mark.skip(reason="need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) diff --git a/test/python/test_api_combinations.py b/test/python/test_api_combinations.py new file mode 100644 index 00000000..36b1ef5d --- /dev/null +++ b/test/python/test_api_combinations.py @@ -0,0 +1,618 @@ +import itertools +import pytest +import numpy as np +from sklearn.model_selection import train_test_split + +from stochtree import BARTModel, BCFModel + + +def run_bart_factorial( + bart_data_train, + bart_data_test, + leaf_reg="none", + variance_forest=False, + random_effects="none", + sampling_global_error_scale=False, + sampling_leaf_scale=False, + outcome_type="continuous", + num_chains=1, +): + # Unpack BART training data + y = bart_data_train["y"] + X = bart_data_train["X"] + if leaf_reg != "none": + leaf_basis = bart_data_train["leaf_basis"] + else: + leaf_basis = None + if random_effects != "none": + rfx_group_ids = bart_data_train["rfx_group_ids"] + else: + rfx_group_ids = None + if random_effects == "custom": + rfx_basis = bart_data_train["rfx_basis"] + else: + rfx_basis = None + + # Set BART model parameters + general_params = { + "num_chains": num_chains, + "sample_sigma2_global": sampling_global_error_scale, + "probit_outcome_model": outcome_type == "binary", + } + mean_forest_params = {"sample_sigma2_leaf": sampling_leaf_scale} + variance_forest_params = {"num_trees": 20 if variance_forest else 0} + rfx_params = { + "model_spec": "custom" if random_effects == "none" else random_effects + } + + # Sample BART model + bart_model = BARTModel() + bart_model.sample( + X_train=X, + y_train=y, + leaf_basis_train=leaf_basis, + rfx_group_ids_train=rfx_group_ids, + rfx_basis_train=rfx_basis, + general_params=general_params, + mean_forest_params=mean_forest_params, + variance_forest_params=variance_forest_params, + random_effects_params=rfx_params, + ) + + # Unpack test set data + y_test = bart_data_test["y"] + X_test = bart_data_test["X"] + if leaf_reg != "none": + leaf_basis_test = bart_data_test["leaf_basis"] + else: + leaf_basis_test = None + if random_effects != "none": + rfx_group_ids_test = bart_data_test["rfx_group_ids"] + else: + rfx_group_ids_test = None + if random_effects == "custom": + rfx_basis_test = bart_data_test["rfx_basis"] + else: + rfx_basis_test = None + + # Predict on test set + mean_preds = bart_model.predict( + X=X_test, + leaf_basis=leaf_basis_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + type="mean", + terms="all", + scale="probability" if outcome_type == "binary" else "linear", + ) + posterior_preds = bart_model.predict( + X=X_test, + leaf_basis=leaf_basis_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="all", + scale="probability" if outcome_type == "binary" else "linear", + ) + + # Compute intervals + posterior_interval = bart_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="probability" if outcome_type == "binary" else "linear", + X=X_test, + leaf_basis=leaf_basis_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + ) + + # Sample posterior predictive + posterior_predictive_draws = bart_model.sample_posterior_predictive( + X=X_test, + leaf_basis=leaf_basis_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + num_draws_per_sample=5, + ) + + +def run_bcf_factorial( + bcf_data_train, + bcf_data_test, + treatment_type="binary", + variance_forest=False, + random_effects="none", + sampling_global_error_scale=False, + sampling_mu_leaf_scale=False, + sampling_tau_leaf_scale=False, + outcome_type="continuous", + num_chains=1, + adaptive_coding=False, + include_propensity=False, +): + # Unpack BART training data + y = bcf_data_train["y"] + X = bcf_data_train["X"] + Z = bcf_data_train["Z"] + if include_propensity: + propensity = bcf_data_train["propensity"] + else: + propensity = None + if random_effects != "none": + rfx_group_ids = bcf_data_train["rfx_group_ids"] + else: + rfx_group_ids = None + if random_effects == "custom": + rfx_basis = bcf_data_train["rfx_basis"] + else: + rfx_basis = None + + # Set BART model parameters + general_params = { + "num_chains": num_chains, + "sample_sigma2_global": sampling_global_error_scale, + "probit_outcome_model": outcome_type == "binary", + "adaptive_coding": adaptive_coding, + } + mu_forest_params = {"sample_sigma2_leaf": sampling_mu_leaf_scale} + tau_forest_params = {"sample_sigma2_leaf": sampling_tau_leaf_scale} + variance_forest_params = {"num_trees": 20 if variance_forest else 0} + rfx_params = { + "model_spec": "custom" if random_effects == "none" else random_effects + } + + # Sample BART model + bcf_model = BCFModel() + bcf_model.sample( + X_train=X, + y_train=y, + Z_train=Z, + propensity_train=propensity, + rfx_group_ids_train=rfx_group_ids, + rfx_basis_train=rfx_basis, + general_params=general_params, + prognostic_forest_params=mu_forest_params, + treatment_effect_forest_params=tau_forest_params, + variance_forest_params=variance_forest_params, + random_effects_params=rfx_params, + ) + + # Unpack test set data + y_test = bcf_data_test["y"] + X_test = bcf_data_test["X"] + Z_test = bcf_data_test["Z"] + if include_propensity: + propensity_test = bcf_data_test["propensity"] + else: + propensity_test = None + if random_effects != "none": + rfx_group_ids_test = bcf_data_test["rfx_group_ids"] + else: + rfx_group_ids_test = None + if random_effects == "custom": + rfx_basis_test = bcf_data_test["rfx_basis"] + else: + rfx_basis_test = None + + # Predict on test set + mean_preds = bcf_model.predict( + X=X_test, + Z=Z_test, + propensity=propensity_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + type="mean", + terms="all", + scale="probability" if outcome_type == "binary" else "linear", + ) + posterior_preds = bcf_model.predict( + X=X_test, + Z=Z_test, + propensity=propensity_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="all", + scale="probability" if outcome_type == "binary" else "linear", + ) + + # Compute intervals + posterior_interval = bcf_model.compute_posterior_interval( + terms="all", + level=0.95, + scale="probability" if outcome_type == "binary" else "linear", + X=X_test, + Z=Z_test, + propensity=propensity_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + ) + + # Sample posterior predictive + posterior_predictive_draws = bcf_model.sample_posterior_predictive( + X=X_test, + Z=Z_test, + propensity=propensity_test, + rfx_group_ids=rfx_group_ids_test, + rfx_basis=rfx_basis_test, + num_draws_per_sample=5, + ) + + +class TestAPICombinations: + @pytest.mark.slow + def test_bart_api_combinations(self): + # RNG + random_seed = 101 + rng = np.random.default_rng(random_seed) + + # Overall, we have seven components of a BART sampler which can be on / off or set to different levels: + # 1. Leaf regression: none, univariate, multivariate + # 2. Variance forest: no, yes + # 3. Random effects: no, custom basis, `intercept_only` + # 4. Sampling global error scale: no, yes + # 5. Sampling leaf scale on mean forest: no, yes (only available for constant leaf or univariate leaf regression) + # 6. Outcome type: continuous (identity link), binary (probit link) + # 7. Number of chains: 1, >1 + # + # For each of the possible models this implies, + # we'd like to be sure that stochtree functions that operate on BART models + # will run without error. Since there are so many possible models implied by the + # options above, this test is designed to be quick (small sample size, low dimensional data) + # and we are only interested in ensuring no errors are triggered. + + # Generate data with random effects + n = 50 + p = 3 + num_basis = 2 + num_rfx_groups = 3 + num_rfx_basis = 2 + X = rng.uniform(0, 1, (n, p)) + leaf_basis = rng.uniform(0, 1, (n, num_basis)) + leaf_coefs = rng.uniform(0, 1, num_basis) + group_ids = rng.choice(num_rfx_groups, size=n) + rfx_basis = rng.uniform(0, 1, (n, num_rfx_basis)) + rfx_coefs = rng.uniform(0, 1, (num_rfx_groups, num_rfx_basis)) + mean_term = np.sin(X[:, 0]) * np.sum(leaf_basis * leaf_coefs, axis=1) + rfx_term = np.sum(rfx_coefs[group_ids - 1, :] * rfx_basis, axis=1) + E_y = mean_term + rfx_term + E_y = E_y - np.mean(E_y) + epsilon = rng.normal(0, 1, n) + y_continuous = E_y + epsilon + y_binary = (y_continuous > 0).astype(int) + + # Split into test and train sets + test_set_pct = 0.5 + sample_inds = np.arange(n) + train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + leaf_basis_train = leaf_basis[train_inds, :] + leaf_basis_test = leaf_basis[test_inds, :] + rfx_basis_train = rfx_basis[train_inds, :] + rfx_basis_test = rfx_basis[test_inds, :] + group_ids_train = group_ids[train_inds] + group_ids_test = group_ids[test_inds] + y_continuous_train = y_continuous[train_inds] + y_continuous_test = y_continuous[test_inds] + y_binary_train = y_binary[train_inds] + y_binary_test = y_binary[test_inds] + + # Run the power set of models + leaf_reg_options = ["none", "univariate", "multivariate"] + variance_forest_options = [False, True] + random_effects_options = ["none", "custom", "intercept_only"] + sampling_global_error_scale_options = [False, True] + sampling_leaf_scale_options = [False, True] + outcome_type_options = ["continuous", "binary"] + num_chains_options = [1, 3] + model_options_iter = itertools.product( + leaf_reg_options, + variance_forest_options, + random_effects_options, + sampling_global_error_scale_options, + sampling_leaf_scale_options, + outcome_type_options, + num_chains_options, + ) + for i, options in enumerate(model_options_iter): + # Unpack BART train and test data + bart_data_train = {} + bart_data_test = {} + bart_data_train["X"] = X_train + bart_data_test["X"] = X_test + if options[5] == "continuous": + bart_data_train["y"] = y_continuous_train + bart_data_test["y"] = y_continuous_test + else: + bart_data_train["y"] = y_binary_train + bart_data_test["y"] = y_binary_test + if options[0] != "none": + if options[0] == "univariate": + bart_data_train["leaf_basis"] = leaf_basis_train[:, 0] + bart_data_test["leaf_basis"] = leaf_basis_test[:, 0] + else: + bart_data_train["leaf_basis"] = leaf_basis_train + bart_data_test["leaf_basis"] = leaf_basis_test + else: + bart_data_train["leaf_basis"] = None + bart_data_test["leaf_basis"] = None + if options[2] != "none": + bart_data_train["rfx_group_ids"] = group_ids_train + bart_data_test["rfx_group_ids"] = group_ids_test + else: + bart_data_train["rfx_group_ids"] = None + bart_data_test["rfx_group_ids"] = None + if options[2] == "custom": + bart_data_train["rfx_basis"] = rfx_basis_train + bart_data_test["rfx_basis"] = rfx_basis_test + else: + bart_data_train["rfx_basis"] = None + bart_data_test["rfx_basis"] = None + + # Determine whether this combination should throw an error, raise a warning, or run as intended + error_cond = (options[1]) and (options[5] == "binary") + warning_cond_1 = (options[4]) and (options[0] == "multivariate") + warning_cond_2 = (options[3]) and (options[5] == "binary") + warning_cond_3 = (options[3]) and (options[1]) + warning_cond = warning_cond_1 or warning_cond_2 or warning_cond_3 + if error_cond and warning_cond: + with pytest.raises(ValueError) as excinfo: + with pytest.warns(UserWarning) as warninfo: + run_bart_factorial( + bart_data_train=bart_data_train, + bart_data_test=bart_data_test, + leaf_reg=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_leaf_scale=options[4], + outcome_type=options[5], + num_chains=options[6], + ) + elif error_cond and not warning_cond: + with pytest.raises(ValueError) as excinfo: + run_bart_factorial( + bart_data_train=bart_data_train, + bart_data_test=bart_data_test, + leaf_reg=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_leaf_scale=options[4], + outcome_type=options[5], + num_chains=options[6], + ) + elif not error_cond and warning_cond: + with pytest.warns(UserWarning) as warninfo: + run_bart_factorial( + bart_data_train=bart_data_train, + bart_data_test=bart_data_test, + leaf_reg=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_leaf_scale=options[4], + outcome_type=options[5], + num_chains=options[6], + ) + else: + run_bart_factorial( + bart_data_train=bart_data_train, + bart_data_test=bart_data_test, + leaf_reg=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_leaf_scale=options[4], + outcome_type=options[5], + num_chains=options[6], + ) + + @pytest.mark.slow + def test_bcf_api_combinations(self): + # RNG + random_seed = 101 + rng = np.random.default_rng(random_seed) + + # Overall, we have nine components of a BCF sampler which can be on / off or set to different levels: + # 1. treatment: binary, univariate continuous, multivariate + # 2. Variance forest: no, yes + # 3. Random effects: no, custom basis, `intercept_only`, `intercept_plus_treatment` + # 4. Sampling global error scale: no, yes + # 5. Sampling leaf scale on prognostic forest: no, yes + # 6. Sampling leaf scale on treatment forest: no, yes (only available for univariate treatment) + # 7. Outcome type: continuous (identity link), binary (probit link) + # 8. Number of chains: 1, >1 + # 9. Adaptive coding: no, yes + # + # For each of the possible models this implies, + # we'd like to be sure that stochtree functions that operate on BCF models + # will run without error. Since there are so many possible models implied by the + # options above, this test is designed to be quick (small sample size, low dimensional data) + # and we are only interested in ensuring no errors are triggered. + + # Generate data with random effects + n = 50 + p = 3 + num_rfx_groups = 3 + num_rfx_basis = 2 + X = rng.uniform(0, 1, (n, p)) + binary_treatment = rng.binomial(1, 0.5, n) + continuous_treatment = rng.uniform(0, 1, n) + multivariate_treatment = np.column_stack( + (binary_treatment, continuous_treatment) + ) + propensity = rng.uniform(0, 1, n) + group_ids = rng.choice(num_rfx_groups, size=n) + rfx_basis = rng.uniform(0, 1, (n, num_rfx_basis)) + rfx_coefs = rng.uniform(0, 1, (num_rfx_groups, num_rfx_basis)) + prognostic_term = np.sin(X[:, 0]) + binary_treatment_effect = X[:, 1] + continuous_treatment_effect = X[:, 2] + rfx_term = np.sum(rfx_coefs[group_ids - 1, :] * rfx_basis, axis=1) + E_y = (prognostic_term + + binary_treatment_effect * binary_treatment + + continuous_treatment_effect * continuous_treatment + + rfx_term) + E_y = E_y - np.mean(E_y) + epsilon = rng.normal(0, 1, n) + y_continuous = E_y + epsilon + y_binary = (y_continuous > 0).astype(int) + + # Split into test and train sets + test_set_pct = 0.5 + sample_inds = np.arange(n) + train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + binary_treatment_test = binary_treatment[test_inds] + binary_treatment_train = binary_treatment[train_inds] + propensity_test = propensity[test_inds] + propensity_train = propensity[train_inds] + continuous_treatment_test = continuous_treatment[test_inds] + continuous_treatment_train = continuous_treatment[train_inds] + multivariate_treatment_test = multivariate_treatment[test_inds, ] + multivariate_treatment_train = multivariate_treatment[train_inds, ] + rfx_basis_train = rfx_basis[train_inds, :] + rfx_basis_test = rfx_basis[test_inds, :] + group_ids_train = group_ids[train_inds] + group_ids_test = group_ids[test_inds] + y_continuous_train = y_continuous[train_inds] + y_continuous_test = y_continuous[test_inds] + y_binary_train = y_binary[train_inds] + y_binary_test = y_binary[test_inds] + + # Run the power set of models + treatment_options = ["binary", "univariate_continuous", "multivariate"] + variance_forest_options = [False, True] + random_effects_options = ["none", "custom", "intercept_only", "intercept_plus_treatment"] + sampling_global_error_scale_options = [False, True] + sampling_mu_leaf_scale_options = [False, True] + sampling_tau_leaf_scale_options = [False, True] + outcome_type_options = ["continuous", "binary"] + num_chains_options = [1, 3] + adaptive_coding_options = [False, True] + include_propensity_options = [False, True] + model_options_iter = itertools.product( + treatment_options, + variance_forest_options, + random_effects_options, + sampling_global_error_scale_options, + sampling_mu_leaf_scale_options, + sampling_tau_leaf_scale_options, + outcome_type_options, + num_chains_options, + adaptive_coding_options, + include_propensity_options + ) + for i, options in enumerate(model_options_iter): + # Unpack BCF train and test data + bcf_data_train = {} + bcf_data_test = {} + bcf_data_train["X"] = X_train + bcf_data_test["X"] = X_test + bcf_data_train["propensity"] = propensity_train + bcf_data_test["propensity"] = propensity_test + if options[5] == "continuous": + bcf_data_train["y"] = y_continuous_train + bcf_data_test["y"] = y_continuous_test + else: + bcf_data_train["y"] = y_binary_train + bcf_data_test["y"] = y_binary_test + if options[0] == "binary": + bcf_data_train["Z"] = binary_treatment_train + bcf_data_test["Z"] = binary_treatment_test + elif options[0] == "univariate_continuous": + bcf_data_train["Z"] = continuous_treatment_train + bcf_data_test["Z"] = continuous_treatment_test + else: + bcf_data_train["Z"] = multivariate_treatment_train + bcf_data_test["Z"] = multivariate_treatment_test + if options[2] != "none": + bcf_data_train["rfx_group_ids"] = group_ids_train + bcf_data_test["rfx_group_ids"] = group_ids_test + else: + bcf_data_train["rfx_group_ids"] = None + bcf_data_test["rfx_group_ids"] = None + if options[2] == "custom": + bcf_data_train["rfx_basis"] = rfx_basis_train + bcf_data_test["rfx_basis"] = rfx_basis_test + else: + bcf_data_train["rfx_basis"] = None + bcf_data_test["rfx_basis"] = None + + # Determine whether this combination should throw an error, raise a warning, or run as intended + error_cond = (options[1]) and (options[6] == "binary") + warning_cond_1 = (options[5]) and (options[0] == "multivariate") + warning_cond_2 = (options[3]) and (options[6] == "binary") + warning_cond_3 = (options[3]) and (options[1]) + warning_cond_4 = (options[8]) and (options[0] != "binary") + warning_cond_5 = (not options[9]) and (options[0] == "multivariate") + warning_cond_6 = (options[2] == "intercept_plus_treatment") and (options[0] == "multivariate") + warning_cond = warning_cond_1 or warning_cond_2 or warning_cond_3 or warning_cond_4 or warning_cond_5 or warning_cond_6 + print(f"error_cond: {error_cond}, warning_cond_1: {warning_cond_1}, warning_cond_2: {warning_cond_2}, warning_cond_3: {warning_cond_3}, warning_cond_4: {warning_cond_4}, warning_cond_5: {warning_cond_5}, warning_cond_6: {warning_cond_6}") + if error_cond and warning_cond: + with pytest.raises(ValueError) as excinfo: + with pytest.warns(UserWarning) as warninfo: + run_bcf_factorial( + bcf_data_train=bcf_data_train, + bcf_data_test=bcf_data_test, + treatment_type=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_mu_leaf_scale=options[4], + sampling_tau_leaf_scale=options[5], + outcome_type=options[6], + num_chains=options[7], + adaptive_coding=options[8], + include_propensity=options[9], + ) + elif error_cond and not warning_cond: + with pytest.raises(ValueError) as excinfo: + run_bcf_factorial( + bcf_data_train=bcf_data_train, + bcf_data_test=bcf_data_test, + treatment_type=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_mu_leaf_scale=options[4], + sampling_tau_leaf_scale=options[5], + outcome_type=options[6], + num_chains=options[7], + adaptive_coding=options[8], + include_propensity=options[9], + ) + elif not error_cond and warning_cond: + with pytest.warns(UserWarning) as warninfo: + run_bcf_factorial( + bcf_data_train=bcf_data_train, + bcf_data_test=bcf_data_test, + treatment_type=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_mu_leaf_scale=options[4], + sampling_tau_leaf_scale=options[5], + outcome_type=options[6], + num_chains=options[7], + adaptive_coding=options[8], + include_propensity=options[9], + ) + else: + run_bcf_factorial( + bcf_data_train=bcf_data_train, + bcf_data_test=bcf_data_test, + treatment_type=options[0], + variance_forest=options[1], + random_effects=options[2], + sampling_global_error_scale=options[3], + sampling_mu_leaf_scale=options[4], + sampling_tau_leaf_scale=options[5], + outcome_type=options[6], + num_chains=options[7], + adaptive_coding=options[8], + include_propensity=options[9], + ) diff --git a/test/python/test_bart.py b/test/python/test_bart.py index b182524b..7f49f5b4 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -429,13 +429,13 @@ def conditional_stddev(X): sigma2_x_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.sigma2_x_train, ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples - ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], - bart_model_2.global_var_samples, - ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + # ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + # bart_model_2.global_var_samples, + # ) def test_bart_univariate_leaf_regression_heteroskedastic(self): # RNG @@ -554,13 +554,13 @@ def conditional_stddev(X): np.testing.assert_allclose( y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples - ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], - bart_model_2.global_var_samples, - ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + # ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + # bart_model_2.global_var_samples, + # ) def test_bart_multivariate_leaf_regression_heteroskedastic(self): # RNG @@ -679,13 +679,13 @@ def conditional_stddev(X): np.testing.assert_allclose( y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples - ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], - bart_model_2.global_var_samples, - ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + # ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + # bart_model_2.global_var_samples, + # ) def test_bart_constant_leaf_heteroskedastic_rfx(self): # RNG @@ -836,13 +836,13 @@ def rfx_term(group_labels, basis): np.testing.assert_allclose( y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples - ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], - bart_model_2.global_var_samples, - ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + # ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + # bart_model_2.global_var_samples, + # ) np.testing.assert_allclose(rfx_preds_train_3[:, 0:num_mcmc], rfx_preds_train) np.testing.assert_allclose( rfx_preds_train_3[:, num_mcmc : (2 * num_mcmc)], rfx_preds_train_2 @@ -1010,13 +1010,13 @@ def conditional_stddev(X): np.testing.assert_allclose( y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples - ) - np.testing.assert_allclose( - bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], - bart_model_2.global_var_samples, - ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + # ) + # np.testing.assert_allclose( + # bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + # bart_model_2.global_var_samples, + # ) np.testing.assert_allclose(rfx_preds_train_3[:, 0:num_mcmc], rfx_preds_train) np.testing.assert_allclose( rfx_preds_train_3[:, num_mcmc : (2 * num_mcmc)], rfx_preds_train_2 diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index c5a1446f..bbfd55d5 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -645,7 +645,7 @@ def test_multivariate_bcf(self): assert tau_hat.shape == (n_test, num_mcmc, treatment_dim) # Run BCF with test set and without propensity score - with pytest.raises(ValueError): + with pytest.warns(UserWarning): bcf_model = BCFModel() variance_forest_params = {"num_trees": 0} bcf_model.sample( @@ -661,7 +661,7 @@ def test_multivariate_bcf(self): ) # Run BCF without test set and without propensity score - with pytest.raises(ValueError): + with pytest.warns(UserWarning): bcf_model = BCFModel() variance_forest_params = {"num_trees": 0} bcf_model.sample(