Skip to content

Commit 86c71bc

Browse files
committed
-- style formatting for TabTransformer
-- DocStrings update for TabTransformer
1 parent 75dc0cd commit 86c71bc

File tree

3 files changed

+117
-71
lines changed

3 files changed

+117
-71
lines changed

pytorch_tabular/models/tab_transformer/components.py

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,30 @@
1+
# Pytorch Tabular
2+
# Author: Manu Joseph <manujoseph@gmail.com>
3+
# For license information, see LICENSE.TXT
4+
# Inspired by implementations
5+
# 1. lucidrains - https://github.com/lucidrains/tab-transformer-pytorch/
6+
# If you are interested in Transformers, you should definitely check out his repositories.
7+
# 2. PyTorch Wide and Deep - https://github.com/jrzaurin/pytorch-widedeep/
8+
# It is another library for tabular data, which supports multi modal problems.
9+
# Check out the library if you haven't already.
10+
# 3. AutoGluon - https://github.com/awslabs/autogluon
11+
# AutoGluon is an AuttoML library which supports Tabular data as well. it is from Amazon Research and is in MXNet
12+
# 4. LabML Annotated Deep Learning Papers - The position-wise FF was shamelessly copied from
13+
# https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers
114
from typing import Optional
15+
216
import torch
317
import torch.nn.functional as F
4-
from torch import nn, einsum
5-
from pytorch_tabular.models import common #import PositionWiseFeedForward, GEGLU, ReGLU, SwiGLU
618
from einops import rearrange
19+
from torch import einsum, nn
20+
21+
from pytorch_tabular.models import common
722

823

