Skip to content

Commit 1c8bdab

Browse files
committed
-- created Tab Transformer config
1 parent 1ccfcc2 commit 1c8bdab

File tree

3 files changed

+291
-0
lines changed

3 files changed

+291
-0
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .autoint import AutoIntBackbone, AutoIntModel
2+
from .config import AutoIntConfig
3+
4+
__all__ = ["AutoIntModel", "AutoIntBackbone", "AutoIntConfig"]
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Pytorch Tabular
2+
# Author: Manu Joseph <manujoseph@gmail.com>
3+
# For license information, see LICENSE.TXT
4+
"""AutomaticFeatureInteraction Config"""
5+
from dataclasses import dataclass, field
6+
from typing import List, Optional
7+
8+
from pytorch_tabular.config import ModelConfig, _validate_choices
9+
10+
11+
@dataclass
12+
class TabTransformerConfig(ModelConfig):
13+
"""Tab Transformer configuration
14+
Args:
15+
task (str): Specify whether the problem is regression of classification.Choices are: regression classification
16+
learning_rate (float): The learning rate of the model
17+
loss (Union[str, NoneType]): The loss function to be applied.
18+
By Default it is MSELoss for regression and CrossEntropyLoss for classification.
19+
Unless you are sure what you are doing, leave it at MSELoss or L1Loss for regression and CrossEntropyLoss for classification
20+
metrics (Union[List[str], NoneType]): the list of metrics you need to track during training.
21+
The metrics should be one of the metrics implemented in PyTorch Lightning.
22+
By default, it is Accuracy if classification and MeanSquaredLogError for regression
23+
metrics_params (Union[List, NoneType]): The parameters to be passed to the Metrics initialized
24+
target_range (Union[List, NoneType]): The range in which we should limit the output variable. Currently ignored for multi-target regression
25+
Typically used for Regression problems. If left empty, will not apply any restrictions
26+
27+
attn_embed_dim (int): The number of hidden units in the Multi-Headed Attention layers. Defaults to 32
28+
num_heads (int): The number of heads in the Multi-Headed Attention layer. Defaults to 2
29+
num_attn_blocks (int): The number of layers of stacked Multi-Headed Attention layers. Defaults to 2
30+
attn_dropouts (float): Dropout between layers of Multi-Headed Attention Layers. Defaults to 0.0
31+
has_residuals (bool): Flag to have a residual connect from enbedded output to attention layer output.
32+
Defaults to True
33+
embedding_dim (int): The dimensions of the embedding for continuous and categorical columns. Defaults to 16
34+
embedding_dropout (float): probability of an embedding element to be zeroed. Defaults to 0.0
35+
deep_layers (bool): Flag to enable a deep MLP layer before the Multi-Headed Attention layer. Defaults to False
36+
layers (str): Hyphen-separated number of layers and units in the deep MLP. Defaults to 128-64-32
37+
activation (str): The activation type in the deep MLP. The default activaion in PyTorch like
38+
ReLU, TanH, LeakyReLU, etc. https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity.
39+
Defaults to ReLU
40+
dropout (float): probability of an classification element to be zeroed in the deep MLP. Defaults to 0.0
41+
use_batch_norm (bool): Flag to include a BatchNorm layer after each Linear Layer+DropOut. Defaults to False
42+
batch_norm_continuous_input (bool): If True, we will normalize the contiinuous layer by passing it through a BatchNorm layer. Defaults to False
43+
attention_pooling (bool): If True, will combine the attention outputs of each block for final prediction. Defaults to False
44+
initialization (str): Initialization scheme for the linear layers. Defaults to `kaiming`.
45+
Choices are: [`kaiming`,`xavier`,`random`].
46+
47+
Raises:
48+
NotImplementedError: Raises an error if task is not in ['regression','classification']
49+
"""
50+
51+
transformer_embed_dim: int = field(
52+
default=32,
53+
metadata={
54+
"help": "The number of hidden units in the Multi-Headed Attention layers. Defaults to 32"
55+
},
56+
)
57+
num_heads: int = field(
58+
default=8,
59+
metadata={
60+
"help": "The number of heads in the Multi-Headed Attention layer. Defaults to 8"
61+
},
62+
)
63+
num_attn_blocks: int = field(
64+
default=6,
65+
metadata={
66+
"help": "The number of layers of stacked Multi-Headed Attention layers. Defaults to 6"
67+
},
68+
)
69+
attn_dropouts: float = field(
70+
default=0.1,
71+
metadata={
72+
"help": "Dropout between layers of Multi-Headed Attention Layers. Defaults to 0.1"
73+
},
74+
)
75+
ff_dropouts: float = field(
76+
default=0.1,
77+
metadata={
78+
"help": "Dropout after FF layers. Defaults to 0.1"
79+
},
80+
)
81+
ff_hidden_multipliers: tuple = field(
82+
default=(4,2),
83+
metadata={
84+
"help": "Multiples by which the layers scale from Transformer output to logits. Defaults to (4,2)"
85+
},
86+
)
87+
mlp_activation: str = field(
88+
default="ReLU",
89+
metadata={
90+
"help": "The activation type in the final FF layer. The default activaion in PyTorch like ReLU, TanH, LeakyReLU, etc. https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity. Defaults to ReLU"
91+
},
92+
)
93+
transformer_activation: str = field(
94+
default="GEGLU",
95+
metadata={
96+
"help": "The activation type in the transformer feed forward layers. In addition to the default activation in PyTorch like ReLU, TanH, LeakyReLU, etc. https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity, GatedGLUs are also implemented(https://arxiv.org/pdf/2002.05202.pdf). Defaults to GEGLU"
97+
},
98+
)
99+
initialization: str = field(
100+
default="kaiming",
101+
metadata={
102+
"help": "Initialization scheme for the linear layers. Defaults to `kaiming`",
103+
"choices": ["kaiming", "xavier", "random"],
104+
},
105+
)
106+
_module_src: str = field(default="tab_transformer")
107+
_model_name: str = field(default="TabTransformerModel")
108+
_config_name: str = field(default="TabTransformerConfig")
109+
110+
111+
# cls = AutoIntConfig
112+
# desc = "Configuration for Data."
113+
# doc_str = f"{desc}\nArgs:"
114+
# for key in cls.__dataclass_fields__.keys():
115+
# atr = cls.__dataclass_fields__[key]
116+
# if atr.init:
117+
# type = str(atr.type).replace("<class '","").replace("'>","").replace("typing.","")
118+
# help_str = atr.metadata.get("help","")
119+
# if "choices" in atr.metadata.keys():
120+
# help_str += f'. Choices are: [{",".join(["`"+str(ch)+"`" for ch in atr.metadata["choices"]])}].'
121+
# # help_str += f'. Defaults to {atr.default}'
122+
# doc_str+=f'\n\t\t{key} ({type}): {help_str}'
123+
124+
# print(doc_str)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Pytorch Tabular
2+
# Author: Manu Joseph <manujoseph@gmail.com>
3+
# For license information, see LICENSE.TXT
4+
# Inspired by https://github.com/lucidrains/tab-transformer-pytorch/blob/main/tab_transformer_pytorch/tab_transformer_pytorch.py
5+
"""TabTransformer Model"""
6+
import logging
7+
from typing import Dict
8+
9+
import pytorch_lightning as pl
10+
import torch
11+
import torch.nn as nn
12+
from omegaconf import DictConfig
13+
14+
from pytorch_tabular.utils import _initialize_layers, _linear_dropout_bn
15+
16+
from ..base_model import BaseModel
17+
18+
logger = logging.getLogger(__name__)
19+
20+
#TODO dont use embedding_dims
21+
class TabTransformerBackbone(pl.LightningModule):
22+
def __init__(self, config: DictConfig):
23+
super().__init__()
24+
self.save_hyperparameters(config)
25+
self._build_network()
26+
27+
def _build_network(self):
28+
if len(self.hparams.categorical_cols)>0:
29+
# Category Embedding layers
30+
self.cat_embedding_layers = nn.ModuleList(
31+
[
32+
nn.Embedding(cardinality, self.hparams.embedding_dim)
33+
for cardinality in self.hparams.categorical_cardinality
34+
]
35+
)
36+
if self.hparams.batch_norm_continuous_input:
37+
self.normalizing_batch_norm = nn.BatchNorm1d(self.hparams.continuous_dim)
38+
# Continuous Embedding Layer
39+
self.cont_embedding_layer = nn.Embedding(
40+
self.hparams.continuous_dim, self.hparams.embedding_dim
41+
)
42+
if self.hparams.embedding_dropout != 0 and len(self.hparams.categorical_cols)>0:
43+
self.embed_dropout = nn.Dropout(self.hparams.embedding_dropout)
44+
# Deep Layers
45+
_curr_units = self.hparams.embedding_dim
46+
if self.hparams.deep_layers:
47+
activation = getattr(nn, self.hparams.activation)
48+
# Linear Layers
49+
layers = []
50+
for units in self.hparams.layers.split("-"):
51+
layers.extend(
52+
_linear_dropout_bn(
53+
self.hparams,
54+
_curr_units,
55+
int(units),
56+
activation,
57+
self.hparams.dropout,
58+
)
59+
)
60+
_curr_units = int(units)
61+
self.linear_layers = nn.Sequential(*layers)
62+
# Projection to Multi-Headed Attention Dims
63+
self.attn_proj = nn.Linear(_curr_units, self.hparams.attn_embed_dim)
64+
_initialize_layers(self.hparams, self.attn_proj)
65+
# Multi-Headed Attention Layers
66+
self.self_attns = nn.ModuleList(
67+
[
68+
nn.MultiheadAttention(
69+
self.hparams.attn_embed_dim,
70+
self.hparams.num_heads,
71+
dropout=self.hparams.attn_dropouts,
72+
)
73+
for _ in range(self.hparams.num_attn_blocks)
74+
]
75+
)
76+
if self.hparams.has_residuals:
77+
self.V_res_embedding = torch.nn.Linear(
78+
_curr_units,
79+
self.hparams.attn_embed_dim * self.hparams.num_attn_blocks
80+
if self.hparams.attention_pooling
81+
else self.hparams.attn_embed_dim,
82+
)
83+
self.output_dim = (
84+
self.hparams.continuous_dim + self.hparams.categorical_dim
85+
) * self.hparams.attn_embed_dim
86+
if self.hparams.attention_pooling:
87+
self.output_dim = self.output_dim * self.hparams.num_attn_blocks
88+
89+
def forward(self, x: Dict):
90+
# (B, N)
91+
continuous_data, categorical_data = x["continuous"], x["categorical"]
92+
x = None
93+
if len(self.hparams.categorical_cols) > 0:
94+
x_cat = [
95+
embedding_layer(categorical_data[:, i]).unsqueeze(1)
96+
for i, embedding_layer in enumerate(self.cat_embedding_layers)
97+
]
98+
# (B, N, E)
99+
x = torch.cat(x_cat, 1)
100+
if self.hparams.continuous_dim > 0:
101+
cont_idx = (
102+
torch.arange(self.hparams.continuous_dim)
103+
.expand(continuous_data.size(0), -1)
104+
.to(self.device)
105+
)
106+
if self.hparams.batch_norm_continuous_input:
107+
continuous_data = self.normalizing_batch_norm(continuous_data)
108+
x_cont = torch.mul(
109+
continuous_data.unsqueeze(2),
110+
self.cont_embedding_layer(cont_idx),
111+
)
112+
# (B, N, E)
113+
x = x_cont if x is None else torch.cat([x, x_cont], 1)
114+
if self.hparams.embedding_dropout != 0 and len(self.hparams.categorical_cols) > 0:
115+
x = self.embed_dropout(x)
116+
if self.hparams.deep_layers:
117+
x = self.linear_layers(x)
118+
# (N, B, E*) --> E* is the Attn Dimention
119+
cross_term = self.attn_proj(x).transpose(0, 1)
120+
if self.hparams.attention_pooling:
121+
attention_ops = []
122+
for self_attn in self.self_attns:
123+
cross_term, _ = self_attn(cross_term, cross_term, cross_term)
124+
if self.hparams.attention_pooling:
125+
attention_ops.append(cross_term)
126+
if self.hparams.attention_pooling:
127+
cross_term = torch.cat(attention_ops, dim=-1)
128+
# (B, N, E*)
129+
cross_term = cross_term.transpose(0, 1)
130+
if self.hparams.has_residuals:
131+
# (B, N, E*) --> Projecting Embedded input to Attention sub-space
132+
V_res = self.V_res_embedding(x)
133+
cross_term = cross_term + V_res
134+
# (B, NxE*)
135+
cross_term = nn.ReLU()(cross_term).reshape(-1, self.output_dim)
136+
return cross_term
137+
138+
139+
class TabTransformerModel(BaseModel):
140+
def __init__(self, config: DictConfig, **kwargs):
141+
super().__init__(config, **kwargs)
142+
143+
def _build_network(self):
144+
# Backbone
145+
self.backbone = TabTransformerBackbone(self.hparams)
146+
self.dropout = nn.Dropout(self.hparams.dropout)
147+
# Adding the last layer
148+
self.output_layer = nn.Linear(
149+
self.backbone.output_dim, self.hparams.output_dim
150+
) # output_dim auto-calculated from other config
151+
_initialize_layers(self.hparams, self.output_layer)
152+
153+
def forward(self, x: Dict):
154+
x = self.backbone(x)
155+
x = self.dropout(x)
156+
y_hat = self.output_layer(x)
157+
if (self.hparams.task == "regression") and (
158+
self.hparams.target_range is not None
159+
):
160+
for i in range(self.hparams.output_dim):
161+
y_min, y_max = self.hparams.target_range[i]
162+
y_hat[:, i] = y_min + nn.Sigmoid()(y_hat[:, i]) * (y_max - y_min)
163+
return {"logits": y_hat, "backbone_features": x}

0 commit comments

Comments
 (0)