99import torch
1010import torch .nn as nn
1111from omegaconf import DictConfig
12- from pytorch_tabular .utils import _initialize_layers
12+
13+ from pytorch_tabular .utils import _initialize_layers , _linear_dropout_bn
1314
1415from ..base_model import BaseModel
1516
@@ -21,40 +22,32 @@ def __init__(self, config: DictConfig, **kwargs):
2122 self .embedding_cat_dim = sum ([y for x , y in config .embedding_dims ])
2223 super ().__init__ (config , ** kwargs )
2324
24- def _linear_dropout_bn (self , in_units , out_units , activation , dropout ):
25- layers = []
26- if self .hparams .use_batch_norm :
27- layers .append (nn .BatchNorm1d (num_features = in_units ))
28- linear = nn .Linear (in_units , out_units )
29- _initialize_layers (self .hparams , linear )
30- layers .extend ([linear , activation ()])
31- if dropout != 0 :
32- layers .append (nn .Dropout (dropout ))
33- return layers
34-
3525 def _build_network (self ):
36- # Embedding layers
26+ # Category Embedding layers
3727 self .cat_embedding_layers = nn .ModuleList (
38- [nn .Embedding (x , y ) for x , y in self .hparams .cat_embedding_dims ]
39- )
40- self .cont_embedding_layers = nn .ModuleList (
4128 [
42- nn .Embedding (1 , self .hparams .cont_embedding_dim )
43- for i in range ( self .hparams .continuous_dim )
29+ nn .Embedding (cardinality , self .hparams .embedding_dim )
30+ for cardinality in self .hparams .categorical_cardinality
4431 ]
4532 )
33+ if self .hparams .batch_norm_continuous_input :
34+ self .normalizing_batch_norm = nn .BatchNorm1d (self .hparams .continuous_dim )
35+ # Continuous Embedding Layer
36+ self .cont_embedding_layer = nn .Embedding (
37+ self .hparams .continuous_dim , self .hparams .embedding_dim
38+ )
4639 if self .hparams .embedding_dropout != 0 and self .embedding_cat_dim != 0 :
4740 self .embed_dropout = nn .Dropout (self .hparams .embedding_dropout )
48- # if self.hparams.use_batch_norm:
49- # self.normalizing_batch_norm = nn.BatchNorm1d( self.hparams.continuous_dim+self.hparams.embedding_cat_dim)
41+ # Deep Layers
42+ _curr_units = self .hparams .embedding_dim
5043 if self .hparams .deep_layers :
5144 activation = getattr (nn , self .hparams .activation )
5245 # Linear Layers
5346 layers = []
54- _curr_units = self .hparams .continuous_dim + self .embedding_cat_dim
5547 for units in self .hparams .layers .split ("-" ):
5648 layers .extend (
57- self ._linear_dropout_bn (
49+ _linear_dropout_bn (
50+ self .hparams ,
5851 _curr_units ,
5952 int (units ),
6053 activation ,
@@ -63,9 +56,10 @@ def _build_network(self):
6356 )
6457 _curr_units = int (units )
6558 self .linear_layers = nn .Sequential (* layers )
66- else :
67- _curr_units = self .hparams .continuous_dim + self .embedding_cat_dim
68-
59+ # Projection to Multi-Headed Attention Dims
60+ self .attn_proj = nn .Linear (_curr_units , self .hparams .attn_embed_dim )
61+ _initialize_layers (self .hparams , self .attn_proj )
62+ # Multi-Headed Attention Layers
6963 self .self_attns = nn .ModuleList (
7064 [
7165 nn .MultiheadAttention (
@@ -76,15 +70,56 @@ def _build_network(self):
7670 for _ in range (self .hparams .num_attn_blocks )
7771 ]
7872 )
79- self .atten_output_dim = (
80- len (self .hparams .continuous_cols + self .hparams .categorical_cols )
81- * self .hparams .atten_embed_dim
82- )
83-
73+ if self .hparams .has_residuals :
74+ self .V_res_embedding = torch .nn .Linear (
75+ _curr_units , self .hparams .attn_embed_dim
76+ )
77+ self .output_dim = (
78+ self .hparams .continuous_dim + self .hparams .categorical_dim
79+ ) * self .hparams .attn_embed_dim
8480
85- def forward (self , x ):
86- x = self .linear_layers (x )
87- return x
81+ def forward (self , x : Dict ):
82+ # (B, N)
83+ continuous_data , categorical_data = x ["continuous" ], x ["categorical" ]
84+ x = None
85+ if self .embedding_cat_dim != 0 :
86+ x_cat = [
87+ embedding_layer (categorical_data [:, i ]).unsqueeze (1 )
88+ for i , embedding_layer in enumerate (self .cat_embedding_layers )
89+ ]
90+ # (B, N, E)
91+ x = torch .cat (x_cat , 1 )
92+ if self .hparams .continuous_dim > 0 :
93+ cont_idx = (
94+ torch .arange (self .hparams .continuous_dim )
95+ .expand (continuous_data .size (0 ), - 1 )
96+ .to (self .device )
97+ )
98+ if self .hparams .batch_norm_continuous_input :
99+ continuous_data = self .normalizing_batch_norm (continuous_data )
100+ x_cont = torch .mul (
101+ continuous_data .unsqueeze (2 ),
102+ self .cont_embedding_layer (cont_idx ),
103+ )
104+ # (B, N, E)
105+ x = x_cont if x is None else torch .cat ([x , x_cont ], 1 )
106+ if self .hparams .embedding_dropout != 0 and self .embedding_cat_dim != 0 :
107+ x = self .embed_dropout (x )
108+ if self .hparams .deep_layers :
109+ x = self .linear_layers (x )
110+ # (N, B, E*) --> E* is the Attn Dimention
111+ cross_term = self .attn_proj (x ).transpose (0 , 1 )
112+ for self_attn in self .self_attns :
113+ cross_term , _ = self_attn (cross_term , cross_term , cross_term )
114+ # (B, N, E*)
115+ cross_term = cross_term .transpose (0 , 1 )
116+ if self .hparams .has_residuals :
117+ # (B, N, E*) --> Projecting Embedded input to Attention sub-space
118+ V_res = self .V_res_embedding (x )
119+ cross_term = cross_term + V_res
120+ # (B, NxE*)
121+ cross_term = nn .ReLU ()(cross_term ).reshape (- 1 , self .output_dim )
122+ return cross_term
88123
89124
90125class AutoIntModel (BaseModel ):
@@ -94,46 +129,18 @@ def __init__(self, config: DictConfig, **kwargs):
94129 super ().__init__ (config , ** kwargs )
95130
96131 def _build_network (self ):
97- # Embedding layers
98- self .embedding_layers = nn .ModuleList (
99- [nn .Embedding (x , y ) for x , y in self .hparams .embedding_dims ]
100- )
101- # Continuous Layers
102- if self .hparams .batch_norm_continuous_input :
103- self .normalizing_batch_norm = nn .BatchNorm1d (self .hparams .continuous_dim )
104132 # Backbone
105133 self .backbone = AutoIntBackbone (self .hparams )
134+ self .dropout = nn .Dropout (self .hparams .dropout )
106135 # Adding the last layer
107136 self .output_layer = nn .Linear (
108137 self .backbone .output_dim , self .hparams .output_dim
109138 ) # output_dim auto-calculated from other config
110139 _initialize_layers (self .hparams , self .output_layer )
111140
112- def unpack_input (self , x : Dict ):
113- continuous_data , categorical_data = x ["continuous" ], x ["categorical" ]
114- if self .embedding_cat_dim != 0 :
115- x = []
116- # for i, embedding_layer in enumerate(self.embedding_layers):
117- # x.append(embedding_layer(categorical_data[:, i]))
118- x = [
119- embedding_layer (categorical_data [:, i ])
120- for i , embedding_layer in enumerate (self .embedding_layers )
121- ]
122- x = torch .cat (x , 1 )
123-
124- if self .hparams .continuous_dim != 0 :
125- if self .hparams .batch_norm_continuous_input :
126- continuous_data = self .normalizing_batch_norm (continuous_data )
127-
128- if self .embedding_cat_dim != 0 :
129- x = torch .cat ([x , continuous_data ], 1 )
130- else :
131- x = continuous_data
132- return x
133-
134141 def forward (self , x : Dict ):
135- x = self .unpack_input (x )
136142 x = self .backbone (x )
143+ x = self .dropout (x )
137144 y_hat = self .output_layer (x )
138145 if (self .hparams .task == "regression" ) and (
139146 self .hparams .target_range is not None
0 commit comments