Skip to content

Commit 75dc0cd

Browse files
committed
-- added shared embeddings for TabTransformer
-- added extract_embeddings for NODE, AutoInt -- updated ReadMe
1 parent d6e80c5 commit 75dc0cd

File tree

7 files changed

+147
-38
lines changed

7 files changed

+147
-38
lines changed

README.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ For complete Documentation with tutorials visit []
6969
* [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442) is another model coming out of Google Research which uses Sparse Attention in multiple steps of decision making to model the output.
7070
* [Mixture Density Networks](https://publications.aston.ac.uk/id/eprint/373/1/NCRG_94_004.pdf) is a regression model which uses gaussian components to approximate the target function and provide a probabilistic prediction out of the box.
7171
* [AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921) is a model which tries to learn interactions between the features in an automated way and create a better representation and then use this representation in downstream task
72+
* [TabTransformer] (https://arxiv.org/abs/2012.06678) is an adaptation of the Transformer model for Tabular Data which creates contextual representations for categorical features.
7273

7374
To implement new models, see the [How to implement new models tutorial](https://github.com/manujosephv/pytorch_tabular/blob/main/docs/04-Implementing%20New%20Architectures.ipynb). It covers basic as well as advanced architectures.
7475

@@ -112,9 +113,9 @@ loaded_model = TabularModel.load_from_checkpoint("examples/basic")
112113
```
113114
## Blogs
114115

115-
[PyTorch Tabular – A Framework for Deep Learning for Tabular Data](https://deep-and-shallow.com/2021/01/27/pytorch-tabular-a-framework-for-deep-learning-for-tabular-data/)
116-
[Neural Oblivious Decision Ensembles(NODE) – A State-of-the-Art Deep Learning Algorithm for Tabular Data](https://deep-and-shallow.com/2021/02/25/neural-oblivious-decision-ensemblesnode-a-state-of-the-art-deep-learning-algorithm-for-tabular-data/)
117-
[Mixture Density Networks: Probabilistic Regression for Uncertainty Estimation](https://deep-and-shallow.com/2021/03/20/mixture-density-networks-probabilistic-regression-for-uncertainty-estimation/)
116+
- [PyTorch Tabular – A Framework for Deep Learning for Tabular Data](https://deep-and-shallow.com/2021/01/27/pytorch-tabular-a-framework-for-deep-learning-for-tabular-data/)
117+
- [Neural Oblivious Decision Ensembles(NODE) – A State-of-the-Art Deep Learning Algorithm for Tabular Data](https://deep-and-shallow.com/2021/02/25/neural-oblivious-decision-ensemblesnode-a-state-of-the-art-deep-learning-algorithm-for-tabular-data/)
118+
- [Mixture Density Networks: Probabilistic Regression for Uncertainty Estimation](https://deep-and-shallow.com/2021/03/20/mixture-density-networks-probabilistic-regression-for-uncertainty-estimation/)
118119

119120
## Future Roadmap(Contributions are Welcome)
120121

@@ -124,8 +125,13 @@ loaded_model = TabularModel.load_from_checkpoint("examples/basic")
124125
4. Add Fourier Encoding for cyclic time variables
125126
5. Integrate Optuna Hyperparameter Tuning
126127
6. Add Text and Image Modalities for mixed modal problems
127-
7. Integrate Wide and Deep model
128-
8. Integrate TabTransformer
128+
7. Add Variable Importance
129+
8. Integrate SHAP for interpretability
130+
** DL Models**
131+
9. [DNF-Net: A Neural Architecture for Tabular Data](https://www.semanticscholar.org/paper/DNF-Net%3A-A-Neural-Architecture-for-Tabular-Data-Abutbul-Elidan/99c49f3a917815eed2144bfb5d064623ff09ade5)
132+
10. [Attention augmented differentiable forest for tabular data](https://www.semanticscholar.org/paper/Attention-augmented-differentiable-forest-for-data-Chen/57990b40affc5f34f4029dab39bc78e44e7d3b10)
133+
11. [XBNet : An Extremely Boosted Neural Network](https://arxiv.org/abs/2106.05239v2)
134+
12. [Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959)
129135
## Citation
130136
If you use PyTorch Tabular for a scientific publication, we would appreciate citations to the published software and the following paper:
131137

examples/to_test_classification.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,22 @@
9090
normalize_continuous_features=False,
9191
)
9292
# model_config = CategoryEmbeddingModelConfig(task="classification", metrics=["f1","accuracy"], metrics_params=[{"num_classes":num_classes},{}])
93-
model_config = NodeConfig(
94-
task="classification",
95-
depth=4,
96-
num_trees=1024,
97-
input_dropout=0.0,
98-
metrics=["f1", "accuracy"],
99-
metrics_params=[{"num_classes": num_classes, "average": "macro"}, {}],
100-
)
101-
# model_config = TabTransformerConfig(
93+
# model_config = NodeConfig(
10294
# task="classification",
95+
# depth=4,
96+
# num_trees=1024,
97+
# input_dropout=0.0,
10398
# metrics=["f1", "accuracy"],
10499
# metrics_params=[{"num_classes": num_classes, "average": "macro"}, {}],
105100
# )
101+
model_config = TabTransformerConfig(
102+
task="classification",
103+
metrics=["f1", "accuracy"],
104+
share_embedding = True,
105+
share_embedding_strategy="fraction",
106+
shared_embedding_fraction=0.25,
107+
metrics_params=[{"num_classes": num_classes, "average": "macro"}, {}],
108+
)
106109
trainer_config = TrainerConfig(gpus=-1, auto_select_gpus=True, fast_dev_run=False, max_epochs=5, batch_size=512)
107110
experiment_config = ExperimentConfig(project_name="PyTorch Tabular Example",
108111
run_name="node_forest_cov",

pytorch_tabular/models/autoint/autoint.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,9 @@ def forward(self, x: Dict):
165165
y_min, y_max = self.hparams.target_range[i]
166166
y_hat[:, i] = y_min + nn.Sigmoid()(y_hat[:, i]) * (y_max - y_min)
167167
return {"logits": y_hat, "backbone_features": x}
168+
169+
def extract_embedding(self):
170+
if len(self.hparams.categorical_cols) > 0:
171+
return self.backbone.cat_embedding_layers
172+
else:
173+
raise ValueError("Model has been trained with no categorical feature and therefore can't be used as a Categorical Encoder")

pytorch_tabular/models/node/node_model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,10 @@ def forward(self, x: Dict):
142142
y_min, y_max = self.hparams.target_range[i]
143143
y_hat[:, i] = y_min + nn.Sigmoid()(y_hat[:, i]) * (y_max - y_min)
144144
return {"logits": y_hat, "backbone_features": x}
145+
146+
def extract_embedding(self):
147+
if self.hparams.embed_categorical:
148+
if self.embedding_cat_dim != 0:
149+
return self.embedding_layers
150+
else:
151+
raise ValueError("Model has been trained with no categorical feature and therefore can't be used as a Categorical Encoder")

pytorch_tabular/models/tab_transformer/components.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,37 @@ def forward(self, x):
4444
out = rearrange(out, "b h n d -> b n (h d)", h=h)
4545
return self.to_out(out)
4646

47-
48-
# transformer
49-
47+
#Shamelessly copied with slight adaptation from https://github.com/jrzaurin/pytorch-widedeep/blob/b487b06721c5abe56ac68c8a38580b95e0897fd4/pytorch_widedeep/models/tab_transformer.py
48+
class SharedEmbeddings(nn.Module):
49+
def __init__(
50+
self,
51+
num_embed: int,
52+
embed_dim: int,
53+
add_shared_embed: bool = False,
54+
frac_shared_embed: float=0.25,
55+
):
56+
super(SharedEmbeddings, self).__init__()
57+
assert (
58+
frac_shared_embed < 1
59+
), "'frac_shared_embed' must be less than 1"
60+
61+
self.add_shared_embed = add_shared_embed
62+
self.embed = nn.Embedding(num_embed, embed_dim, padding_idx=0)
63+
self.embed.weight.data.clamp_(-2, 2)
64+
if add_shared_embed:
65+
col_embed_dim = embed_dim
66+
else:
67+
col_embed_dim = int(embed_dim * frac_shared_embed)
68+
self.shared_embed = nn.Parameter(torch.empty(1, col_embed_dim).uniform_(-1, 1))
69+
70+
def forward(self, X: torch.Tensor) -> torch.Tensor:
71+
out = self.embed(X)
72+
shared_embed = self.shared_embed.expand(out.shape[0], -1)
73+
if self.add_shared_embed:
74+
out += shared_embed
75+
else:
76+
out[:, : shared_embed.shape[1]] = shared_embed
77+
return out
5078

5179
class TransformerEncoderBlock(nn.Module):
5280
def __init__(

pytorch_tabular/models/tab_transformer/config.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,31 @@ class TabTransformerConfig(ModelConfig):
5454
"help": "The embedding dimension for the input categorical features. Defaults to 32"
5555
},
5656
)
57+
embedding_dropout: float = field(
58+
default=0.1,
59+
metadata={
60+
"help": "Dropout to be applied to the Categorical Embedding. Defaults to 0.1"
61+
},
62+
)
63+
share_embedding: bool = field(
64+
default=False,
65+
metadata={
66+
"help": "The flag turns on shared embeddings in the input embedding process. The key idea here is to have an embedding for the feature as a whole along with embeddings of each unique values of that column. For more details refer to Appendix A of the TabTransformer paper. Defaults to False"
67+
}
68+
)
69+
share_embedding_strategy: Optional[str] = field(
70+
default="fraction",
71+
metadata={
72+
"help": "There are two strategies in adding shared embeddings. 1. `add` - A separate embedding for the feature is added to the embedding of the unique values of the feature. 2. `fraction` - A fraction of the input embedding is reserved for the shared embedding of the feature. Defaults to fraction.",
73+
"choices": ["add","fraction"]
74+
}
75+
)
76+
shared_embedding_fraction: float = field(
77+
default=0.25,
78+
metadata={
79+
"help": "Fraction of the input_embed_dim to be reserved by the shared embedding. Should be less than one. Defaults to 0.25"
80+
},
81+
)
5782
num_heads: int = field(
5883
default=8,
5984
metadata={
@@ -72,12 +97,6 @@ class TabTransformerConfig(ModelConfig):
7297
"help": "The number of hidden units in the Multi-Headed Attention layers. Defaults to None and will be same as input_dim."
7398
},
7499
)
75-
embedding_dropout: float = field(
76-
default=0.1,
77-
metadata={
78-
"help": "Dropout to be applied to the Categorical Embedding. Defaults to 0.1"
79-
},
80-
)
81100
attn_dropout: float = field(
82101
default=0.1,
83102
metadata={

pytorch_tabular/models/tab_transformer/tab_transformer.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
# Pytorch Tabular
22
# Author: Manu Joseph <manujoseph@gmail.com>
33
# For license information, see LICENSE.TXT
4-
# Inspired by https://github.com/lucidrains/tab-transformer-pytorch/blob/main/tab_transformer_pytorch/tab_transformer_pytorch.py
4+
# Inspired by implementations
5+
# 1. lucidrains - https://github.com/lucidrains/tab-transformer-pytorch/
6+
# If you are interested in Transformers, you should definitely check out his repositories.
7+
# 2. PyTorch Wide and Deep - https://github.com/jrzaurin/pytorch-widedeep/
8+
# It is another library for tabular data, which supports multi modal problems.
9+
# Check out the library if you haven't already.
10+
# 3. AutoGluon - https://github.com/awslabs/autogluon
11+
# AutoGluon is an AuttoML library which supports Tabular data as well. it is from Amazon Research and is in MXNet
12+
# 4. LabML Annotated Deep Learning Papers - The position-wise FF was shamelessly copied from
13+
# https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers
514
"""TabTransformer Model"""
615
import logging
716
from typing import Dict, OrderedDict
@@ -13,29 +22,46 @@
1322
from einops import rearrange
1423

1524
from pytorch_tabular.utils import _initialize_layers, _linear_dropout_bn
16-
from .components import TransformerEncoderBlock
25+
from .components import TransformerEncoderBlock, SharedEmbeddings
1726

1827
from ..base_model import BaseModel
1928

2029
logger = logging.getLogger(__name__)
2130

31+
2232
class TabTransformerBackbone(pl.LightningModule):
2333
def __init__(self, config: DictConfig):
2434
super().__init__()
35+
assert config.share_embedding_strategy in [
36+
"add",
37+
"fraction",
38+
], f"`share_embedding_strategy` should be one of `add` or `fraction`, not {self.hparams.share_embedding_strategy}"
2539
self.save_hyperparameters(config)
2640
self._build_network()
27-
#TODO Add output_dim
2841

2942
def _build_network(self):
3043
if len(self.hparams.categorical_cols) > 0:
3144
# Category Embedding layers
32-
# self.embedding_dropout = nn.Dropout(self.hparams.embedding_dropout)
33-
self.cat_embedding_layers = nn.ModuleList(
34-
[
35-
nn.Embedding(cardinality, self.hparams.input_embed_dim)
36-
for cardinality in self.hparams.categorical_cardinality
37-
]
38-
)
45+
if self.hparams.share_embedding:
46+
self.cat_embedding_layers = nn.ModuleList(
47+
[
48+
SharedEmbeddings(
49+
cardinality,
50+
self.hparams.input_embed_dim,
51+
add_shared_embed=self.hparams.share_embedding_strategy
52+
== "add",
53+
frac_shared_embed=self.hparams.shared_embedding_fraction,
54+
)
55+
for cardinality in self.hparams.categorical_cardinality
56+
]
57+
)
58+
else:
59+
self.cat_embedding_layers = nn.ModuleList(
60+
[
61+
nn.Embedding(cardinality, self.hparams.input_embed_dim)
62+
for cardinality in self.hparams.categorical_cardinality
63+
]
64+
)
3965
if self.hparams.embedding_dropout != 0:
4066
self.embed_dropout = nn.Dropout(self.hparams.embedding_dropout)
4167
self.transformer_blocks = OrderedDict()
@@ -44,17 +70,20 @@ def _build_network(self):
4470
input_embed_dim=self.hparams.input_embed_dim,
4571
num_heads=self.hparams.num_heads,
4672
ff_hidden_multiplier=self.hparams.ff_hidden_multiplier,
47-
ff_activation = self.hparams.transformer_activation,
73+
ff_activation=self.hparams.transformer_activation,
4874
attn_dropout=self.hparams.attn_dropout,
4975
ff_dropout=self.hparams.ff_dropout,
5076
add_norm_dropout=self.hparams.add_norm_dropout,
5177
)
5278
self.transformer_blocks = nn.Sequential(self.transformer_blocks)
53-
79+
self.attention_weights = [None] * self.hparams.num_attn_blocks
5480
if self.hparams.batch_norm_continuous_input:
5581
self.normalizing_batch_norm = nn.BatchNorm1d(self.hparams.continuous_dim)
5682
# Final MLP Layers
57-
_curr_units = self.hparams.input_embed_dim*len(self.hparams.categorical_cols) + self.hparams.continuous_dim
83+
_curr_units = (
84+
self.hparams.input_embed_dim * len(self.hparams.categorical_cols)
85+
+ self.hparams.continuous_dim
86+
)
5887
# Linear Layers
5988
layers = []
6089
for units in self.hparams.out_ff_layers.split("-"):
@@ -87,7 +116,7 @@ def forward(self, x: Dict):
87116
x = self.embed_dropout(x)
88117
for i, block in enumerate(self.transformer_blocks):
89118
x = block(x)
90-
#Flatten (Batch, N_Categorical, Hidden) --> (Batch, N_CategoricalxHidden)
119+
# Flatten (Batch, N_Categorical, Hidden) --> (Batch, N_CategoricalxHidden)
91120
x = rearrange(x, "b n h -> b (n h)")
92121
if self.hparams.continuous_dim > 0:
93122
if self.hparams.batch_norm_continuous_input:
@@ -99,6 +128,7 @@ def forward(self, x: Dict):
99128
x = self.linear_layers(x)
100129
return x
101130

131+
102132
class TabTransformerModel(BaseModel):
103133
def __init__(self, config: DictConfig, **kwargs):
104134
super().__init__(config, **kwargs)
@@ -111,7 +141,11 @@ def _build_network(self):
111141
self.output_layer = nn.Linear(
112142
self.backbone.output_dim, self.hparams.output_dim
113143
) # output_dim auto-calculated from other config
114-
_initialize_layers(self.hparams.out_ff_activation, self.hparams.out_ff_initialization, self.output_layer)
144+
_initialize_layers(
145+
self.hparams.out_ff_activation,
146+
self.hparams.out_ff_initialization,
147+
self.output_layer,
148+
)
115149

116150
def forward(self, x: Dict):
117151
x = self.backbone(x)
@@ -124,3 +158,9 @@ def forward(self, x: Dict):
124158
y_min, y_max = self.hparams.target_range[i]
125159
y_hat[:, i] = y_min + nn.Sigmoid()(y_hat[:, i]) * (y_max - y_min)
126160
return {"logits": y_hat, "backbone_features": x}
161+
162+
def extract_embedding(self):
163+
if len(self.hparams.categorical_cols) > 0:
164+
return self.cat_embedding_layers
165+
else:
166+
raise ValueError("Model has been trained with no categorical feature and therefore can't be used as a Categorical Encoder")

0 commit comments

Comments
 (0)