Skip to content

Commit 8db1d8d

Browse files
committed
Updated BART posterior interval function / method in R and Python
1 parent 396931c commit 8db1d8d

File tree

4 files changed

+52
-52
lines changed

4 files changed

+52
-52
lines changed

R/posterior_transformation.R

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,8 +1068,8 @@ compute_bcf_posterior_interval <- function(
10681068
#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`.
10691069
#' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval).
10701070
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
1071-
#' @param covariates A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).
1072-
#' @param basis An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
1071+
#' @param X A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).
1072+
#' @param leaf_basis An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
10731073
#' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects.
10741074
#' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.
10751075
#'
@@ -1085,7 +1085,7 @@ compute_bcf_posterior_interval <- function(
10851085
#' intervals <- compute_bart_posterior_interval(
10861086
#' model_object = bart_model,
10871087
#' terms = c("mean_forest", "y_hat"),
1088-
#' covariates = X,
1088+
#' X = X,
10891089
#' level = 0.90
10901090
#' )
10911091
#' @export
@@ -1094,8 +1094,8 @@ compute_bart_posterior_interval <- function(
10941094
terms,
10951095
level = 0.95,
10961096
scale = "linear",
1097-
covariates = NULL,
1098-
basis = NULL,
1097+
X = NULL,
1098+
leaf_basis = NULL,
10991099
rfx_group_ids = NULL,
11001100
rfx_basis = NULL
11011101
) {
@@ -1129,30 +1129,30 @@ compute_bart_posterior_interval <- function(
11291129
if (needs_covariates) {
11301130
if (is.null(covariates)) {
11311131
stop(
1132-
"'covariates' must be provided in order to compute the requested intervals"
1132+
"'X' must be provided in order to compute the requested intervals"
11331133
)
11341134
}
1135-
if (!is.matrix(covariates) && !is.data.frame(covariates)) {
1136-
stop("'covariates' must be a matrix or data frame")
1135+
if (!is.matrix(X) && !is.data.frame(X)) {
1136+
stop("'X' must be a matrix or data frame")
11371137
}
11381138
}
11391139
needs_basis <- needs_covariates && model_object$model_params$has_basis
11401140
if (needs_basis) {
1141-
if (is.null(basis)) {
1141+
if (is.null(leaf_basis)) {
11421142
stop(
1143-
"'basis' must be provided in order to compute the requested intervals"
1143+
"'leaf_basis' must be provided in order to compute the requested intervals"
11441144
)
11451145
}
1146-
if (!is.matrix(basis)) {
1147-
stop("'basis' must be a matrix")
1146+
if (!is.matrix(leaf_basis)) {
1147+
stop("'leaf_basis' must be a matrix")
11481148
}
1149-
if (is.matrix(basis)) {
1150-
if (nrow(basis) != nrow(covariates)) {
1151-
stop("'basis' must have the same number of rows as 'covariates'")
1149+
if (is.matrix(leaf_basis)) {
1150+
if (nrow(leaf_basis) != nrow(X)) {
1151+
stop("'leaf_basis' must have the same number of rows as 'X'")
11521152
}
11531153
} else {
1154-
if (length(basis) != nrow(covariates)) {
1155-
stop("'basis' must have the same number of elements as 'covariates'")
1154+
if (length(leaf_basis) != nrow(X)) {
1155+
stop("'leaf_basis' must have the same number of elements as 'X'")
11561156
}
11571157
}
11581158
}
@@ -1167,9 +1167,9 @@ compute_bart_posterior_interval <- function(
11671167
"'rfx_group_ids' must be provided in order to compute the requested intervals"
11681168
)
11691169
}
1170-
if (length(rfx_group_ids) != nrow(covariates)) {
1170+
if (length(rfx_group_ids) != nrow(X)) {
11711171
stop(
1172-
"'rfx_group_ids' must have the same length as the number of rows in 'covariates'"
1172+
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
11731173
)
11741174
}
11751175
if (is.null(rfx_basis)) {
@@ -1180,16 +1180,16 @@ compute_bart_posterior_interval <- function(
11801180
if (!is.matrix(rfx_basis)) {
11811181
stop("'rfx_basis' must be a matrix")
11821182
}
1183-
if (nrow(rfx_basis) != nrow(covariates)) {
1184-
stop("'rfx_basis' must have the same number of rows as 'covariates'")
1183+
if (nrow(rfx_basis) != nrow(X)) {
1184+
stop("'rfx_basis' must have the same number of rows as 'X'")
11851185
}
11861186
}
11871187

11881188
# Compute posterior matrices for the requested model terms
11891189
predictions <- predict(
11901190
model_object,
1191-
X = covariates,
1192-
leaf_basis = basis,
1191+
X = X,
1192+
leaf_basis = leaf_basis,
11931193
rfx_group_ids = rfx_group_ids,
11941194
rfx_basis = rfx_basis,
11951195
type = "posterior",

man/compute_bart_posterior_interval.Rd

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

stochtree/bart.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,8 +2196,8 @@ def compute_posterior_interval(
21962196
terms: Union[list[str], str] = "all",
21972197
level: float = 0.95,
21982198
scale: str = "linear",
2199-
covariates: np.array = None,
2200-
basis: np.array = None,
2199+
X: np.array = None,
2200+
leaf_basis: np.array = None,
22012201
rfx_group_ids: np.array = None,
22022202
rfx_basis: np.array = None,
22032203
) -> dict:
@@ -2212,9 +2212,9 @@ def compute_posterior_interval(
22122212
Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`.
22132213
level : float, optional
22142214
A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval.
2215-
covariates : np.array, optional
2215+
X : np.array, optional
22162216
Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).
2217-
basis : np.array, optional
2217+
leaf_basis : np.array, optional
22182218
Optional array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.
22192219
rfx_group_ids : np.array, optional
22202220
Optional vector of group IDs for random effects. Required if the requested term includes random effects.
@@ -2266,25 +2266,25 @@ def compute_posterior_interval(
22662266
or needs_covariates_intermediate
22672267
)
22682268
if needs_covariates:
2269-
if covariates is None:
2269+
if X is None:
22702270
raise ValueError(
2271-
"'covariates' must be provided in order to compute the requested intervals"
2271+
"'X' must be provided in order to compute the requested intervals"
22722272
)
2273-
if not isinstance(covariates, np.ndarray) and not isinstance(
2274-
covariates, pd.DataFrame
2273+
if not isinstance(X, np.ndarray) and not isinstance(
2274+
X, pd.DataFrame
22752275
):
2276-
raise ValueError("'covariates' must be a matrix or data frame")
2276+
raise ValueError("'X' must be a matrix or data frame")
22772277
needs_basis = needs_covariates and self.has_basis
22782278
if needs_basis:
2279-
if basis is None:
2279+
if leaf_basis is None:
22802280
raise ValueError(
2281-
"'basis' must be provided in order to compute the requested intervals"
2281+
"'leaf_basis' must be provided in order to compute the requested intervals"
22822282
)
2283-
if not isinstance(basis, np.ndarray):
2284-
raise ValueError("'basis' must be a numpy array")
2285-
if basis.shape[0] != covariates.shape[0]:
2283+
if not isinstance(leaf_basis, np.ndarray):
2284+
raise ValueError("'leaf_basis' must be a numpy array")
2285+
if leaf_basis.shape[0] != X.shape[0]:
22862286
raise ValueError(
2287-
"'basis' must have the same number of rows as 'covariates'"
2287+
"'leaf_basis' must have the same number of rows as 'X'"
22882288
)
22892289
needs_rfx_data_intermediate = (
22902290
("y_hat" in terms) or ("all" in terms)
@@ -2297,25 +2297,25 @@ def compute_posterior_interval(
22972297
)
22982298
if not isinstance(rfx_group_ids, np.ndarray):
22992299
raise ValueError("'rfx_group_ids' must be a numpy array")
2300-
if rfx_group_ids.shape[0] != covariates.shape[0]:
2300+
if rfx_group_ids.shape[0] != X.shape[0]:
23012301
raise ValueError(
2302-
"'rfx_group_ids' must have the same length as the number of rows in 'covariates'"
2302+
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
23032303
)
23042304
if rfx_basis is None:
23052305
raise ValueError(
23062306
"'rfx_basis' must be provided in order to compute the requested intervals"
23072307
)
23082308
if not isinstance(rfx_basis, np.ndarray):
23092309
raise ValueError("'rfx_basis' must be a numpy array")
2310-
if rfx_basis.shape[0] != covariates.shape[0]:
2310+
if rfx_basis.shape[0] != X.shape[0]:
23112311
raise ValueError(
2312-
"'rfx_basis' must have the same number of rows as 'covariates'"
2312+
"'rfx_basis' must have the same number of rows as 'X'"
23132313
)
23142314

23152315
# Compute posterior matrices for the requested model terms
23162316
predictions = self.predict(
2317-
covariates=covariates,
2318-
basis=basis,
2317+
X=X,
2318+
leaf_basis=leaf_basis,
23192319
rfx_group_ids=rfx_group_ids,
23202320
rfx_basis=rfx_basis,
23212321
type="posterior",

tools/debug/bart_predict_debug.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ y_hat_intervals <- compute_bart_posterior_interval(
5656
model_object = bart_model,
5757
transform = function(x) x,
5858
terms = c("y_hat", "mean_forest"),
59-
covariates = X_test,
59+
X = X_test,
6060
level = 0.95
6161
)
6262

@@ -137,7 +137,7 @@ y_hat_intervals <- compute_bart_posterior_interval(
137137
model_object = bart_model,
138138
scale = "linear",
139139
terms = c("y_hat"),
140-
covariates = X_test,
140+
X = X_test,
141141
level = 0.95
142142
)
143143

@@ -146,7 +146,7 @@ y_hat_prob_intervals <- compute_bart_posterior_interval(
146146
model_object = bart_model,
147147
scale = "probability",
148148
terms = c("y_hat"),
149-
covariates = X_test,
149+
X = X_test,
150150
level = 0.95
151151
)
152152

0 commit comments

Comments
 (0)