Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions .github/workflows/slow-api-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
name: Unit Tests and Slow Running API Integration Tests for R and Python

on:
workflow_dispatch:

jobs:
testing:
name: test-slow-api-combinations
runs-on: ${{ matrix.os }}

strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]

steps:
- name: Prevent conversion of line endings on Windows
if: startsWith(matrix.os, 'windows')
shell: pwsh
run: git config --global core.autocrlf false

- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: 'recursive'

- name: Setup Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"

- name: Set up openmp (macos)
# Set up openMP on MacOS since it doesn't ship with the apple clang compiler suite
if: matrix.os == 'macos-latest'
run: |
brew install libomp

- name: Install Package with Relevant Dependencies
run: |
pip install --upgrade pip
pip install -r requirements.txt
pip install .

- name: Run Pytest with Slow Running API Tests Enabled
run: |
pytest --runslow test/python

- name: Setup Pandoc for R
uses: r-lib/actions/setup-pandoc@v2

- name: Setup R
uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true

- name: Setup R Package Dependencies
uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::testthat, any::decor, any::rcmdcheck
needs: check

- name: Create a CRAN-ready version of the R package
run: |
Rscript cran-bootstrap.R 0 0 1

- name: Run CRAN Checks with Slow Running API Tests Enabled
uses: r-lib/actions/check-r-package@v2
env:
RUN_SLOW_TESTS: true
with:
working-directory: 'stochtree_cran'
12 changes: 4 additions & 8 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
# Changelog

# stochtree (development version)

## New Features

## Computational Improvements
# stochtree 0.2.1

## Bug Fixes

* Predict random effects correctly in R for univariate random effects models ([#248](https://github.com/StochasticTree/stochtree/pull/248))

## Documentation Improvements
* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248))

## Other Changes

