Skip to content

Commit 0ad9706

Browse files
authored
flag for predict df (#139)
1 parent 5d650aa commit 0ad9706

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/pytorch_tabular/tabular_model.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,7 @@ def predict(
10601060
quantiles: Optional[List] = [0.25, 0.5, 0.75],
10611061
n_samples: Optional[int] = 100,
10621062
ret_logits=False,
1063+
include_input_features: bool = True,
10631064
) -> pd.DataFrame:
10641065
"""Uses the trained model to predict on new data and return as a dataframe
10651066
@@ -1072,11 +1073,17 @@ def predict(
10721073
Ignored for non-probabilistic models. Defaults to 100
10731074
ret_logits (bool): Flag to return raw model outputs/logits except the backbone features along
10741075
with the dataframe. Defaults to False
1076+
include_input_features (bool): Flag to include the input features in the returned dataframe.
1077+
Defaults to True
10751078
10761079
Returns:
1077-
pd.DataFrame: Returns a dataframe with predictions and features.
1080+
pd.DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
10781081
If classification, it returns probabilities and final prediction
10791082
"""
1083+
warnings.warn(
1084+
"Default for `include_input_features` will change from True to False in the next release. Please set it explicitly.",
1085+
DeprecationWarning,
1086+
)
10801087
assert all([q <= 1 and q >= 0 for q in quantiles]), "Quantiles should be a decimal between 0 and 1"
10811088
self.model.eval()
10821089
inference_dataloader = self.datamodule.prepare_inference_dataloader(test)
@@ -1113,7 +1120,10 @@ def predict(
11131120
quantile_predictions = torch.cat(quantile_predictions, dim=0).unsqueeze(-1)
11141121
if quantile_predictions.ndim == 2:
11151122
quantile_predictions = quantile_predictions.unsqueeze(-1)
1116-
pred_df = test.copy() # TODO Add option to switch between including the entire input DF or not.
1123+
if include_input_features:
1124+
pred_df = test.copy() # TODO Add option to switch between including the entire input DF or not.
1125+
else:
1126+
pred_df = pd.DataFrame(index=test.index)
11171127
if self.config.task == "regression":
11181128
point_predictions = point_predictions.numpy()
11191129
# Probabilistic Models are only implemented for Regression

0 commit comments

Comments
 (0)