Skip to content

Commit b49f827

Browse files
committed
-- added attention pooling for autoint
-- added backbone features also to ret logits prediction -- added autoint attentionpooling test cases
1 parent b5d033a commit b49f827

File tree

6 files changed

+32
-7
lines changed

6 files changed

+32
-7
lines changed

examples/to_test_regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@
6565
# )
6666
# # model_config.validate()
6767
# model_config = CategoryEmbeddingModelConfig(task="regression")
68-
model_config = AutoIntConfig(task="regression", deep_layers=True, embedding_dropout=0.2, batch_norm_continuous_input=True)
68+
model_config = AutoIntConfig(task="regression", deep_layers=True, embedding_dropout=0.2,
69+
batch_norm_continuous_input=True, attention_pooling=True)
6970
trainer_config = TrainerConfig(checkpoints=None, max_epochs=25, gpus=1, profiler=None, fast_dev_run=False, auto_lr_find=True)
7071
# experiment_config = ExperimentConfig(
7172
# project_name="DeepGMM_test",

pytorch_tabular/models/autoint/autoint.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,16 @@ def _build_network(self):
7272
)
7373
if self.hparams.has_residuals:
7474
self.V_res_embedding = torch.nn.Linear(
75-
_curr_units, self.hparams.attn_embed_dim
75+
_curr_units,
76+
self.hparams.attn_embed_dim * self.hparams.num_attn_blocks
77+
if self.hparams.attention_pooling
78+
else self.hparams.attn_embed_dim,
7679
)
7780
self.output_dim = (
7881
self.hparams.continuous_dim + self.hparams.categorical_dim
7982
) * self.hparams.attn_embed_dim
83+
if self.hparams.attention_pooling:
84+
self.output_dim = self.output_dim * self.hparams.num_attn_blocks
8085

8186
def forward(self, x: Dict):
8287
# (B, N)
@@ -109,8 +114,14 @@ def forward(self, x: Dict):
109114
x = self.linear_layers(x)
110115
# (N, B, E*) --> E* is the Attn Dimention
111116
cross_term = self.attn_proj(x).transpose(0, 1)
117+
if self.hparams.attention_pooling:
118+
attention_ops = []
112119
for self_attn in self.self_attns:
113120
cross_term, _ = self_attn(cross_term, cross_term, cross_term)
121+
if self.hparams.attention_pooling:
122+
attention_ops.append(cross_term)
123+
if self.hparams.attention_pooling:
124+
cross_term = torch.cat(attention_ops, dim=-1)
114125
# (B, N, E*)
115126
cross_term = cross_term.transpose(0, 1)
116127
if self.hparams.has_residuals:

pytorch_tabular/models/autoint/config.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class AutoIntConfig(ModelConfig):
3939
Defaults to ReLU
4040
dropout (float): probability of an classification element to be zeroed in the deep MLP. Defaults to 0.0
4141
use_batch_norm (bool): Flag to include a BatchNorm layer after each Linear Layer+DropOut. Defaults to False
42-
batch_norm_continuous_input (bool): If True, we will normalize the contiinuous layer by passing it through a BatchNorm layer
42+
batch_norm_continuous_input (bool): If True, we will normalize the contiinuous layer by passing it through a BatchNorm layer. Defaults to False
43+
attention_pooling (bool): If True, will combine the attention outputs of each block for final prediction. Defaults to False
4344
initialization (str): Initialization scheme for the linear layers. Defaults to `kaiming`.
4445
Choices are: [`kaiming`,`xavier`,`random`].
4546
@@ -122,7 +123,13 @@ class AutoIntConfig(ModelConfig):
122123
batch_norm_continuous_input: bool = field(
123124
default=False,
124125
metadata={
125-
"help": "If True, we will normalize the continuous layer by passing it through a BatchNorm layer"
126+
"help": "If True, we will normalize the continuous layer by passing it through a BatchNorm layer. Defaults to Fasle"
127+
},
128+
)
129+
attention_pooling: bool = field(
130+
default=False,
131+
metadata={
132+
"help": "If True, will combine the attention outputs of each block for final prediction. Defaults to False"
126133
},
127134
)
128135
initialization: str = field(

pytorch_tabular/models/base_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def _setup_metrics(self):
7373
raise e
7474
else:
7575
self.metrics = self.custom_metrics
76+
self.hparams.metrics = [m.__name__ for m in self.custom_metrics]
7677

7778
def calculate_loss(self, y, y_hat, tag):
7879
if (self.hparams.task == "regression") and (self.hparams.output_dim > 1):

pytorch_tabular/tabular_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,8 +627,8 @@ def predict(
627627
y_hat, ret_value = self.model.predict(batch, ret_model_output=True)
628628
if ret_logits:
629629
for k, v in ret_value.items():
630-
if k == "backbone_features":
631-
continue
630+
# if k == "backbone_features":
631+
# continue
632632
logits_predictions[k].append(v.detach().cpu())
633633
point_predictions.append(y_hat.detach().cpu())
634634
if is_probabilistic:
@@ -751,6 +751,8 @@ def load_from_checkpoint(cls, dir: str):
751751
tabular_model.model = model
752752
tabular_model.datamodule = datamodule
753753
tabular_model.callbacks = callbacks
754+
#TODO max_epochs and min_epochs, make it optional
755+
#TODO custom model and custom metrics need to be dealt with separately
754756
tabular_model._prepare_trainer()
755757
tabular_model.trainer.model = model
756758
tabular_model.logger = logger

tests/test_autoint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
@pytest.mark.parametrize("target_range", [True, False])
3131
@pytest.mark.parametrize("deep_layers", [True, False])
3232
@pytest.mark.parametrize("batch_norm_continuous_input", [True, False])
33+
@pytest.mark.parametrize("attention_pooling", [True, False])
3334
def test_regression(
3435
regression_data,
3536
multi_target,
@@ -39,7 +40,8 @@ def test_regression(
3940
normalize_continuous_features,
4041
target_range,
4142
deep_layers,
42-
batch_norm_continuous_input
43+
batch_norm_continuous_input,
44+
attention_pooling
4345
):
4446
(train, test, target) = regression_data
4547
if len(continuous_cols) + len(categorical_cols) == 0:
@@ -65,6 +67,7 @@ def test_regression(
6567
model_config_params["target_range"] = _target_range
6668
model_config_params["deep_layers"] = deep_layers
6769
model_config_params["batch_norm_continuous_input"] = batch_norm_continuous_input
70+
model_config_params["attention_pooling"] = attention_pooling
6871
model_config = AutoIntConfig(**model_config_params)
6972
trainer_config = TrainerConfig(
7073
max_epochs=3, checkpoints=None, early_stopping=None, gpus=0

0 commit comments

Comments
 (0)