* Encode expectations about which combinations of BART / BCF features work together and ensure warning ([#250](https://github.com/StochasticTree/stochtree/pull/250))

# stochtree 0.2.0

## New Features
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: stochtree
Title: Stochastic Tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference
Version: 0.2.0.9000
Version: 0.2.1
Authors@R:
c(
person("Drew", "Herren", email = "drewherrenopensource@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4109-6611")),
Expand Down
2 changes: 1 addition & 1 deletion Doxyfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ PROJECT_NAME = "StochTree"
# could be handy for archiving the generated documentation or if some version
# control system is used.

PROJECT_NUMBER = 0.2.0.9000
PROJECT_NUMBER = 0.2.1

# Using the PROJECT_BRIEF tag one can provide an optional one line description
# for a project that appears at the top of each page and should give viewer a
Expand Down
12 changes: 4 additions & 8 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
# stochtree (development version)

## New Features

## Computational Improvements
# stochtree 0.2.1

## Bug Fixes

* Predict random effects correctly in R for univariate random effects models ([#248](https://github.com/StochasticTree/stochtree/pull/248))

## Documentation Improvements
* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248))

## Other Changes

* Encode expectations about which combinations of BART / BCF features work together and ensure warning ([#250](https://github.com/StochasticTree/stochtree/pull/250))

# stochtree 0.2.0

## New Features
Expand Down
12 changes: 10 additions & 2 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,16 @@ bart <- function(
}
}

# Runtime checks for variance forest
if (include_variance_forest) {
if (sample_sigma2_global) {
warning(
"Global error variance will not be sampled with a heteroskedasticity forest"
)
sample_sigma2_global <- F
}
}

# Handle standardization, prior calibration, and initialization of forest
# differently for binary and continuous outcomes
if (probit_outcome_model) {
Expand Down Expand Up @@ -2124,7 +2134,6 @@ predict.bartmodel <- function(
X <- preprocessPredictionData(X, train_set_metadata)

# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
has_rfx <- FALSE
if (predict_rfx) {
if (!is.null(rfx_group_ids)) {
rfx_unique_group_ids <- object$rfx_unique_group_ids
Expand All @@ -2135,7 +2144,6 @@ predict.bartmodel <- function(
)
}
rfx_group_ids <- as.integer(group_ids_factor)
has_rfx <- TRUE
}
}

Expand Down
63 changes: 36 additions & 27 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -897,14 +897,7 @@ bcf <- function(
# Handle multivariate treatment
has_multivariate_treatment <- ncol(Z_train) > 1
if (has_multivariate_treatment) {
# Disable adaptive coding, internal propensity model, and
# leaf scale sampling if treatment is multivariate
if (adaptive_coding) {
warning(
"Adaptive coding is incompatible with multivariate treatment and will be ignored"
)
adaptive_coding <- FALSE
}
# Disable internal propensity model and leaf scale sampling if treatment is multivariate
if (is.null(propensity_train)) {
if (propensity_covariate != "none") {
warning(
Expand Down Expand Up @@ -949,21 +942,31 @@ bcf <- function(
}
has_basis_rfx <- TRUE
num_basis_rfx <- ncol(rfx_basis_train)
} else if (rfx_model_spec == "intercept_only") {
rfx_basis_train <- matrix(
rep(1, nrow(X_train)),
nrow = nrow(X_train),
ncol = 1
)
has_basis_rfx <- TRUE
num_basis_rfx <- 1
} else if (rfx_model_spec == "intercept_plus_treatment") {
rfx_basis_train <- cbind(
rep(1, nrow(X_train)),
Z_train
)
has_basis_rfx <- TRUE
num_basis_rfx <- 1 + ncol(Z_train)
if (has_multivariate_treatment) {
warning(
"Random effects `intercept_plus_treatment` specification is not currently implemented for multivariate treatments. This model will be fit under the `intercept_only` specification instead. Please provide a custom `rfx_basis_train` if you wish to have random slopes on multivariate treatment variables."
)
rfx_model_spec <- "intercept_only"
}
}
if (is.null(rfx_basis_train)) {
if (rfx_model_spec == "intercept_only") {
rfx_basis_train <- matrix(
rep(1, nrow(X_train)),
nrow = nrow(X_train),
ncol = 1
)
has_basis_rfx <- TRUE
num_basis_rfx <- 1
} else {
rfx_basis_train <- cbind(
rep(1, nrow(X_train)),
Z_train
)
has_basis_rfx <- TRUE
num_basis_rfx <- 1 + ncol(Z_train)
}
}
num_rfx_groups <- length(unique(rfx_group_ids_train))
num_rfx_components <- ncol(rfx_basis_train)
Expand Down Expand Up @@ -1021,15 +1024,21 @@ bcf <- function(
y_train <- as.matrix(y_train)
}

# Check whether treatment is binary (specifically 0-1 binary)
binary_treatment <- length(unique(Z_train)) == 2
if (binary_treatment) {
unique_treatments <- sort(unique(Z_train))
if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE
# Check whether treatment is binary and univariate (specifically 0-1 binary)
binary_treatment <- FALSE
if (!has_multivariate_treatment) {
binary_treatment <- length(unique(Z_train)) == 2
if (binary_treatment) {
unique_treatments <- sort(unique(Z_train))
if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE
}
}

# Adaptive coding will be ignored for continuous / ordered categorical treatments
if ((!binary_treatment) && (adaptive_coding)) {
warning(
"Adaptive coding is only compatible with binary (univariate) treatment and, as a result, will be ignored in sampling this model"
)
adaptive_coding <- FALSE
}

Expand Down
106 changes: 64 additions & 42 deletions R/posterior_transformation.R
Original file line number Diff line number Diff line change
Expand Up @@ -409,19 +409,31 @@ compute_contrast_bart_model <- function(
"rfx_group_ids_0 and rfx_group_ids_1 must be provided for this model"
)
}
if ((has_rfx) && (is.null(rfx_basis_0) || is.null(rfx_basis_1))) {
stop(
"rfx_basis_0 and rfx_basis_1 must be provided for this model"
)
}
if (
(object$model_params$num_rfx_basis > 0) &&
((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) ||
(ncol(rfx_basis_1) != object$model_params$num_rfx_basis))
) {
stop(
"rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model"
)
if (has_rfx) {
if (object$model_params$rfx_model_spec == "custom") {
if ((is.null(rfx_basis_0) || is.null(rfx_basis_1))) {
stop(
"A user-provided basis (`rfx_basis_0` and `rfx_basis_1`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
)
}
if (!is.matrix(rfx_basis_0) || !is.matrix(rfx_basis_1)) {
stop("'rfx_basis_0' and 'rfx_basis_1' must be matrices")
}
if ((nrow(rfx_basis_0) != nrow(X)) || (nrow(rfx_basis_1) != nrow(X))) {
stop(
"'rfx_basis_0' and 'rfx_basis_1' must have the same number of rows as 'X'"
)
}
if (
(object$model_params$num_rfx_basis > 0) &&
((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) ||
(ncol(rfx_basis_1) != object$model_params$num_rfx_basis))
) {
stop(
"rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model"
)
}
}
}

# Predict for the control arm
Expand Down Expand Up @@ -574,16 +586,22 @@ sample_bcf_posterior_predictive <- function(
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
)
}
if (is.null(rfx_basis)) {
stop(
"'rfx_basis' must be provided in order to compute the requested intervals"
)
}
if (!is.matrix(rfx_basis)) {
stop("'rfx_basis' must be a matrix")

if (model_object$model_params$rfx_model_spec == "custom") {
if (is.null(rfx_basis)) {
stop(
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
)
}
}
if (nrow(rfx_basis) != nrow(X)) {
stop("'rfx_basis' must have the same number of rows as 'X'")

if (!is.null(rfx_basis)) {
if (!is.matrix(rfx_basis)) {
stop("'rfx_basis' must be a matrix")
}
if (nrow(rfx_basis) != nrow(X)) {
stop("'rfx_basis' must have the same number of rows as 'X'")
}
}
}

Expand Down Expand Up @@ -735,16 +753,18 @@ sample_bart_posterior_predictive <- function(
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
)
}
if (is.null(rfx_basis)) {
stop(
"'rfx_basis' must be provided in order to compute the requested intervals"
)
}
if (!is.matrix(rfx_basis)) {
stop("'rfx_basis' must be a matrix")
}
if (nrow(rfx_basis) != nrow(X)) {
stop("'rfx_basis' must have the same number of rows as 'X'")
if (model_object$model_params$rfx_model_spec == "custom") {
if (is.null(rfx_basis)) {
stop(
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
)
}
if (!is.matrix(rfx_basis)) {
stop("'rfx_basis' must be a matrix")
}
if (nrow(rfx_basis) != nrow(X)) {
stop("'rfx_basis' must have the same number of rows as 'X'")
}
}
}

Expand Down Expand Up @@ -1172,16 +1192,18 @@ compute_bart_posterior_interval <- function(
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
)
}
if (is.null(rfx_basis)) {
stop(
"'rfx_basis' must be provided in order to compute the requested intervals"
)
}
if (!is.matrix(rfx_basis)) {
stop("'rfx_basis' must be a matrix")
}
if (nrow(rfx_basis) != nrow(X)) {
stop("'rfx_basis' must have the same number of rows as 'X'")
if (model_object$model_params$rfx_model_spec == "custom") {
if (is.null(rfx_basis)) {
stop(
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
)
}
if (!is.matrix(rfx_basis)) {
stop("'rfx_basis' must be a matrix")
}
if (nrow(rfx_basis) != nrow(X)) {
stop("'rfx_basis' must have the same number of rows as 'X'")
}
}
}

Expand Down
Loading
Loading