Skip to content

Commit 1f3c14e

Browse files
authored
Updating Torchmetrics dependency (#142)
* [skip actions] updated history * updated torchmetrics
1 parent d925a89 commit 1f3c14e

File tree

8 files changed

+87
-15
lines changed

8 files changed

+87
-15
lines changed

docs/history.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,31 @@
4343
- Made the temp folder pytorch tabular specific to avoid conflicts with other tmp folders.
4444
- Some bug fixes
4545
- Edited an error out of Advanced Tutorial in docs
46+
47+
## 1.0.0 (2023-01-18)
48+
49+
- Added a new task - Self Supervised Learning (SSL) and a separate training API for it.
50+
- Added new SOTA model - Gated Additive Tree Ensembles (GATE).
51+
- Added one SSL model - Denoising AutoEncoder.
52+
- Added lots of new tutorials and updated entire documentation.
53+
- Improved code documentation and type hints.
54+
- Separated a Model into separate Embedding, Backbone, and Head.
55+
- Refactored all models to separate Backbone as native PyTorch Model(nn.Module).
56+
- Refactored commonly used modules (layers, activations etc. to a common module).
57+
- Changed MixedDensityNetworks completely (breaking change). Now MDN is a head you can use with any model.
58+
- Enabled a low level api for training model.
59+
- Enabled saving and loading of datamodule.
60+
- Added trainer_kwargs to pass any trainer argument PyTorch Lightning supports.
61+
- Added Early Stopping and Model Checkpoint kwargs to use all the arguments in PyTorch Lightining.
62+
- Enabled prediction using GPUs in predict method.
63+
- Added `reset_model` to reset model weights to random.
64+
- Added many save and load functions including ONNX(experimental).
65+
- Added random seed as a parameter.
66+
- Switched over completely to Rich progressbars from tqdm.
67+
- Fixed class-balancing / mu propagation and set default to 1.0.
68+
- Added PyTorch Profiler for debugging performance issues.
69+
- Fixed bugs with FTTransformer and TabTransformer.
70+
- Updated MixedDensityNetworks fixing a bug with lambda_pi.
71+
- Many CI/CD improvements including complete integration with GitHub Actions.
72+
- Upgraded all dependencies, including PyTorch Lightning, pandas, to latest versions and added dependabot to manage it going forward.
73+
- Added pre-commit to ensure code integrity and standardization.

examples/__only_for_dev__/to_test_dae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def print_metrics(y_true, y_pred, tag):
140140
from pytorch_tabular.models import CategoryEmbeddingModelConfig # noqa: E402
141141
from pytorch_tabular.ssl_models.dae import DenoisingAutoEncoderConfig # noqa: E402
142142

143-
max_epochs = 5
143+
max_epochs = 1
144144
batch_size = 1024
145145
lr = 1e-3
146146

@@ -159,7 +159,7 @@ def print_metrics(y_true, y_pred, tag):
159159
auto_lr_find=False, # Runs the LRFinder to automatically derive a learning rate
160160
batch_size=batch_size,
161161
max_epochs=max_epochs,
162-
fast_dev_run=False,
162+
fast_dev_run=True,
163163
)
164164
optimizer_config = OptimizerConfig()
165165
encoder_config = CategoryEmbeddingModelConfig(

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ pandas>=1.1.5
55
scikit-learn>=1.0.0
66
pytorch-lightning==1.8.*
77
omegaconf>=2.0.1
8-
torchmetrics==0.10.*
8+
torchmetrics==0.11.*
99
tensorboard>=2.2.0, !=2.5.0
1010
protobuf<=3.20.*
1111
pytorch-tabnet==4.0

src/pytorch_tabular/config/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ class InferredConfig:
193193
194194
embedding_dims (Optional[List]): The dimensions of the embedding for each categorical column as a
195195
list of tuples (cardinality, embedding_dim).
196+
197+
embedded_cat_dim (int): The number of features or dimensions of the embedded categorical features
198+
196199
"""
197200

198201
categorical_dim: int = field(

src/pytorch_tabular/models/base_model.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""Base Model"""
55
import warnings
66
from abc import ABCMeta, abstractmethod
7+
from functools import partial
78
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
89

910
import pytorch_lightning as pl
@@ -86,8 +87,27 @@ def __init__(
8687
if self.custom_loss is not None:
8788
config.loss = str(self.custom_loss)
8889
if self.custom_metrics is not None:
89-
config.metrics = [str(m) for m in self.custom_metrics]
90-
config.metrics_params = [vars(m) for m in self.custom_metrics]
90+
# Adding metrics to config for hparams logging and tracking
91+
config.metrics = []
92+
config.metrics_params = []
93+
for metric in self.custom_metrics:
94+
if isinstance(metric, partial):
95+
# extracting func names from partial functions
96+
config.metrics.append(metric.func.__name__)
97+
config.metrics_params.append(metric.keywords)
98+
else:
99+
config.metrics.append(metric.__name__)
100+
config.metrics_params.append(vars(metric))
101+
else: # Updating default metrics in config
102+
if config.task == "classification":
103+
# Adding metric_params to config for classification task
104+
for i, metric_params in enumerate(config.metrics_params):
105+
if "task" not in metric_params:
106+
# For classification task, output_dim == number of classses
107+
config.metrics_params[i]["task"] = "binary" if inferred_config.output_dim == 2 else "multiclass"
108+
if "num_classes" not in metric_params:
109+
config.metrics_params[i]["num_classes"] = inferred_config.output_dim
110+
91111
if self.custom_optimizer is not None:
92112
config.optimizer = str(self.custom_optimizer.__class__.__name__)
93113
if len(self.custom_optimizer_params) > 0:
@@ -167,7 +187,6 @@ def _setup_metrics(self):
167187
raise e
168188
else:
169189
self.metrics = self.custom_metrics
170-
self.hparams.metrics = [m.__name__ for m in self.custom_metrics]
171190

172191
def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tensor:
173192
"""Calculates the loss for the model
@@ -241,7 +260,11 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
241260
if self.hparams.task == "regression":
242261
_metrics = []
243262
for i in range(self.hparams.output_dim):
244-
if metric.__name__ == torchmetrics.functional.mean_squared_log_error.__name__:
263+
if isinstance(metric, partial):
264+
name = metric.func.__name__
265+
else:
266+
name = metric.__name__
267+
if name == torchmetrics.functional.mean_squared_log_error.__name__:
245268
# MSLE should only be used in strictly positive targets. It is undefined otherwise
246269
_metric = metric(
247270
torch.clamp(y_hat[:, i], min=0),

src/pytorch_tabular/models/ft_transformer/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class FTTransformerConfig(ModelConfig):
181181
default=4,
182182
metadata={"help": "Multiple by which the Positionwise FF layer scales the input. Defaults to 4"},
183183
)
184-
# TODO improve documentation
184+
185185
transformer_activation: str = field(
186186
default="GEGLU",
187187
metadata={

src/pytorch_tabular/models/tab_transformer/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ class TabTransformerConfig(ModelConfig):
172172
default=4,
173173
metadata={"help": "Multiple by which the Positionwise FF layer scales the input. Defaults to 4"},
174174
)
175-
# TODO improve documentation
176175
transformer_activation: str = field(
177176
default="GEGLU",
178177
metadata={

src/pytorch_tabular/tabular_model.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import warnings
99
from collections import defaultdict
10+
from functools import partial
1011
from pathlib import Path
1112
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
1213

@@ -370,7 +371,7 @@ def load_model(cls, dir: str, map_location=None, strict=True):
370371
model_args["loss"] = "MSELoss" # For compatibility. Not Used
371372
if custom_params.get("custom_metrics") is not None:
372373
model_args["metrics"] = ["mean_squared_error"] # For compatibility. Not Used
373-
model_args["metric_params"] = [{}] # For compatibility. Not Used
374+
model_args["metrics_params"] = [{}] # For compatibility. Not Used
374375
if custom_params.get("custom_optimizer") is not None:
375376
model_args["optimizer"] = "Adam" # For compatibility. Not Used
376377
if custom_params.get("custom_optimizer_params") is not None:
@@ -815,6 +816,14 @@ def create_finetune_model(
815816
datamodule.target = config.target
816817
datamodule.batch_size = config.batch_size
817818
datamodule.seed = config.seed
819+
model_callable = _GenericModel
820+
inferred_config = self.datamodule.update_config(config)
821+
inferred_config = OmegaConf.structured(inferred_config)
822+
# Adding dummy attributes for compatibility. Not used because custom metrics are provided
823+
if not hasattr(config, "metrics"):
824+
config.metrics = "dummy"
825+
if not hasattr(config, "metrics_params"):
826+
config.metrics_params = {}
818827
if metrics is not None:
819828
assert len(metrics) == len(metrics_params), "Number of metrics and metrics_params should be same"
820829
metrics = [getattr(torchmetrics.functional, m) if isinstance(m, str) else m for m in metrics]
@@ -827,11 +836,21 @@ def create_finetune_model(
827836
loss = loss if loss is not None else torch.nn.CrossEntropyLoss()
828837
if metrics is None:
829838
metrics = [torchmetrics.functional.accuracy]
830-
metrics_params = [{}]
831-
832-
model_callable = _GenericModel
833-
inferred_config = self.datamodule.update_config(config)
834-
inferred_config = OmegaConf.structured(inferred_config)
839+
metrics_params = [
840+
{
841+
"task": "binary" if inferred_config.output_dim == 2 else "multiclass",
842+
"num_classes": inferred_config.output_dim,
843+
}
844+
]
845+
else:
846+
for i, mp in enumerate(metrics_params):
847+
if "task" not in mp:
848+
# For classification task, output_dim == number of classses
849+
metrics_params[i]["task"] = "binary" if inferred_config.output_dim == 2 else "multiclass"
850+
if "num_classes" not in mp:
851+
metrics_params[i]["num_classes"] = inferred_config.output_dim
852+
# Forming partial callables using metrics and metric params
853+
metrics = [partial(m, **mp) for m, mp in zip(metrics, metrics_params)]
835854
self.model.mode = "finetune"
836855
if learning_rate is not None:
837856
config.learning_rate = learning_rate

0 commit comments

Comments
 (0)