Skip to content

Commit 8d6fb81

Browse files
committed
-- added FeatureTokenizerTransformer
-- refactored common components of TabTransformer to common.py
1 parent 05d7594 commit 8d6fb81

File tree

8 files changed

+626
-159
lines changed

8 files changed

+626
-159
lines changed

examples/to_test_classification.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pytorch_tabular.models.tab_transformer.config import TabTransformerConfig
2+
from pytorch_tabular.models.ft_transformer.config import FTTransformerConfig
23
import torch
34
import numpy as np
45
from torch.functional import norm
@@ -98,9 +99,19 @@
9899
# metrics=["f1", "accuracy"],
99100
# metrics_params=[{"num_classes": num_classes, "average": "macro"}, {}],
100101
# )
101-
model_config = TabTransformerConfig(
102+
# model_config = TabTransformerConfig(
103+
# task="classification",
104+
# metrics=["f1", "accuracy"],
105+
# share_embedding = True,
106+
# share_embedding_strategy="fraction",
107+
# shared_embedding_fraction=0.25,
108+
# metrics_params=[{"num_classes": num_classes, "average": "macro"}, {}],
109+
# )
110+
model_config = FTTransformerConfig(
102111
task="classification",
103112
metrics=["f1", "accuracy"],
113+
# embedding_initialization=None,
114+
embedding_bias=False,
104115
share_embedding = True,
105116
share_embedding_strategy="fraction",
106117
shared_embedding_fraction=0.25,
@@ -139,10 +150,10 @@
139150
result = tabular_model.evaluate(test)
140151
print(result)
141152
# test.drop(columns=target_name, inplace=True)
142-
pred_df = tabular_model.predict(test)
143-
print(pred_df.head())
153+
# pred_df = tabular_model.predict(test)
154+
# print(pred_df.head())
144155
# pred_df.to_csv("output/temp2.csv")
145-
tabular_model.save_model("test_save")
146-
new_model = TabularModel.load_from_checkpoint("test_save")
147-
result = new_model.evaluate(test)
148-
print(result)
156+
# tabular_model.save_model("test_save")
157+
# new_model = TabularModel.load_from_checkpoint("test_save")
158+
# result = new_model.evaluate(test)
159+
# print(result)

pytorch_tabular/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from .autoint import AutoIntConfig, AutoIntModel
1515
from .tab_transformer import TabTransformerConfig, TabTransformerModel
16+
from .ft_transformer import FTTransformerConfig, FTTransformerModel
1617
from .base_model import BaseModel
1718
from . import category_embedding, node, mixture_density, tabnet, autoint
1819

@@ -36,6 +37,8 @@
3637
"AutoIntModel",
3738
"TabTransformerConfig",
3839
"TabTransformerModel",
40+
"FTTransformerConfig",
41+
"FTTransformerModel",
3942
"category_embedding",
4043
"node",
4144
"mixture_density",

pytorch_tabular/models/common.py

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from typing import Optional
2+
13
import torch
24
import torch.nn.functional as F
3-
from torch import nn, einsum
4-
55
from einops import rearrange
6+
from torch import einsum, nn
7+
8+
from pytorch_tabular.models import common
69

710

811
class Residual(nn.Module):
@@ -118,3 +121,140 @@ def __init__(self, d_model: int, d_ff: int,
118121
def forward(self, x: torch.Tensor):
119122
return self.ffn(x)
120123

124+
# Inspired by implementations
125+
# 1. lucidrains - https://github.com/lucidrains/tab-transformer-pytorch/
126+
# If you are interested in Transformers, you should definitely check out his repositories.
127+
# 2. PyTorch Wide and Deep - https://github.com/jrzaurin/pytorch-widedeep/
128+
# It is another library for tabular data, which supports multi modal problems.
129+
# Check out the library if you haven't already.
130+
# 3. AutoGluon - https://github.com/awslabs/autogluon
131+
# AutoGluon is an AuttoML library which supports Tabular data as well. it is from Amazon Research and is in MXNet
132+
# 4. LabML Annotated Deep Learning Papers - The position-wise FF was shamelessly copied from
133+
# https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers
134+
135+
class AddNorm(nn.Module):
136+
"""
137+
Applies LayerNorm, Dropout and adds to input. Standard AddNorm operations in Transformers
138+
"""
139+
def __init__(self, input_dim: int, dropout: float):
140+
super(AddNorm, self).__init__()
141+
self.dropout = nn.Dropout(dropout)
142+
self.ln = nn.LayerNorm(input_dim)
143+
144+
def forward(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
145+
return self.ln(self.dropout(Y) + X)
146+
147+
148+
class MultiHeadedAttention(nn.Module):
149+
"""
150+
Multi Headed Attention Block in Transformers
151+
"""
152+
def __init__(
153+
self, input_dim: int, num_heads: int = 8, head_dim: int = 16, dropout: int = 0.1
154+
):
155+
super().__init__()
156+
assert (
157+
input_dim % num_heads == 0
158+
), "'input_dim' must be multiples of 'num_heads'"
159+
inner_dim = head_dim * num_heads
160+
self.n_heads = num_heads
161+
self.scale = head_dim ** -0.5
162+
163+
self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=False)
164+
self.to_out = nn.Linear(inner_dim, input_dim)
165+
166+
self.dropout = nn.Dropout(dropout)
167+
168+
def forward(self, x):
169+
h = self.n_heads
170+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
171+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
172+
sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
173+
174+
attn = sim.softmax(dim=-1)
175+
attn = self.dropout(attn)
176+
177+
out = einsum("b h i j, b h j d -> b h i d", attn, v)
178+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
179+
return self.to_out(out)
180+
181+
182+
# Slight adaptation from https://github.com/jrzaurin/pytorch-widedeep which in turn adapted from AutoGluon
183+
class SharedEmbeddings(nn.Module):
184+
"""
185+
Enables different values in a categorical feature to share some embeddings across
186+
"""
187+
def __init__(
188+
self,
189+
num_embed: int,
190+
embed_dim: int,
191+
add_shared_embed: bool = False,
192+
frac_shared_embed: float = 0.25,
193+
):
194+
super(SharedEmbeddings, self).__init__()
195+
assert frac_shared_embed < 1, "'frac_shared_embed' must be less than 1"
196+
197+
self.add_shared_embed = add_shared_embed
198+
self.embed = nn.Embedding(num_embed, embed_dim, padding_idx=0)
199+
self.embed.weight.data.clamp_(-2, 2)
200+
if add_shared_embed:
201+
col_embed_dim = embed_dim
202+
else:
203+
col_embed_dim = int(embed_dim * frac_shared_embed)
204+
self.shared_embed = nn.Parameter(torch.empty(1, col_embed_dim).uniform_(-1, 1))
205+
206+
def forward(self, X: torch.Tensor) -> torch.Tensor:
207+
out = self.embed(X)
208+
shared_embed = self.shared_embed.expand(out.shape[0], -1)
209+
if self.add_shared_embed:
210+
out += shared_embed
211+
else:
212+
out[:, : shared_embed.shape[1]] = shared_embed
213+
return out
214+
215+
216+
class TransformerEncoderBlock(nn.Module):
217+
"""A single Transformer Encoder Block
218+
"""
219+
def __init__(
220+
self,
221+
input_embed_dim: int,
222+
num_heads: int = 8,
223+
ff_hidden_multiplier: int = 4,
224+
ff_activation: str = "GEGLU",
225+
attn_dropout: float = 0.1,
226+
ff_dropout: float = 0.1,
227+
add_norm_dropout: float = 0.1,
228+
transformer_head_dim: Optional[int] = None,
229+
):
230+
super().__init__()
231+
self.mha = MultiHeadedAttention(
232+
input_embed_dim,
233+
num_heads,
234+
head_dim=input_embed_dim
235+
if transformer_head_dim is None
236+
else transformer_head_dim,
237+
dropout=attn_dropout,
238+
)
239+
240+
try:
241+
self.pos_wise_ff = getattr(common, ff_activation)(
242+
d_model=input_embed_dim,
243+
d_ff=input_embed_dim * ff_hidden_multiplier,
244+
dropout=ff_dropout,
245+
)
246+
except AttributeError:
247+
self.pos_wise_ff = getattr(common, "PositionWiseFeedForward")(
248+
d_model=input_embed_dim,
249+
d_ff=input_embed_dim * ff_hidden_multiplier,
250+
dropout=ff_dropout,
251+
activation=getattr(nn, self.hparams.ff_activation),
252+
)
253+
self.attn_add_norm = AddNorm(input_embed_dim, add_norm_dropout)
254+
self.ff_add_norm = AddNorm(input_embed_dim, add_norm_dropout)
255+
256+
def forward(self, x):
257+
y = self.mha(x)
258+
x = self.attn_add_norm(x, y)
259+
y = self.pos_wise_ff(y)
260+
return self.ff_add_norm(x, y)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .ft_transformer import FTTransformerBackbone, FTTransformerModel
2+
from .config import FTTransformerConfig
3+
4+
__all__ = ["FTTransformerBackbone", "FTTransformerModel", "FTTransformerConfig"]

0 commit comments

Comments
 (0)