You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Added model comparator
tweaked DANet default virtual batch size
* added test cases
* added lite config
added model sweep
added model sweep tutorial
* added documentation
* update documentation for model sweep
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Copy file name to clipboardExpand all lines: docs/tabular_model.md
+39Lines changed: 39 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -30,6 +30,45 @@ tabular_model = TabularModel(
30
30
)
31
31
```
32
32
33
+
### Model Sweep
34
+
35
+
PyTorch Tabular also provides an easy way to check performance of different models and configurations on a given dataset. This is done through the `model_sweep` function. It takes in a list of model configs or one of the presets defined in ``pytorch_tabular.MODEL_PRESETS`` and trains them on the data. It then ranks the models based on the metric provided and returns the best model.
36
+
37
+
These are the major args:
38
+
-``task``: The type of prediction task. Either 'classification' or 'regression'
39
+
-``train``: The training data
40
+
-``test``: The test data on which performance is evaluated
41
+
- all the config objects can be passed as either the object or the path to the yaml file.
42
+
-``models``: The list of models to compare. This can be one of the presets defined in ``pytorch_tabular.MODEL_SWEEP_PRESETS`` or a list of ``ModelConfig`` objects.
43
+
-``metrics``: the list of metrics you need to track during training. The metrics should be one of the functional metrics implemented in ``torchmetrics``. By default, it is accuracy if classification and mean_squared_error for regression
44
+
-``metrics_prob_input``: Is a mandatory parameter for classification metrics defined in the config. This defines whether the input to the metric function is the probability or the class. Length should be same as the number of metrics. Defaults to None.
45
+
-``metrics_params``: The parameters to be passed to the metrics function.
46
+
-``rank_metric``: The metric to use for ranking the models. The first element of the tuple is the metric name and the second element is the direction. Defaults to ('loss', "lower_is_better").
47
+
-``return_best_model``: If True, will return the best model. Defaults to True.
48
+
49
+
#### Usage Example
50
+
51
+
```python
52
+
sweep_df, best_model = model_sweep(
53
+
task="classification", # One of "classification", "regression"
54
+
train=train,
55
+
test=test,
56
+
data_config=data_config,
57
+
optimizer_config=optimizer_config,
58
+
trainer_config=trainer_config,
59
+
model_list="lite", # One of the presets defined in pytorch_tabular.MODEL_SWEEP_PRESETS
metrics=['accuracy', "f1_score"], # The metrics to track during training
62
+
metrics_params=[{}, {"average": "weighted"}],
63
+
metrics_prob_input=[False, True],
64
+
rank_metric=("accuracy", "higher_is_better"), # The metric to use for ranking the models.
65
+
progress_bar=True, # If True, will show a progress bar
66
+
verbose=False# If True, will print the results of each model
67
+
)
68
+
```
69
+
70
+
For more examples, check out the tutorial notebook - [Model Sweep]("tutorials/13-Model Sweep.ipynb") for example usage.
71
+
33
72
### Advanced Usage
34
73
35
74
-`config`: DictConfig: Another way of initializing `TabularModel` is with an `Dictconfig` from `omegaconf`. Although not recommended, you can create a normal dictionary with all the parameters dumped into it and create a `DictConfig` from `omegaconf` and pass it here. The downside is that you'll be skipping all the validation(both type validation and logical validations). This is primarily used internally to load a saved model from a checkpoint.
0 commit comments