1+ from typing import Optional
2+
13import torch
24import torch .nn .functional as F
3- from torch import nn , einsum
4-
55from einops import rearrange
6+ from torch import einsum , nn
7+
8+ from pytorch_tabular .models import common
69
710
811class Residual (nn .Module ):
@@ -118,3 +121,140 @@ def __init__(self, d_model: int, d_ff: int,
118121 def forward (self , x : torch .Tensor ):
119122 return self .ffn (x )
120123
124+ # Inspired by implementations
125+ # 1. lucidrains - https://github.com/lucidrains/tab-transformer-pytorch/
126+ # If you are interested in Transformers, you should definitely check out his repositories.
127+ # 2. PyTorch Wide and Deep - https://github.com/jrzaurin/pytorch-widedeep/
128+ # It is another library for tabular data, which supports multi modal problems.
129+ # Check out the library if you haven't already.
130+ # 3. AutoGluon - https://github.com/awslabs/autogluon
131+ # AutoGluon is an AuttoML library which supports Tabular data as well. it is from Amazon Research and is in MXNet
132+ # 4. LabML Annotated Deep Learning Papers - The position-wise FF was shamelessly copied from
133+ # https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers
134+
135+ class AddNorm (nn .Module ):
136+ """
137+ Applies LayerNorm, Dropout and adds to input. Standard AddNorm operations in Transformers
138+ """
139+ def __init__ (self , input_dim : int , dropout : float ):
140+ super (AddNorm , self ).__init__ ()
141+ self .dropout = nn .Dropout (dropout )
142+ self .ln = nn .LayerNorm (input_dim )
143+
144+ def forward (self , X : torch .Tensor , Y : torch .Tensor ) -> torch .Tensor :
145+ return self .ln (self .dropout (Y ) + X )
146+
147+
148+ class MultiHeadedAttention (nn .Module ):
149+ """
150+ Multi Headed Attention Block in Transformers
151+ """
152+ def __init__ (
153+ self , input_dim : int , num_heads : int = 8 , head_dim : int = 16 , dropout : int = 0.1
154+ ):
155+ super ().__init__ ()
156+ assert (
157+ input_dim % num_heads == 0
158+ ), "'input_dim' must be multiples of 'num_heads'"
159+ inner_dim = head_dim * num_heads
160+ self .n_heads = num_heads
161+ self .scale = head_dim ** - 0.5
162+
163+ self .to_qkv = nn .Linear (input_dim , inner_dim * 3 , bias = False )
164+ self .to_out = nn .Linear (inner_dim , input_dim )
165+
166+ self .dropout = nn .Dropout (dropout )
167+
168+ def forward (self , x ):
169+ h = self .n_heads
170+ q , k , v = self .to_qkv (x ).chunk (3 , dim = - 1 )
171+ q , k , v = map (lambda t : rearrange (t , "b n (h d) -> b h n d" , h = h ), (q , k , v ))
172+ sim = einsum ("b h i d, b h j d -> b h i j" , q , k ) * self .scale
173+
174+ attn = sim .softmax (dim = - 1 )
175+ attn = self .dropout (attn )
176+
177+ out = einsum ("b h i j, b h j d -> b h i d" , attn , v )
178+ out = rearrange (out , "b h n d -> b n (h d)" , h = h )
179+ return self .to_out (out )
180+
181+
182+ # Slight adaptation from https://github.com/jrzaurin/pytorch-widedeep which in turn adapted from AutoGluon
183+ class SharedEmbeddings (nn .Module ):
184+ """
185+ Enables different values in a categorical feature to share some embeddings across
186+ """
187+ def __init__ (
188+ self ,
189+ num_embed : int ,
190+ embed_dim : int ,
191+ add_shared_embed : bool = False ,
192+ frac_shared_embed : float = 0.25 ,
193+ ):
194+ super (SharedEmbeddings , self ).__init__ ()
195+ assert frac_shared_embed < 1 , "'frac_shared_embed' must be less than 1"
196+
197+ self .add_shared_embed = add_shared_embed
198+ self .embed = nn .Embedding (num_embed , embed_dim , padding_idx = 0 )
199+ self .embed .weight .data .clamp_ (- 2 , 2 )
200+ if add_shared_embed :
201+ col_embed_dim = embed_dim
202+ else :
203+ col_embed_dim = int (embed_dim * frac_shared_embed )
204+ self .shared_embed = nn .Parameter (torch .empty (1 , col_embed_dim ).uniform_ (- 1 , 1 ))
205+
206+ def forward (self , X : torch .Tensor ) -> torch .Tensor :
207+ out = self .embed (X )
208+ shared_embed = self .shared_embed .expand (out .shape [0 ], - 1 )
209+ if self .add_shared_embed :
210+ out += shared_embed
211+ else :
212+ out [:, : shared_embed .shape [1 ]] = shared_embed
213+ return out
214+
215+
216+ class TransformerEncoderBlock (nn .Module ):
217+ """A single Transformer Encoder Block
218+ """
219+ def __init__ (
220+ self ,
221+ input_embed_dim : int ,
222+ num_heads : int = 8 ,
223+ ff_hidden_multiplier : int = 4 ,
224+ ff_activation : str = "GEGLU" ,
225+ attn_dropout : float = 0.1 ,
226+ ff_dropout : float = 0.1 ,
227+ add_norm_dropout : float = 0.1 ,
228+ transformer_head_dim : Optional [int ] = None ,
229+ ):
230+ super ().__init__ ()
231+ self .mha = MultiHeadedAttention (
232+ input_embed_dim ,
233+ num_heads ,
234+ head_dim = input_embed_dim
235+ if transformer_head_dim is None
236+ else transformer_head_dim ,
237+ dropout = attn_dropout ,
238+ )
239+
240+ try :
241+ self .pos_wise_ff = getattr (common , ff_activation )(
242+ d_model = input_embed_dim ,
243+ d_ff = input_embed_dim * ff_hidden_multiplier ,
244+ dropout = ff_dropout ,
245+ )
246+ except AttributeError :
247+ self .pos_wise_ff = getattr (common , "PositionWiseFeedForward" )(
248+ d_model = input_embed_dim ,
249+ d_ff = input_embed_dim * ff_hidden_multiplier ,
250+ dropout = ff_dropout ,
251+ activation = getattr (nn , self .hparams .ff_activation ),
252+ )
253+ self .attn_add_norm = AddNorm (input_embed_dim , add_norm_dropout )
254+ self .ff_add_norm = AddNorm (input_embed_dim , add_norm_dropout )
255+
256+ def forward (self , x ):
257+ y = self .mha (x )
258+ x = self .attn_add_norm (x , y )
259+ y = self .pos_wise_ff (y )
260+ return self .ff_add_norm (x , y )
0 commit comments