Skip to content

Commit 38166cf

Browse files
committed
-- added feature importance to FT Transformer
1 parent 8d6fb81 commit 38166cf

File tree

8 files changed

+109
-43
lines changed

8 files changed

+109
-43
lines changed

examples/to_test_classification.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,25 +99,25 @@
9999
# metrics=["f1", "accuracy"],
100100
# metrics_params=[{"num_classes": num_classes, "average": "macro"}, {}],
101101
# )
102-
# model_config = TabTransformerConfig(
102+
model_config = TabTransformerConfig(
103+
task="classification",
104+
metrics=["f1", "accuracy"],
105+
share_embedding = True,
106+
share_embedding_strategy="add",
107+
shared_embedding_fraction=0.25,
108+
metrics_params=[{"num_classes": num_classes, "average": "macro"}, {}],
109+
)
110+
# model_config = FTTransformerConfig(
103111
# task="classification",
104112
# metrics=["f1", "accuracy"],
113+
# # embedding_initialization=None,
114+
# embedding_bias=True,
105115
# share_embedding = True,
106116
# share_embedding_strategy="fraction",
107117
# shared_embedding_fraction=0.25,
108118
# metrics_params=[{"num_classes": num_classes, "average": "macro"}, {}],
109119
# )
110-
model_config = FTTransformerConfig(
111-
task="classification",
112-
metrics=["f1", "accuracy"],
113-
# embedding_initialization=None,
114-
embedding_bias=False,
115-
share_embedding = True,
116-
share_embedding_strategy="fraction",
117-
shared_embedding_fraction=0.25,
118-
metrics_params=[{"num_classes": num_classes, "average": "macro"}, {}],
119-
)
120-
trainer_config = TrainerConfig(gpus=-1, auto_select_gpus=True, fast_dev_run=False, max_epochs=5, batch_size=512)
120+
trainer_config = TrainerConfig(gpus=-1, auto_select_gpus=True, fast_dev_run=True, max_epochs=5, batch_size=512)
121121
experiment_config = ExperimentConfig(project_name="PyTorch Tabular Example",
122122
run_name="node_forest_cov",
123123
exp_watch="gradients",
@@ -147,9 +147,14 @@
147147
# loss=cust_loss,
148148
train_sampler=sampler)
149149

150-
result = tabular_model.evaluate(test)
151-
print(result)
152-
# test.drop(columns=target_name, inplace=True)
150+
from pytorch_tabular.categorical_encoders import CategoricalEmbeddingTransformer
151+
transformer = CategoricalEmbeddingTransformer(tabular_model)
152+
train_transform = transformer.fit_transform(train)
153+
# test_transform = transformer.transform(test)
154+
# ft = tabular_model.model.feature_importance()
155+
# result = tabular_model.evaluate(test)
156+
# print(result)
157+
# test.drop(columns=ta6rget_name, inplace=True)
153158
# pred_df = tabular_model.predict(test)
154159
# print(pred_df.head())
155160
# pred_df.to_csv("output/temp2.csv")

