Skip to content

Commit 4644fdf

Browse files
committed
Reformatted python code and fixed (most) issues surfaced by ruff code analysis
1 parent b1ee7e0 commit 4644fdf

File tree

5 files changed

+28
-40
lines changed

5 files changed

+28
-40
lines changed

stochtree/bart.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -765,19 +765,8 @@ def sample(
765765
)
766766
previous_bart_model = BARTModel()
767767
previous_bart_model.from_json(previous_model_json)
768-
previous_y_bar = previous_bart_model.y_bar
769768
previous_y_scale = previous_bart_model.y_std
770769
previous_model_num_samples = previous_bart_model.num_samples
771-
if previous_bart_model.include_mean_forest:
772-
previous_forest_samples_mean = previous_bart_model.forest_container_mean
773-
else:
774-
previous_forest_samples_mean = None
775-
if previous_bart_model.include_variance_forest:
776-
previous_forest_samples_variance = (
777-
previous_bart_model.forest_container_variance
778-
)
779-
else:
780-
previous_forest_samples_variance = None
781770
if previous_bart_model.sample_sigma2_global:
782771
previous_global_var_samples = previous_bart_model.global_var_samples / (
783772
previous_y_scale * previous_y_scale
@@ -788,22 +777,14 @@ def sample(
788777
previous_leaf_var_samples = previous_bart_model.leaf_scale_samples
789778
else:
790779
previous_leaf_var_samples = None
791-
if previous_bart_model.has_rfx:
792-
previous_rfx_samples = previous_bart_model.rfx_container
793-
else:
794-
previous_rfx_samples = None
795780
if previous_model_warmstart_sample_num + 1 > previous_model_num_samples:
796781
raise ValueError(
797782
"`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`"
798783
)
799784
else:
800-
previous_y_bar = None
801785
previous_y_scale = None
802786
previous_global_var_samples = None
803787
previous_leaf_var_samples = None
804-
previous_rfx_samples = None
805-
previous_forest_samples_mean = None
806-
previous_forest_samples_variance = None
807788
previous_model_num_samples = 0
808789

809790
# Update variable weights if the covariates have been resized (by e.g. one-hot encoding)
@@ -1772,7 +1753,6 @@ def predict(
17721753
rfx_intercept = rfx_model_spec == "intercept_only"
17731754
if not isinstance(terms, str) and not isinstance(terms, list):
17741755
raise ValueError("type must be a string or list of strings")
1775-
num_terms = 1 if isinstance(terms, str) else len(terms)
17761756
has_mean_forest = self.include_mean_forest
17771757
has_variance_forest = self.include_variance_forest
17781758
has_rfx = self.has_rfx

stochtree/bcf.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,7 +1656,6 @@ def sample(
16561656
if sample_sigma2_leaf_tau:
16571657
self.leaf_scale_tau_samples = np.empty(self.num_samples, dtype=np.float64)
16581658
muhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
1659-
tauhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64)
16601659
if self.include_variance_forest:
16611660
sigma2_x_train_raw = np.empty(
16621661
(self.n_train, self.num_samples), dtype=np.float64
@@ -2442,7 +2441,6 @@ def predict(
24422441
raise ValueError(
24432442
f"term '{term}' was requested. Valid terms are 'y_hat', 'prognostic_function', 'mu', 'cate', 'tau', 'rfx', 'variance_forest', and 'all'"
24442443
)
2445-
num_terms = 1 if isinstance(terms, str) else len(terms)
24462444
has_mu_forest = True
24472445
has_tau_forest = True
24482446
has_variance_forest = self.include_variance_forest

stochtree/preprocessing.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,9 @@ def _fit_numpy(self, covariates: np.array) -> None:
339339
self._onehot_feature_index = np.array(
340340
[-1 for i in range(self._num_original_features)], dtype=int
341341
)
342-
self._original_feature_types = np.array(
343-
["float" for i in range(self._num_original_features)]
344-
)
342+
self._original_feature_types = np.array([
343+
"float" for i in range(self._num_original_features)
344+
])
345345

346346
# Check whether the array is numeric
347347
cov_dtype = covariates.dtype
@@ -443,9 +443,9 @@ def _transform_numpy(self, covariates: np.array) -> np.array:
443443
raise ValueError(
444444
"Attempting to call transform from a CovariateTransformer that was fit on a dataset with different dimensionality"
445445
)
446-
self._original_feature_indices = np.array(
447-
[i for i in range(covariates.shape[1])]
448-
)
446+
self._original_feature_indices = np.array([
447+
i for i in range(covariates.shape[1])
448+
])
449449
return covariates
450450

451451
def _transform(self, covariates: Union[pd.DataFrame, np.array]) -> np.array:

stochtree/random_effects.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -367,20 +367,20 @@ def predict(self, group_labels: np.array, basis: np.array) -> np.ndarray:
367367
return self.rfx_container_cpp.Predict(
368368
rfx_dataset.rfx_dataset_cpp, self.rfx_label_mapper_cpp
369369
)
370-
370+
371371
def extract_parameter_samples(self) -> dict[str, np.ndarray]:
372372
"""
373-
Extract the random effects parameters sampled. With the "redundant parameterization" of Gelman et al (2008),
374-
this includes four parameters: alpha (the "working parameter" shared across every group), xi
375-
(the "group parameter" sampled separately for each group), beta (the product of alpha and xi,
376-
which corresponds to the overall group-level random effects), and sigma (group-independent prior
373+
Extract the random effects parameters sampled. With the "redundant parameterization" of Gelman et al (2008),
374+
this includes four parameters: alpha (the "working parameter" shared across every group), xi
375+
(the "group parameter" sampled separately for each group), beta (the product of alpha and xi,
376+
which corresponds to the overall group-level random effects), and sigma (group-independent prior
377377
variance for each component of xi).
378378
379379
Returns
380380
-------
381381
dict[str, np.ndarray]
382-
dict of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`.
383-
The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and are simply matrices if `num_components = 1`.
382+
dict of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`.
383+
The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and are simply matrices if `num_components = 1`.
384384
The sigma array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`.
385385
"""
386386
# num_samples = self.rfx_container_cpp.NumSamples()
@@ -391,10 +391,10 @@ def extract_parameter_samples(self) -> dict[str, np.ndarray]:
391391
alpha_samples = np.squeeze(self.rfx_container_cpp.GetAlpha())
392392
sigma_samples = np.squeeze(self.rfx_container_cpp.GetSigma())
393393
output = {
394-
"beta_samples": beta_samples,
394+
"beta_samples": beta_samples,
395395
"xi_samples": xi_samples,
396396
"alpha_samples": alpha_samples,
397-
"sigma_samples": sigma_samples
397+
"sigma_samples": sigma_samples,
398398
}
399399
return output
400400

stochtree/utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,21 +334,31 @@ def _expand_dims_2d_diag(
334334
)
335335
return output
336336

337-
def _posterior_predictive_heuristic_multiplier(num_samples: int, num_observations: int) -> int:
337+
338+
def _posterior_predictive_heuristic_multiplier(
339+
num_samples: int, num_observations: int
340+
) -> int:
338341
if num_samples >= 1000:
339342
return 1
340343
else:
341344
return math.ceil(1000 / num_samples)
342345

343-
def _summarize_interval(array: np.ndarray, sample_dim: int = 2, level: float = 0.95) -> dict:
346+
347+
def _summarize_interval(
348+
array: np.ndarray, sample_dim: int = 2, level: float = 0.95
349+
) -> dict:
344350
# Check that the array is numeric and at least 2 dimensional
345351
if not isinstance(array, np.ndarray):
346352
raise ValueError("`array` must be a numpy array")
347353
if not _check_array_numeric(array):
348354
raise ValueError("`array` must be a numeric numpy array")
349355
if not len(array.shape) >= 2:
350356
raise ValueError("`array` must be at least a 2-dimensional numpy array")
351-
if not _check_is_int(sample_dim) or (sample_dim < 0) or (sample_dim >= len(array.shape)):
357+
if (
358+
not _check_is_int(sample_dim)
359+
or (sample_dim < 0)
360+
or (sample_dim >= len(array.shape))
361+
):
352362
raise ValueError(
353363
"`sample_dim` must be an integer between 0 and the number of dimensions of `array` - 1"
354364
)

0 commit comments

Comments
 (0)