Skip to content

Commit d6e80c5

Browse files
committed
-- working TabTransformer
-- made some utility functions independent of config
1 parent 014e619 commit d6e80c5

File tree

10 files changed

+358
-222
lines changed

10 files changed

+358
-222
lines changed

examples/to_test_classification.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pytorch_tabular.models.tab_transformer.config import TabTransformerConfig
12
import torch
23
import numpy as np
34
from torch.functional import norm
@@ -88,16 +89,21 @@
8889
continuous_feature_transform=None,#"quantile_normal",
8990
normalize_continuous_features=False,
9091
)
91-
model_config = CategoryEmbeddingModelConfig(task="classification", metrics=["f1","accuracy"], metrics_params=[{"num_classes":num_classes},{}])
92-
# model_config = NodeConfig(
92+
# model_config = CategoryEmbeddingModelConfig(task="classification", metrics=["f1","accuracy"], metrics_params=[{"num_classes":num_classes},{}])
93+
model_config = NodeConfig(
94+
task="classification",
95+
depth=4,
96+
num_trees=1024,
97+
input_dropout=0.0,
98+
metrics=["f1", "accuracy"],
99+
metrics_params=[{"num_classes": num_classes, "average": "macro"}, {}],
100+
)
101+
# model_config = TabTransformerConfig(
93102
# task="classification",
94-
# depth=4,
95-
# num_trees=1024,
96-
# input_dropout=0.0,
97103
# metrics=["f1", "accuracy"],
98104
# metrics_params=[{"num_classes": num_classes, "average": "macro"}, {}],
99105
# )
100-
trainer_config = TrainerConfig(gpus=-1, auto_select_gpus=True, fast_dev_run=False, max_epochs=5, batch_size=1024)
106+
trainer_config = TrainerConfig(gpus=-1, auto_select_gpus=True, fast_dev_run=False, max_epochs=5, batch_size=512)
101107
experiment_config = ExperimentConfig(project_name="PyTorch Tabular Example",
102108
run_name="node_forest_cov",
103109
exp_watch="gradients",
@@ -130,8 +136,10 @@
130136
result = tabular_model.evaluate(test)
131137
print(result)
132138
# test.drop(columns=target_name, inplace=True)
133-
# pred_df = tabular_model.predict(test)
139+
pred_df = tabular_model.predict(test)
140+
print(pred_df.head())
134141
# pred_df.to_csv("output/temp2.csv")
135-
# tabular_model.save_model("test_save")
136-
# new_model = TabularModel.load_from_checkpoint("test_save")
137-
# result = new_model.evaluate(test)
142+
tabular_model.save_model("test_save")
143+
new_model = TabularModel.load_from_checkpoint("test_save")
144+
result = new_model.evaluate(test)
145+
print(result)

pytorch_tabular/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
AutoIntMDNConfig
1313
)
1414
from .autoint import AutoIntConfig, AutoIntModel
15+
from .tab_transformer import TabTransformerConfig, TabTransformerModel
1516
from .base_model import BaseModel
1617
from . import category_embedding, node, mixture_density, tabnet, autoint
1718

@@ -33,9 +34,12 @@
3334
"AutoIntMDNConfig",
3435
"AutoIntConfig",
3536
"AutoIntModel",
37+
"TabTransformerConfig",
38+
"TabTransformerModel",
3639
"category_embedding",
3740
"node",
3841
"mixture_density",
3942
"tabnet",
4043
"autoint",
44+
"tab_transformer"
4145
]