pytorch_tabular/models/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def calculate_metrics(self, y, y_hat, tag):
121121
for metric, metric_str, metric_params in zip(
122122
self.metrics, self.hparams.metrics, self.hparams.metrics_params
123123
):
124-
if (self.hparams.task == "regression") and (self.hparams.output_dim > 1):
124+
if (self.hparams.task == "regression"):
125125
_metrics = []
126126
for i in range(self.hparams.output_dim):
127127
if (

pytorch_tabular/models/common.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class MultiHeadedAttention(nn.Module):
150150
Multi Headed Attention Block in Transformers
151151
"""
152152
def __init__(
153-
self, input_dim: int, num_heads: int = 8, head_dim: int = 16, dropout: int = 0.1
153+
self, input_dim: int, num_heads: int = 8, head_dim: int = 16, dropout: int = 0.1, keep_attn: bool = True
154154
):
155155
super().__init__()
156156
assert (
@@ -159,6 +159,7 @@ def __init__(
159159
inner_dim = head_dim * num_heads
160160
self.n_heads = num_heads
161161
self.scale = head_dim ** -0.5
162+
self.keep_attn = keep_attn
162163

163164
self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=False)
164165
self.to_out = nn.Linear(inner_dim, input_dim)
@@ -173,7 +174,8 @@ def forward(self, x):
173174

174175
attn = sim.softmax(dim=-1)
175176
attn = self.dropout(attn)
176-
177+
if self.keep_attn:
178+
self.attn_weights = attn
177179
out = einsum("b h i j, b h j d -> b h i d", attn, v)
178180
out = rearrange(out, "b h n d -> b n (h d)", h=h)
179181
return self.to_out(out)
@@ -211,7 +213,15 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
211213
else:
212214
out[:, : shared_embed.shape[1]] = shared_embed
213215
return out
214-
216+
217+
@property
218+
def weight(self):
219+
w = self.embed.weight.detach()
220+
if self.add_shared_embed:
221+
w += self.shared_embed
222+
else:
223+
w[:, : self.shared_embed.shape[1]] = self.shared_embed
224+
return w
215225

216226
class TransformerEncoderBlock(nn.Module):
217227
"""A single Transformer Encoder Block
@@ -223,6 +233,7 @@ def __init__(
223233
ff_hidden_multiplier: int = 4,
224234
ff_activation: str = "GEGLU",
225235
attn_dropout: float = 0.1,
236+
keep_attn: bool = True,
226237
ff_dropout: float = 0.1,
227238
add_norm_dropout: float = 0.1,
228239
transformer_head_dim: Optional[int] = None,
@@ -235,6 +246,7 @@ def __init__(
235246
if transformer_head_dim is None
236247
else transformer_head_dim,
237248
dropout=attn_dropout,
249+
keep_attn = keep_attn
238250
)
239251

240252
try:

pytorch_tabular/models/ft_transformer/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ class FTTransformerConfig(ModelConfig):
128128
"help": "Fraction of the input_embed_dim to be reserved by the shared embedding. Should be less than one. Defaults to 0.25"
129129
},
130130
)
131+
attn_feature_importance: bool = field(
132+
default = True,
133+
metadata={
134+
"help": "If you are facing memory issues, you can turn off feature importance which will not save the attention weights. Defaults to True"
135+
},
136+
)
131137
num_heads: int = field(
132138
default=8,
133139
metadata={

pytorch_tabular/models/ft_transformer/ft_transformer.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
# For license information, see LICENSE.TXT
44
"""Feature Tokenizer Transformer Model"""
55
import logging
6-
from collections import OrderedDict
76
import math
7+
from collections import OrderedDict
88
from typing import Dict
99

10+
import pandas as pd
1011
import pytorch_lightning as pl
1112
import torch
1213
import torch.nn as nn
@@ -118,9 +119,11 @@ def _build_network(self):
118119
attn_dropout=self.hparams.attn_dropout,
119120
ff_dropout=self.hparams.ff_dropout,
120121
add_norm_dropout=self.hparams.add_norm_dropout,
122+
keep_attn=self.hparams.attn_feature_importance #Can use Attn Weights to derive feature importance
121123
)
122124
self.transformer_blocks = nn.Sequential(self.transformer_blocks)
123-
self.attention_weights = [None] * self.hparams.num_attn_blocks
125+
if self.hparams.attn_feature_importance:
126+
self.attention_weights_ = [None] * self.hparams.num_attn_blocks
124127
if self.hparams.batch_norm_continuous_input:
125128
self.normalizing_batch_norm = nn.BatchNorm1d(self.hparams.continuous_dim)
126129
# Final MLP Layers
@@ -177,11 +180,31 @@ def forward(self, x: Dict):
177180
x = self.add_cls(x)
178181
for i, block in enumerate(self.transformer_blocks):
179182
x = block(x)
183+
if self.hparams.attn_feature_importance:
184+
self.attention_weights_[i] = block.mha.attn_weights
185+
# self.feature_importance_+=block.mha.attn_weights[:,:,:,-1].sum(dim=1)
186+
# self._calculate_feature_importance(block.mha.attn_weights)
187+
if self.hparams.attn_feature_importance:
188+
self._calculate_feature_importance()
180189
# Flatten (Batch, N_Categorical, Hidden) --> (Batch, N_CategoricalxHidden)
181190
# x = rearrange(x, "b n h -> b (n h)")
182191
# Taking only CLS token for the prediction head
183192
x = self.linear_layers(x[:, -1])
184193
return x
194+
195+
#Not Tested Properly
196+
def _calculate_feature_importance(self):
197+
# if self.feature_importance_.device != self.device:
198+
# self.feature_importance_ = self.feature_importance_.to(self.device)
199+
200+
n, h, f, _ = self.attention_weights_[0].shape
201+
L = len(self.attention_weights_)
202+
self.local_feature_importance = torch.zeros((n,f), device=self.device)
203+
for attn_weights in self.attention_weights_:
204+
self.local_feature_importance+=attn_weights[:,:,:,-1].sum(dim=1)
205+
self.local_feature_importance = (1/(h*L))*self.local_feature_importance[:,:-1]
206+
self.feature_importance_ = self.local_feature_importance.mean(dim=0)
207+
# self.feature_importance_count_+=attn_weights.shape[0]
185208

186209

187210
class FTTransformerModel(BaseModel):
@@ -221,3 +244,10 @@ def extract_embedding(self):
221244
raise ValueError(
222245
"Model has been trained with no categorical feature and therefore can't be used as a Categorical Encoder"
223246
)
247+
248+
def feature_importance(self):
249+
if self.hparams.attn_feature_importance:
250+
importance_df = pd.DataFrame({"Features": self.hparams.categorical_cols+self.hparams.continuous_cols, "importance": self.backbone.feature_importance_.detach().cpu().numpy()})
251+
return importance_df
252+
else:
253+
raise ValueError("If you want Feature Importance, `attn_feature_weights` should be `True`.")

pytorch_tabular/models/tab_transformer/tab_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def _build_network(self):
7575
attn_dropout=self.hparams.attn_dropout,
7676
ff_dropout=self.hparams.ff_dropout,
7777
add_norm_dropout=self.hparams.add_norm_dropout,
78+
keep_attn = False # No easy way to convert TabTransformer Attn Weights to Feature Importance
7879
)
7980
self.transformer_blocks = nn.Sequential(self.transformer_blocks)
8081
self.attention_weights = [None] * self.hparams.num_attn_blocks

