Skip to content

Commit 1c23046

Browse files
committed
[skip actions] minor update readme and logging
1 parent 399d726 commit 1c23046

File tree

3 files changed

+5
-10
lines changed

3 files changed

+5
-10
lines changed

README.md

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,7 @@ git clone git://github.com/manujosephv/pytorch_tabular
5858
Once you have a copy of the source, you can install it with:
5959

6060
```bash
61-
pip install .
62-
```
63-
64-
or
65-
66-
```bash
67-
python setup.py install
61+
pip install .[extra]
6862
```
6963

7064
## Documentation
@@ -142,11 +136,11 @@ loaded_model = TabularModel.load_from_checkpoint("examples/basic")
142136

143137
## Future Roadmap(Contributions are Welcome)
144138

145-
1. Add GaussRank as Feature Transformation
146139
1. Integrate Optuna Hyperparameter Tuning
147140
1. Integrate SHAP for interpretability
148141
1. Add Variable Importance
149142
1. Add ability to use custom activations in CategoryEmbeddingModel
143+
1. Add GaussRank as Feature Transformation
150144
1. ~~Add differential dropouts(layer-wise) in CategoryEmbeddingModel~~
151145
1. ~~Add Fourier Encoding for cyclic time variables~~
152146
1. ~~Add Text and Image Modalities for mixed modal problems~~

examples/covertype_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
model_config = CategoryEmbeddingModelConfig(
9999
task="classification", metrics=["f1_score", "accuracy"], metrics_params=[{"num_classes": num_classes}, {}]
100100
)
101-
trainer_config = TrainerConfig(auto_select_gpus=True, fast_dev_run=False, max_epochs=5, batch_size=512)
101+
trainer_config = TrainerConfig(auto_lr_find=True, fast_dev_run=False, max_epochs=5, batch_size=512)
102102
optimizer_config = OptimizerConfig()
103103
tabular_model = TabularModel(
104104
data_config=data_config,

src/pytorch_tabular/tabular_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,8 @@ def train(
556556
self.model.train()
557557
if self.config.auto_lr_find and (not self.config.fast_dev_run):
558558
logger.info("Auto LR Find Started")
559-
self.trainer.tune(self.model, train_loader, val_loader)
559+
result = self.trainer.tune(self.model, train_loader, val_loader)
560+
logger.info(f"Suggested LR: {result['lr_find'].suggestion()}. For plot and detailed analysis, use `find_learning_rate` method.")
560561
# Parameters in models needs to be initialized again after LR find
561562
self.model.data_aware_initialization(self.datamodule)
562563
self.model.train()

0 commit comments

Comments
 (0)