Skip to content

Commit 396931c

Browse files
committed
Updated compute_contrast in Python and corrected other R package issues
1 parent 9597886 commit 396931c

File tree

5 files changed

+42
-72
lines changed

5 files changed

+42
-72
lines changed

R/posterior_transformation.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ sample_bart_posterior_predictive <- function(
751751
# Compute posterior samples
752752
bart_preds <- predict(
753753
model_object,
754-
covariates = covariates,
754+
X = covariates,
755755
leaf_basis = basis,
756756
rfx_group_ids = rfx_group_ids,
757757
rfx_basis = rfx_basis,
@@ -1188,7 +1188,7 @@ compute_bart_posterior_interval <- function(
11881188
# Compute posterior matrices for the requested model terms
11891189
predictions <- predict(
11901190
model_object,
1191-
covariates = covariates,
1191+
X = covariates,
11921192
leaf_basis = basis,
11931193
rfx_group_ids = rfx_group_ids,
11941194
rfx_basis = rfx_basis,

R/utils.R

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,33 +1092,3 @@ expand_dims_2d_diag <- function(input, output_size) {
10921092
}
10931093
return(output)
10941094
}
1095-
1096-
1097-
gfr_tie_checks <- function(covariates) {
1098-
num_vars <- ncol(covariates)
1099-
for (j in 1:num_vars) {
1100-
x_j <- covariates[, j]
1101-
if (has_few_unique_values(x_j)) {
1102-
warning_message <- paste0(
1103-
"Covariate column ",
1104-
j,
1105-
" has relatively few unique values. ",
1106-
"This may lead to tied values when sampling split points in BART/BCF, ",
1107-
"which can cause errors during model fitting. ",
1108-
"Consider adding small amounts of noise to this variable to break ties."
1109-
)
1110-
warning(warning_message)
1111-
}
1112-
}
1113-
}
1114-
1115-
1116-
has_few_unique_values <- function(
1117-
x,
1118-
count_threshold = 15
1119-
) {
1120-
x_unique <- unique(x)
1121-
num_unique_values <- length(unique_values)
1122-
unique_to_total_count_ratio <- num_unique_values / length(x)
1123-
return(num_unique_values <= threshold)
1124-
}

demo/debug/bart_contrast_debug.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@
5555

5656
# Compute contrast posterior
5757
contrast_posterior_test = bart_model.compute_contrast(
58-
covariates_0=X_test,
59-
covariates_1=X_test,
60-
basis_0=np.zeros((n_test, 1)),
61-
basis_1=np.ones((n_test, 1)),
58+
X_0=X_test,
59+
X_1=X_test,
60+
leaf_basis_0=np.zeros((n_test, 1)),
61+
leaf_basis_1=np.ones((n_test, 1)),
6262
type="posterior",
6363
scale="linear",
6464
)
@@ -143,10 +143,10 @@
143143

144144
# Compute contrast posterior
145145
contrast_posterior_test = bart_model.compute_contrast(
146-
covariates_0=X_test,
147-
covariates_1=X_test,
148-
basis_0=np.zeros((n_test, 1)),
149-
basis_1=np.ones((n_test, 1)),
146+
X_0=X_test,
147+
X_1=X_test,
148+
leaf_basis_0=np.zeros((n_test, 1)),
149+
leaf_basis_1=np.ones((n_test, 1)),
150150
rfx_group_ids_0=group_ids_test,
151151
rfx_group_ids_1=group_ids_test,
152152
rfx_basis_0=rfx_basis_test,

demo/debug/probit_bart_rfx_debug.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@
7272

7373
# Compute contrast posterior
7474
contrast_posterior_test = bart_model.compute_contrast(
75-
covariates_0=X_test,
76-
covariates_1=X_test,
77-
basis_0=np.zeros((n_test, 1)),
78-
basis_1=np.ones((n_test, 1)),
75+
X_0=X_test,
76+
X_1=X_test,
77+
leaf_basis_0=np.zeros((n_test, 1)),
78+
leaf_basis_1=np.ones((n_test, 1)),
7979
rfx_group_ids_0=group_ids_test,
8080
rfx_group_ids_1=group_ids_test,
8181
rfx_basis_0=rfx_basis_test,

stochtree/bart.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,10 +2046,10 @@ def predict(
20462046

20472047
def compute_contrast(
20482048
self,
2049-
covariates_0: Union[np.array, pd.DataFrame],
2050-
covariates_1: Union[np.array, pd.DataFrame],
2051-
basis_0: np.array = None,
2052-
basis_1: np.array = None,
2049+
X_0: Union[np.array, pd.DataFrame],
2050+
X_1: Union[np.array, pd.DataFrame],
2051+
leaf_basis_0: np.array = None,
2052+
leaf_basis_1: np.array = None,
20532053
rfx_group_ids_0: np.array = None,
20542054
rfx_group_ids_1: np.array = None,
20552055
rfx_basis_0: np.array = None,
@@ -2068,13 +2068,13 @@ def compute_contrast(
20682068
20692069
Parameters
20702070
----------
2071-
covariates_0 : np.array or pd.DataFrame
2071+
X_0 : np.array or pd.DataFrame
20722072
Covariates used for prediction in the "control" case. Must be a numpy array or dataframe.
2073-
covariates_1 : np.array or pd.DataFrame
2073+
X_1 : np.array or pd.DataFrame
20742074
Covariates used for prediction in the "treatment" case. Must be a numpy array or dataframe.
2075-
basis_0 : np.array, optional
2075+
leaf_basis_0 : np.array, optional
20762076
Bases used for prediction in the "control" case (by e.g. dot product with leaf values).
2077-
basis_1 : np.array, optional
2077+
leaf_basis_1 : np.array, optional
20782078
Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values).
20792079
rfx_group_ids_0 : np.array, optional
20802080
Test set group labels used for prediction from an additive random effects model in the "control" case.
@@ -2135,33 +2135,33 @@ def compute_contrast(
21352135
raise NotSampledError(msg)
21362136

21372137
# Data checks
2138-
if not isinstance(covariates_0, pd.DataFrame) and not isinstance(
2139-
covariates_0, np.ndarray
2138+
if not isinstance(X_0, pd.DataFrame) and not isinstance(
2139+
X_0, np.ndarray
21402140
):
2141-
raise ValueError("covariates_0 must be a pandas dataframe or numpy array")
2142-
if not isinstance(covariates_1, pd.DataFrame) and not isinstance(
2143-
covariates_1, np.ndarray
2141+
raise ValueError("X_0 must be a pandas dataframe or numpy array")
2142+
if not isinstance(X_1, pd.DataFrame) and not isinstance(
2143+
X_1, np.ndarray
21442144
):
2145-
raise ValueError("covariates_1 must be a pandas dataframe or numpy array")
2146-
if basis_0 is not None:
2147-
if not isinstance(basis_0, np.ndarray):
2148-
raise ValueError("basis_0 must be a numpy array")
2149-
if basis_0.shape[0] != covariates_0.shape[0]:
2145+
raise ValueError("X_1 must be a pandas dataframe or numpy array")
2146+
if leaf_basis_0 is not None:
2147+
if not isinstance(leaf_basis_0, np.ndarray):
2148+
raise ValueError("leaf_basis_0 must be a numpy array")
2149+
if leaf_basis_0.shape[0] != X_0.shape[0]:
21502150
raise ValueError(
2151-
"covariates_0 and basis_0 must have the same number of rows"
2151+
"X_0 and leaf_basis_0 must have the same number of rows"
21522152
)
2153-
if basis_1 is not None:
2154-
if not isinstance(basis_1, np.ndarray):
2155-
raise ValueError("basis_1 must be a numpy array")
2156-
if basis_1.shape[0] != covariates_1.shape[0]:
2153+
if leaf_basis_1 is not None:
2154+
if not isinstance(leaf_basis_1, np.ndarray):
2155+
raise ValueError("leaf_basis_1 must be a numpy array")
2156+
if leaf_basis_1.shape[0] != X_1.shape[0]:
21572157
raise ValueError(
2158-
"covariates_1 and basis_1 must have the same number of rows"
2158+
"X_1 and leaf_basis_1 must have the same number of rows"
21592159
)
21602160

21612161
# Predict for the control arm
21622162
control_preds = self.predict(
2163-
covariates=covariates_0,
2164-
basis=basis_0,
2163+
X=X_0,
2164+
leaf_basis=leaf_basis_0,
21652165
rfx_group_ids=rfx_group_ids_0,
21662166
rfx_basis=rfx_basis_0,
21672167
type="posterior",
@@ -2171,8 +2171,8 @@ def compute_contrast(
21712171

21722172
# Predict for the treatment arm
21732173
treatment_preds = self.predict(
2174-
covariates=covariates_1,
2175-
basis=basis_1,
2174+
X=X_1,
2175+
leaf_basis=leaf_basis_1,
21762176
rfx_group_ids=rfx_group_ids_1,
21772177
rfx_basis=rfx_basis_1,
21782178
type="posterior",

0 commit comments

Comments
 (0)