pytorch_tabular/tabular_datamodule.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -232,14 +232,20 @@ def preprocess_data(
232232
# Target Transforms
233233
if all([col in data.columns for col in self.config.target]):
234234
if self.do_target_transform:
235-
target_transforms = []
236-
for col in self.config.target:
237-
_target_transform = copy.deepcopy(self.target_transform_template)
238-
data[col] = _target_transform.fit_transform(
239-
data[col].values.reshape(-1, 1)
240-
)
241-
target_transforms.append(_target_transform)
242-
self.target_transforms = target_transforms
235+
if stage == "fit":
236+
target_transforms = []
237+
for col in self.config.target:
238+
_target_transform = copy.deepcopy(self.target_transform_template)
239+
data[col] = _target_transform.fit_transform(
240+
data[col].values.reshape(-1, 1)
241+
)
242+
target_transforms.append(_target_transform)
243+
self.target_transforms = target_transforms
244+
else:
245+
for col, _target_transform in zip(self.config.target, self.target_transforms):
246+
data[col] = _target_transform.transform(
247+
data[col].values.reshape(-1, 1)
248+
)
243249
return data, added_features
244250

245251
def setup(self, stage: Optional[str] = None) -> None:

pytorch_tabular/tabular_model.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,10 @@ def _get_run_name_uid(self) -> Tuple[str, int]:
184184
"""
185185
if hasattr(self.config, "run_name") and self.config.run_name is not None:
186186
name = self.config.run_name
187-
elif hasattr(self.config, "checkpoints_name") and self.config.checkpoints_name is not None:
187+
elif (
188+
hasattr(self.config, "checkpoints_name")
189+
and self.config.checkpoints_name is not None
190+
):
188191
name = self.config.checkpoints_name
189192
else:
190193
name = self.config.task
@@ -287,7 +290,6 @@ def _prepare_model(self, loss, metrics, optimizer, optimizer_params, reset):
287290
)
288291
# Data Aware Initialization(for the models that need it)
289292
self.model.data_aware_initialization(self.datamodule)
290-
291293

292294
def _prepare_trainer(self, max_epochs=None, min_epochs=None):
293295
logger.info("Preparing the Trainer...")
@@ -297,7 +299,7 @@ def _prepare_trainer(self, max_epochs=None, min_epochs=None):
297299
self.config.min_epochs = min_epochs
298300
# TODO get Trainer Arguments from the init signature
299301
trainer_sig = inspect.signature(pl.Trainer.__init__)
300-
trainer_args = [p for p in trainer_sig.parameters.keys() if p!="self"]
302+
trainer_args = [p for p in trainer_sig.parameters.keys() if p != "self"]
301303
trainer_args_config = {
302304
k: v for k, v in self.config.items() if k in trainer_args
303305
}
@@ -314,9 +316,14 @@ def load_best_model(self):
314316
if self.trainer.checkpoint_callback is not None:
315317
logger.info("Loading the best model...")
316318
ckpt_path = self.trainer.checkpoint_callback.best_model_path
317-
logger.debug(f"Model Checkpoint: {ckpt_path}")
318-
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
319-
self.model.load_state_dict(ckpt["state_dict"])
319+
if ckpt_path != "":
320+
logger.debug(f"Model Checkpoint: {ckpt_path}")
321+
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
322+
self.model.load_state_dict(ckpt["state_dict"])
323+
else:
324+
logger.info(
325+
"No best model available to load. Did you run it more than 1 epoch?..."
326+
)
320327
else:
321328
logger.info(
322329
"No best model available to load. Did you run it more than 1 epoch?..."
@@ -737,19 +744,18 @@ def load_from_checkpoint(cls, dir: str):
737744
custom_params = joblib.load(os.path.join(dir, "custom_params.sav"))
738745
model_args = {}
739746
if custom_params.get("custom_loss") is not None:
740-
model_args['loss'] = "MSELoss"
747+
model_args["loss"] = "MSELoss"
741748
if custom_params.get("custom_metrics") is not None:
742-
model_args['metrics'] = ["mean_squared_error"]
743-
model_args['metric_params'] = [{}]
749+
model_args["metrics"] = ["mean_squared_error"]
750+
model_args["metric_params"] = [{}]
744751
if custom_params.get("custom_optimizer") is not None:
745-
model_args['optimizer'] = "Adam"
752+
model_args["optimizer"] = "Adam"
746753
if custom_params.get("custom_optimizer_params") is not None:
747-
model_args['optimizer_params'] = {}
748-
754+
model_args["optimizer_params"] = {}
755+
749756
# Initializing with default metrics, losses, and optimizers. Will revert once initialized
750757
model = model_callable.load_from_checkpoint(
751-
checkpoint_path=os.path.join(dir, "model.ckpt"),
752-
**model_args
758+
checkpoint_path=os.path.join(dir, "model.ckpt"), **model_args
753759
)
754760
# else:
755761
# # Initializing with default values

0 commit comments

Comments
 (0)