Skip to content

Commit 8a9a4e7

Browse files
committed
-- fixed a bug with not calling eval on model
1 parent 1d4042c commit 8a9a4e7

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pytorch_tabular/tabular_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def fit(
456456
# Parameters in NODE needs to be initialized again
457457
if self.config._model_name in ["CategoryEmbeddingNODEModel", "NODEModel"]:
458458
self.data_aware_initialization()
459-
459+
self.model.train()
460460
self.trainer.fit(self.model, train_loader, val_loader)
461461
logger.info("Training the model completed...")
462462
if self.config.load_best:
@@ -581,6 +581,7 @@ def predict(self, test: pd.DataFrame) -> pd.DataFrame:
581581
pd.DataFrame: Returns a dataframe with predictions and features.
582582
If classification, it returns probabilities and final prediction
583583
"""
584+
self.model.eval()
584585
inference_dataloader = self.datamodule.prepare_inference_dataloader(test)
585586
predictions = []
586587
for sample in tqdm(inference_dataloader, desc="Generating Predictions..."):

0 commit comments

Comments
 (0)