Skip to content

Commit 24ce9e5

Browse files
Model Sweep (#366)
* Added model comparator tweaked DANet default virtual batch size * added test cases * added lite config added model sweep added model sweep tutorial * added documentation * update documentation for model sweep * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 96d6c06 commit 24ce9e5

17 files changed

+1437
-1795
lines changed

docs/apidocs_coreclasses.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,9 @@
66
::: pytorch_tabular.TabularDatamodule
77
options:
88
heading_level: 3
9+
::: pytorch_tabular.TabularModelTuner
10+
options:
11+
heading_level: 3
12+
::: pytorch_tabular.model_sweep
13+
options:
14+
heading_level: 3

docs/apidocs_utils.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@
1818
::: pytorch_tabular.utils.get_gaussian_centers
1919
options:
2020
heading_level: 3
21+
::: pytorch_tabular.utils.load_covertype_dataset
22+
options:
23+
heading_level: 3
24+
::: pytorch_tabular.utils.make_mixed_dataset
25+
options:
26+
heading_level: 3
27+
::: pytorch_tabular.utils.print_metrics
28+
options:
29+
heading_level: 3
2130

2231
## NN Utilities
2332
::: pytorch_tabular.utils._initialize_layers
@@ -38,7 +47,10 @@
3847
::: pytorch_tabular.utils.to_one_hot
3948
options:
4049
heading_level: 3
41-
50+
::: pytorch_tabular.utils.count_parameters
51+
options:
52+
heading_level: 3
53+
4254
## Python Utilities
4355
::: pytorch_tabular.utils.getattr_nested
4456
options:
@@ -55,3 +67,12 @@
5567
::: pytorch_tabular.utils.generate_doc_dataclass
5668
options:
5769
heading_level: 3
70+
::: pytorch_tabular.utils.suppress_lightning_logs
71+
options:
72+
heading_level: 3
73+
::: pytorch_tabular.utils.enable_lightning_logs
74+
options:
75+
heading_level: 3
76+
::: pytorch_tabular.utils.int_to_human_readable
77+
options:
78+
heading_level: 3

docs/tabular_model.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,45 @@ tabular_model = TabularModel(
3030
)
3131
```
3232

33+
### Model Sweep
34+
35+
PyTorch Tabular also provides an easy way to check performance of different models and configurations on a given dataset. This is done through the `model_sweep` function. It takes in a list of model configs or one of the presets defined in ``pytorch_tabular.MODEL_PRESETS`` and trains them on the data. It then ranks the models based on the metric provided and returns the best model.
36+
37+
These are the major args:
38+
- ``task``: The type of prediction task. Either 'classification' or 'regression'
39+
- ``train``: The training data
40+
- ``test``: The test data on which performance is evaluated
41+
- all the config objects can be passed as either the object or the path to the yaml file.
42+
- ``models``: The list of models to compare. This can be one of the presets defined in ``pytorch_tabular.MODEL_SWEEP_PRESETS`` or a list of ``ModelConfig`` objects.
43+
- ``metrics``: the list of metrics you need to track during training. The metrics should be one of the functional metrics implemented in ``torchmetrics``. By default, it is accuracy if classification and mean_squared_error for regression
44+
- ``metrics_prob_input``: Is a mandatory parameter for classification metrics defined in the config. This defines whether the input to the metric function is the probability or the class. Length should be same as the number of metrics. Defaults to None.
45+
- ``metrics_params``: The parameters to be passed to the metrics function.
46+
- ``rank_metric``: The metric to use for ranking the models. The first element of the tuple is the metric name and the second element is the direction. Defaults to ('loss', "lower_is_better").
47+
- ``return_best_model``: If True, will return the best model. Defaults to True.
48+
49+
#### Usage Example
50+
51+
```python
52+
sweep_df, best_model = model_sweep(
53+
task="classification", # One of "classification", "regression"
54+
train=train,
55+
test=test,
56+
data_config=data_config,
57+
optimizer_config=optimizer_config,
58+
trainer_config=trainer_config,
59+
model_list="lite", # One of the presets defined in pytorch_tabular.MODEL_SWEEP_PRESETS
60+
common_model_args=dict(head="LinearHead", head_config=head_config),
61+
metrics=['accuracy', "f1_score"], # The metrics to track during training
62+
metrics_params=[{}, {"average": "weighted"}],
63+
metrics_prob_input=[False, True],
64+
rank_metric=("accuracy", "higher_is_better"), # The metric to use for ranking the models.
65+
progress_bar=True, # If True, will show a progress bar
66+
verbose=False # If True, will print the results of each model
67+
)
68+
```
69+
70+
For more examples, check out the tutorial notebook - [Model Sweep]("tutorials/13-Model Sweep.ipynb") for example usage.
71+
3372
### Advanced Usage
3473

3574
- `config`: DictConfig: Another way of initializing `TabularModel` is with an `Dictconfig` from `omegaconf`. Although not recommended, you can create a normal dictionary with all the parameters dumped into it and create a `DictConfig` from `omegaconf` and pass it here. The downside is that you'll be skipping all the validation(both type validation and logical validations). This is primarily used internally to load a saved model from a checkpoint.

docs/tutorials/13-Model Leaderboard copy.ipynb

Lines changed: 402 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)