@@ -1060,6 +1060,7 @@ def predict(
10601060 quantiles : Optional [List ] = [0.25 , 0.5 , 0.75 ],
10611061 n_samples : Optional [int ] = 100 ,
10621062 ret_logits = False ,
1063+ include_input_features : bool = True ,
10631064 ) -> pd .DataFrame :
10641065 """Uses the trained model to predict on new data and return as a dataframe
10651066
@@ -1072,11 +1073,17 @@ def predict(
10721073 Ignored for non-probabilistic models. Defaults to 100
10731074 ret_logits (bool): Flag to return raw model outputs/logits except the backbone features along
10741075 with the dataframe. Defaults to False
1076+ include_input_features (bool): Flag to include the input features in the returned dataframe.
1077+ Defaults to True
10751078
10761079 Returns:
1077- pd.DataFrame: Returns a dataframe with predictions and features.
1080+ pd.DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`) .
10781081 If classification, it returns probabilities and final prediction
10791082 """
1083+ warnings .warn (
1084+ "Default for `include_input_features` will change from True to False in the next release. Please set it explicitly." ,
1085+ DeprecationWarning ,
1086+ )
10801087 assert all ([q <= 1 and q >= 0 for q in quantiles ]), "Quantiles should be a decimal between 0 and 1"
10811088 self .model .eval ()
10821089 inference_dataloader = self .datamodule .prepare_inference_dataloader (test )
@@ -1113,7 +1120,10 @@ def predict(
11131120 quantile_predictions = torch .cat (quantile_predictions , dim = 0 ).unsqueeze (- 1 )
11141121 if quantile_predictions .ndim == 2 :
11151122 quantile_predictions = quantile_predictions .unsqueeze (- 1 )
1116- pred_df = test .copy () # TODO Add option to switch between including the entire input DF or not.
1123+ if include_input_features :
1124+ pred_df = test .copy () # TODO Add option to switch between including the entire input DF or not.
1125+ else :
1126+ pred_df = pd .DataFrame (index = test .index )
11171127 if self .config .task == "regression" :
11181128 point_predictions = point_predictions .numpy ()
11191129 # Probabilistic Models are only implemented for Regression
0 commit comments