924
class AddNorm(nn.Module):
25+
"""
26+
Applies LayerNorm, Dropout and adds to input. Standard AddNorm operations in Transformers
27+
"""
1028
def __init__(self, input_dim: int, dropout: float):
1129
super(AddNorm, self).__init__()
1230
self.dropout = nn.Dropout(dropout)
@@ -17,11 +35,16 @@ def forward(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
1735

1836

1937
class MultiHeadedAttention(nn.Module):
38+
"""
39+
Multi Headed Attention Block in Transformers
40+
"""
2041
def __init__(
2142
self, input_dim: int, num_heads: int = 8, head_dim: int = 16, dropout: int = 0.1
2243
):
2344
super().__init__()
24-
assert input_dim % num_heads == 0, "'input_dim' must be multiples of 'num_heads'"
45+
assert (
46+
input_dim % num_heads == 0
47+
), "'input_dim' must be multiples of 'num_heads'"
2548
inner_dim = head_dim * num_heads
2649
self.n_heads = num_heads
2750
self.scale = head_dim ** -0.5
@@ -44,19 +67,21 @@ def forward(self, x):
4467
out = rearrange(out, "b h n d -> b n (h d)", h=h)
4568
return self.to_out(out)
4669

47-
#Shamelessly copied with slight adaptation from https://github.com/jrzaurin/pytorch-widedeep/blob/b487b06721c5abe56ac68c8a38580b95e0897fd4/pytorch_widedeep/models/tab_transformer.py
70+
71+
# Slight adaptation from https://github.com/jrzaurin/pytorch-widedeep which in turn adapted from AutoGluon
4872
class SharedEmbeddings(nn.Module):
73+
"""
74+
Enables different values in a categorical feature to share some embeddings across
75+
"""
4976
def __init__(
5077
self,
5178
num_embed: int,
5279
embed_dim: int,
5380
add_shared_embed: bool = False,
54-
frac_shared_embed: float=0.25,
81+
frac_shared_embed: float = 0.25,
5582
):
5683
super(SharedEmbeddings, self).__init__()
57-
assert (
58-
frac_shared_embed < 1
59-
), "'frac_shared_embed' must be less than 1"
84+
assert frac_shared_embed < 1, "'frac_shared_embed' must be less than 1"
6085

6186
self.add_shared_embed = add_shared_embed
6287
self.embed = nn.Embedding(num_embed, embed_dim, padding_idx=0)
@@ -76,7 +101,10 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
76101
out[:, : shared_embed.shape[1]] = shared_embed
77102
return out
78103

104+
79105
class TransformerEncoderBlock(nn.Module):
106+
"""A single Transformer Encoder Block
107+
"""
80108
def __init__(
81109
self,
82110
input_embed_dim: int,
@@ -97,17 +125,19 @@ def __init__(
97125
else transformer_head_dim,
98126
dropout=attn_dropout,
99127
)
100-
128+
101129
try:
102-
self.pos_wise_ff = getattr(common, ff_activation)(d_model=input_embed_dim,
103-
d_ff=input_embed_dim * ff_hidden_multiplier,
104-
dropout=ff_dropout)
130+
self.pos_wise_ff = getattr(common, ff_activation)(
131+
d_model=input_embed_dim,
132+
d_ff=input_embed_dim * ff_hidden_multiplier,
133+
dropout=ff_dropout,
134+
)
105135
except AttributeError:
106136
self.pos_wise_ff = getattr(common, "PositionWiseFeedForward")(
107137
d_model=input_embed_dim,
108138
d_ff=input_embed_dim * ff_hidden_multiplier,
109139
dropout=ff_dropout,
110-
activation = getattr(nn, self.hparams.ff_activation)
140+
activation=getattr(nn, self.hparams.ff_activation),
111141
)
112142
self.attn_add_norm = AddNorm(input_embed_dim, add_norm_dropout)
113143
self.ff_add_norm = AddNorm(input_embed_dim, add_norm_dropout)
@@ -116,26 +146,4 @@ def forward(self, x):
116146
y = self.mha(x)
117147
x = self.attn_add_norm(x, y)
118148
y = self.pos_wise_ff(y)
119-
return self.ff_add_norm(x, y)
120-
121-
122-
# class MLP(nn.Module):
123-
# def __init__(self, dims, act=None):
124-
# super().__init__()
125-
# dims_pairs = list(zip(dims[:-1], dims[1:]))
126-
# layers = []
127-
# for ind, (dim_in, dim_out) in enumerate(dims_pairs):
128-
# is_last = ind >= (len(dims) - 1)
129-
# linear = nn.Linear(dim_in, dim_out)
130-
# layers.append(linear)
131-
132-
# if is_last:
133-
# continue
134-
135-
# act = default(act, nn.ReLU())
136-
# layers.append(act)
137-
138-
# self.mlp = nn.Sequential(*layers)
139-
140-
# def forward(self, x):
141-
# return self.mlp(x)
149+
return self.ff_add_norm(x, y)

pytorch_tabular/models/tab_transformer/config.py

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,72 @@
1212
class TabTransformerConfig(ModelConfig):
1313
"""Tab Transformer configuration
1414
Args:
15-
task (str): Specify whether the problem is regression of classification.Choices are: regression classification
15+
task (str): Specify whether the problem is regression of classification.
16+
Choices are: [`regression`,`classification`].
17+
embedding_dims (Union[List[int], NoneType]): The dimensions of the embedding for
18+
each categorical column as a list of tuples (cardinality, embedding_dim).
19+
If left empty, will infer using the cardinality of the categorical column using
20+
the rule min(50, (x + 1) // 2)
1621
learning_rate (float): The learning rate of the model
17-
loss (Union[str, NoneType]): The loss function to be applied.
18-
By Default it is MSELoss for regression and CrossEntropyLoss for classification.
19-
Unless you are sure what you are doing, leave it at MSELoss or L1Loss for regression and CrossEntropyLoss for classification
20-
metrics (Union[List[str], NoneType]): the list of metrics you need to track during training.
21-
The metrics should be one of the metrics implemented in PyTorch Lightning.
22-
By default, it is Accuracy if classification and MeanSquaredLogError for regression
23-
metrics_params (Union[List, NoneType]): The parameters to be passed to the Metrics initialized
24-
target_range (Union[List, NoneType]): The range in which we should limit the output variable. Currently ignored for multi-target regression
25-
Typically used for Regression problems. If left empty, will not apply any restrictions
22+
loss (Union[str, NoneType]): The loss function to be applied.
23+
By Default it is MSELoss for regression and CrossEntropyLoss for classification.
24+
Unless you are sure what you are doing, leave it at MSELoss or L1Loss for regression
25+
and CrossEntropyLoss for classification
26+
metrics (Union[List[str], NoneType]): the list of metrics you need to track during training.
27+
The metrics should be one of the functional metrics implemented in ``torchmetrics``.
28+
By default, it is accuracy if classification and mean_squared_error for regression
29+
metrics_params (Union[List, NoneType]): The parameters to be passed to the metrics function
30+
target_range (Union[List, NoneType]): The range in which we should limit the output variable.
31+
Currently ignored for multi-target regression. Typically used for Regression problems.
32+
If left empty, will not apply any restrictions
2633
27-
attn_embed_dim (int): The number of hidden units in the Multi-Headed Attention layers. Defaults to 32
28-
num_heads (int): The number of heads in the Multi-Headed Attention layer. Defaults to 2
29-
num_attn_blocks (int): The number of layers of stacked Multi-Headed Attention layers. Defaults to 2
30-
attn_dropouts (float): Dropout between layers of Multi-Headed Attention Layers. Defaults to 0.0
31-
has_residuals (bool): Flag to have a residual connect from enbedded output to attention layer output.
32-
Defaults to True
33-
embedding_dim (int): The dimensions of the embedding for continuous and categorical columns. Defaults to 16
34-
embedding_dropout (float): probability of an embedding element to be zeroed. Defaults to 0.0
35-
deep_layers (bool): Flag to enable a deep MLP layer before the Multi-Headed Attention layer. Defaults to False
36-
layers (str): Hyphen-separated number of layers and units in the deep MLP. Defaults to 128-64-32
37-
activation (str): The activation type in the deep MLP. The default activaion in PyTorch like
38-
ReLU, TanH, LeakyReLU, etc. https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity.
34+
input_embed_dim (int): The embedding dimension for the input categorical features.
35+
Defaults to 32
36+
embedding_dropout (float): Dropout to be applied to the Categorical Embedding.
37+
Defaults to 0.1
38+
share_embedding (bool): The flag turns on shared embeddings in the input embedding process.
39+
The key idea here is to have an embedding for the feature as a whole along with embeddings of
40+
each unique values of that column. For more details refer to Appendix A of the TabTransformer paper.
41+
Defaults to False
42+
share_embedding_strategy (Union[str, NoneType]): There are two strategies in adding shared embeddings.
43+
1. `add` - A separate embedding for the feature is added to the embedding of the unique values of the feature.
44+
2. `fraction` - A fraction of the input embedding is reserved for the shared embedding of the feature.
45+
Defaults to fraction.
46+
Choices are: [`add`,`fraction`].
47+
shared_embedding_fraction (float): Fraction of the input_embed_dim to be reserved by the shared embedding.
48+
Should be less than one. Defaults to 0.25
49+
num_heads (int): The number of heads in the Multi-Headed Attention layer.
50+
Defaults to 8
51+
num_attn_blocks (int): The number of layers of stacked Multi-Headed Attention layers.
52+
Defaults to 6
53+
transformer_head_dim (Union[int, NoneType]): The number of hidden units in the Multi-Headed Attention layers.
54+
Defaults to None and will be same as input_dim.
55+
attn_dropout (float): Dropout to be applied after Multi headed Attention.
56+
Defaults to 0.1
57+
add_norm_dropout (float): Dropout to be applied in the AddNorm Layer.
58+
Defaults to 0.1
59+
ff_dropout (float): Dropout to be applied in the Positionwise FeedForward Network.
60+
Defaults to 0.1
61+
ff_hidden_multiplier (int): Multiple by which the Positionwise FF layer scales the input.
62+
Defaults to 4
63+
transformer_activation (str): The activation type in the transformer feed forward layers.
64+
In addition to the default activation in PyTorch like ReLU, TanH, LeakyReLU, etc.
65+
https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity,
66+
GEGLU, ReGLU and SwiGLU are also implemented(https://arxiv.org/pdf/2002.05202.pdf).
67+
Defaults to GEGLU
68+
out_ff_layers (str): Hyphen-separated number of layers and units in the deep MLP.
69+
Defaults to 128-64-32
70+
out_ff_activation (str): The activation type in the deep MLP. The default activaion in PyTorch like ReLU, TanH, LeakyReLU, etc.
71+
https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity.
3972
Defaults to ReLU
40-
dropout (float): probability of an classification element to be zeroed in the deep MLP. Defaults to 0.0
41-
use_batch_norm (bool): Flag to include a BatchNorm layer after each Linear Layer+DropOut. Defaults to False
42-
batch_norm_continuous_input (bool): If True, we will normalize the contiinuous layer by passing it through a BatchNorm layer. Defaults to False
43-
attention_pooling (bool): If True, will combine the attention outputs of each block for final prediction. Defaults to False
44-
initialization (str): Initialization scheme for the linear layers. Defaults to `kaiming`.
73+
out_ff_dropout (float): Probability of an classification element to be zeroed in the deep MLP.
74+
Defaults to 0.0
75+
use_batch_norm (bool): Flag to include a BatchNorm layer after each Linear Layer+DropOut.
76+
Defaults to False
77+
batch_norm_continuous_input (bool): If True, we will normalize the continuous layer by passing it through a BatchNorm layer.
78+
Defaults to False
79+
out_ff_initialization (str): Initialization scheme for the linear layers.
80+
Defaults to `kaiming`.
4581
Choices are: [`kaiming`,`xavier`,`random`].
4682
4783
Raises:
@@ -170,7 +206,7 @@ class TabTransformerConfig(ModelConfig):
170206
_config_name: str = field(default="TabTransformerConfig")
171207

172208

173-
# cls = AutoIntConfig
209+
# cls = TabTransformerConfig
174210
# desc = "Configuration for Data."
175211
# doc_str = f"{desc}\nArgs:"
176212
# for key in cls.__dataclass_fields__.keys():

pytorch_tabular/models/tab_transformer/tab_transformer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# Pytorch Tabular
22
# Author: Manu Joseph <manujoseph@gmail.com>
33
# For license information, see LICENSE.TXT
4-
# Inspired by implementations
4+
# Inspired by implementations
55
# 1. lucidrains - https://github.com/lucidrains/tab-transformer-pytorch/
66
# If you are interested in Transformers, you should definitely check out his repositories.
7-
# 2. PyTorch Wide and Deep - https://github.com/jrzaurin/pytorch-widedeep/
8-
# It is another library for tabular data, which supports multi modal problems.
7+
# 2. PyTorch Wide and Deep - https://github.com/jrzaurin/pytorch-widedeep/
8+
# It is another library for tabular data, which supports multi modal problems.
99
# Check out the library if you haven't already.
1010
# 3. AutoGluon - https://github.com/awslabs/autogluon
1111
# AutoGluon is an AuttoML library which supports Tabular data as well. it is from Amazon Research and is in MXNet
12-
# 4. LabML Annotated Deep Learning Papers - The position-wise FF was shamelessly copied from
12+
# 4. LabML Annotated Deep Learning Papers - The position-wise FF was shamelessly copied from
1313
# https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers
1414
"""TabTransformer Model"""
1515
import logging
@@ -18,13 +18,13 @@
1818
import pytorch_lightning as pl
1919
import torch
2020
import torch.nn as nn
21-
from omegaconf import DictConfig
2221
from einops import rearrange
22+
from omegaconf import DictConfig
2323

2424
from pytorch_tabular.utils import _initialize_layers, _linear_dropout_bn
25-
from .components import TransformerEncoderBlock, SharedEmbeddings
2625

2726
from ..base_model import BaseModel
27+
from .components import SharedEmbeddings, TransformerEncoderBlock
2828

2929
logger = logging.getLogger(__name__)
3030

@@ -163,4 +163,6 @@ def extract_embedding(self):
163163
if len(self.hparams.categorical_cols) > 0:
164164
return self.cat_embedding_layers
165165
else:
166-
raise ValueError("Model has been trained with no categorical feature and therefore can't be used as a Categorical Encoder")
166+
raise ValueError(
167+
"Model has been trained with no categorical feature and therefore can't be used as a Categorical Encoder"
168+
)

0 commit comments

Comments
 (0)