File tree Expand file tree Collapse file tree 1 file changed +2
-1
lines changed
Expand file tree Collapse file tree 1 file changed +2
-1
lines changed Original file line number Diff line number Diff line change @@ -456,7 +456,7 @@ def fit(
456456 # Parameters in NODE needs to be initialized again
457457 if self .config ._model_name in ["CategoryEmbeddingNODEModel" , "NODEModel" ]:
458458 self .data_aware_initialization ()
459-
459+ self . model . train ()
460460 self .trainer .fit (self .model , train_loader , val_loader )
461461 logger .info ("Training the model completed..." )
462462 if self .config .load_best :
@@ -581,6 +581,7 @@ def predict(self, test: pd.DataFrame) -> pd.DataFrame:
581581 pd.DataFrame: Returns a dataframe with predictions and features.
582582 If classification, it returns probabilities and final prediction
583583 """
584+ self .model .eval ()
584585 inference_dataloader = self .datamodule .prepare_inference_dataloader (test )
585586 predictions = []
586587 for sample in tqdm (inference_dataloader , desc = "Generating Predictions..." ):
You can’t perform that action at this time.
0 commit comments