pytorch_tabular/models/autoint/autoint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,24 @@ def _build_network(self):
4646
# Deep Layers
4747
_curr_units = self.hparams.embedding_dim
4848
if self.hparams.deep_layers:
49-
activation = getattr(nn, self.hparams.activation)
5049
# Linear Layers
5150
layers = []
5251
for units in self.hparams.layers.split("-"):
5352
layers.extend(
5453
_linear_dropout_bn(
55-
self.hparams,
54+
self.hparams.activation,
55+
self.hparams.initialization,
56+
self.hparams.use_batch_norm,
5657
_curr_units,
5758
int(units),
58-
activation,
5959
self.hparams.dropout,
6060
)
6161
)
6262
_curr_units = int(units)
6363
self.linear_layers = nn.Sequential(*layers)
6464
# Projection to Multi-Headed Attention Dims
6565
self.attn_proj = nn.Linear(_curr_units, self.hparams.attn_embed_dim)
66-
_initialize_layers(self.hparams, self.attn_proj)
66+
_initialize_layers(self.hparams.activation, self.hparams.initialization, self.attn_proj)
6767
# Multi-Headed Attention Layers
6868
self.self_attns = nn.ModuleList(
6969
[
@@ -152,7 +152,7 @@ def _build_network(self):
152152
self.output_layer = nn.Linear(
153153
self.backbone.output_dim, self.hparams.output_dim
154154
) # output_dim auto-calculated from other config
155-
_initialize_layers(self.hparams, self.output_layer)
155+
_initialize_layers(self.hparams.activation, self.hparams.initialization, self.output_layer)
156156

157157
def forward(self, x: Dict):
158158
x = self.backbone(x)

pytorch_tabular/models/category_embedding/category_embedding_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def __init__(self, config: DictConfig, **kwargs):
2424
self._build_network()
2525

2626
def _build_network(self):
27-
activation = getattr(nn, self.hparams.activation)
2827
# Linear Layers
2928
layers = []
3029
_curr_units = self.embedding_cat_dim + self.hparams.continuous_dim
@@ -33,10 +32,11 @@ def _build_network(self):
3332
for units in self.hparams.layers.split("-"):
3433
layers.extend(
3534
_linear_dropout_bn(
36-
self.hparams,
35+
self.hparams.activation,
36+
self.hparams.initialization,
37+
self.hparams.use_batch_norm,
3738
_curr_units,
3839
int(units),
39-
activation,
4040
self.hparams.dropout,
4141
)
4242
)
@@ -69,7 +69,7 @@ def _build_network(self):
6969
self.output_layer = nn.Linear(
7070
self.backbone.output_dim, self.hparams.output_dim
7171
) # output_dim auto-calculated from other config
72-
_initialize_layers(self.hparams, self.output_layer)
72+
_initialize_layers(self.hparams.activation, self.hparams.initialization, self.output_layer)
7373

7474
def unpack_input(self, x: Dict):
7575
continuous_data, categorical_data = x["continuous"], x["categorical"]

pytorch_tabular/models/common.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch import nn, einsum
4+
5+
from einops import rearrange
6+
7+
8+
class Residual(nn.Module):
9+
def __init__(self, fn):
10+
super().__init__()
11+
self.fn = fn
12+
13+
def forward(self, x, **kwargs):
14+
return self.fn(x, **kwargs) + x
15+
16+
17+
# https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/feed_forward.py
18+
class PositionWiseFeedForward(nn.Module):
19+
"""
20+
title: Position-wise Feed-Forward Network (FFN)
21+
summary: Documented reusable implementation of the position wise feedforward network.
22+
23+
# Position-wise Feed-Forward Network (FFN)
24+
This is a [PyTorch](https://pytorch.org) implementation
25+
of position-wise feedforward network used in transformer.
26+
FFN consists of two fully connected layers.
27+
Number of dimensions in the hidden layer $d_{ff}$, is generally set to around
28+
four times that of the token embedding $d_{model}$.
29+
So it is sometime also called the expand-and-contract network.
30+
There is an activation at the hidden layer, which is
31+
usually set to ReLU (Rectified Linear Unit) activation, $$\max(0, x)$$
32+
That is, the FFN function is,
33+
$$FFN(x, W_1, W_2, b_1, b_2) = \max(0, x W_1 + b_1) W_2 + b_2$$
34+
where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters.
35+
Sometimes the
36+
GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
37+
$$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
38+
### Gated Linear Units
39+
This is a generic implementation that supports different variants including
40+
[Gated Linear Units](https://arxiv.org/abs/2002.05202) (GLU).
41+
We have also implemented experiments on these:
42+
* [experiment that uses `labml.configs`](glu_variants/experiment.html)
43+
* [simpler version from scratch](glu_variants/simple.html)
44+
"""
45+
46+
def __init__(self, d_model: int, d_ff: int,
47+
dropout: float = 0.1,
48+
activation=nn.ReLU(),
49+
is_gated: bool = False,
50+
bias1: bool = True,
51+
bias2: bool = True,
52+
bias_gate: bool = True):
53+
"""
54+
* `d_model` is the number of features in a token embedding
55+
* `d_ff` is the number of features in the hidden layer of the FFN
56+
* `dropout` is dropout probability for the hidden layer
57+
* `is_gated` specifies whether the hidden layer is gated
58+
* `bias1` specified whether the first fully connected layer should have a learnable bias
59+
* `bias2` specified whether the second fully connected layer should have a learnable bias
60+
* `bias_gate` specified whether the fully connected layer for the gate should have a learnable bias
61+
"""
62+
super().__init__()
63+
# Layer one parameterized by weight $W_1$ and bias $b_1$
64+
self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
65+
# Layer one parameterized by weight $W_1$ and bias $b_1$
66+
self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
67+
# Hidden layer dropout
68+
self.dropout = nn.Dropout(dropout)
69+
# Activation function $f$
70+
self.activation = activation
71+
# Whether there is a gate
72+
self.is_gated = is_gated
73+
if is_gated:
74+
# If there is a gate the linear layer to transform inputs to
75+
# be multiplied by the gate, parameterized by weight $V$ and bias $c$
76+
self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)
77+
78+
def forward(self, x: torch.Tensor):
79+
# $f(x W_1 + b_1)$
80+
g = self.activation(self.layer1(x))
81+
# If gated, $f(x W_1 + b_1) \otimes (x V + b) $
82+
if self.is_gated:
83+
x = g * self.linear_v(x)
84+
# Otherwise
85+
else:
86+
x = g
87+
# Apply dropout
88+
x = self.dropout(x)
89+
# $(f(x W_1 + b_1) \otimes (x V + b)) W_2 + b_2$ or $f(x W_1 + b_1) W_2 + b_2$
90+
# depending on whether it is gated
91+
return self.layer2(x)
92+
93+
# GLU Variants Improve Transformer https://arxiv.org/pdf/2002.05202.pdf
94+
class GEGLU(nn.Module):
95+
def __init__(self, d_model: int, d_ff: int,
96+
dropout: float = 0.1):
97+
super().__init__()
98+
self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.GELU(), True, False, False, False)
99+
100+
def forward(self, x: torch.Tensor):
101+
return self.ffn(x)
102+
103+
class ReGLU(nn.Module):
104+
def __init__(self, d_model: int, d_ff: int,
105+
dropout: float = 0.1):
106+
super().__init__()
107+
self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.ReLU(), True, False, False, False)
108+
109+
def forward(self, x: torch.Tensor):
110+
return self.ffn(x)
111+
112+
class SwiGLU(nn.Module):
113+
def __init__(self, d_model: int, d_ff: int,
114+
dropout: float = 0.1):
115+
super().__init__()
116+
self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout, nn.SiLU(), True, False, False, False)
117+
118+
def forward(self, x: torch.Tensor):
119+
return self.ffn(x)
120+
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .autoint import AutoIntBackbone, AutoIntModel
2-
from .config import AutoIntConfig
1+
from .tab_transformer import TabTransformerBackbone, TabTransformerModel
2+
from .config import TabTransformerConfig
33

4-
__all__ = ["AutoIntModel", "AutoIntBackbone", "AutoIntConfig"]
4+
__all__ = ["TabTransformerBackbone", "TabTransformerModel", "TabTransformerConfig"]

0 commit comments

Comments
 (0)