Skip to content

Commit b5d033a

Browse files
committed
Merge branch 'autoint' into develop
2 parents c34c2a0 + 635582b commit b5d033a

File tree

14 files changed

+782
-205661
lines changed

14 files changed

+782
-205661
lines changed

examples/regression_with_MDN.ipynb

Lines changed: 0 additions & 205630 deletions
Large diffs are not rendered by default.

examples/to_test_regression.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytorch_tabular.models.category_embedding.config import (
1313
CategoryEmbeddingModelConfig,
1414
)
15+
from pytorch_tabular.models import AutoIntModel, AutoIntConfig
1516

1617
from pytorch_tabular.models.mixture_density import (
1718
CategoryEmbeddingMDNConfig, MixtureDensityHeadConfig, NODEMDNConfig
@@ -33,6 +34,8 @@
3334
dataset = fetch_california_housing(data_home="data", as_frame=True)
3435
dataset.frame["HouseAgeBin"] = pd.qcut(dataset.frame["HouseAge"], q=4)
3536
dataset.frame.HouseAgeBin = "age_" + dataset.frame.HouseAgeBin.cat.codes.astype(str)
37+
dataset.frame["AveRoomsBin"] = pd.qcut(dataset.frame["AveRooms"], q=3)
38+
dataset.frame.AveRoomsBin = "av_rm_" + dataset.frame.AveRoomsBin.cat.codes.astype(str)
3639

3740
test_idx = dataset.frame.sample(int(0.2 * len(dataset.frame)), random_state=42).index
3841
test = dataset.frame[dataset.frame.index.isin(test_idx)]
@@ -49,7 +52,7 @@
4952
"Longitude",
5053
],
5154
# continuous_cols=[],
52-
categorical_cols=["HouseAgeBin"],
55+
categorical_cols=["HouseAgeBin","AveRoomsBin"],
5356
continuous_feature_transform=None, # "yeo-johnson",
5457
normalize_continuous_features=True,
5558
)
@@ -61,8 +64,9 @@
6164
# mdn_config = mdn_config
6265
# )
6366
# # model_config.validate()
64-
model_config = NodeConfig(task="regression", depth=2, embed_categorical=False)
65-
trainer_config = TrainerConfig(checkpoints=None, max_epochs=5, gpus=1, profiler=None)
67+
# model_config = CategoryEmbeddingModelConfig(task="regression")
68+
model_config = AutoIntConfig(task="regression", deep_layers=True, embedding_dropout=0.2, batch_norm_continuous_input=True)
69+
trainer_config = TrainerConfig(checkpoints=None, max_epochs=25, gpus=1, profiler=None, fast_dev_run=False, auto_lr_find=True)
6670
# experiment_config = ExperimentConfig(
6771
# project_name="DeepGMM_test",
6872
# run_name="wand_debug",

pytorch_tabular/models/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
MixtureDensityHeadConfig,
99
NODEMDNConfig,
1010
NODEMDN,
11+
AutoIntMDN,
12+
AutoIntMDNConfig
1113
)
14+
from .autoint import AutoIntConfig, AutoIntModel
1215
from .base_model import BaseModel
13-
from . import category_embedding, node, mixture_density, tabnet
16+
from . import category_embedding, node, mixture_density, tabnet, autoint
1417

1518
__all__ = [
1619
"CategoryEmbeddingModel",
@@ -26,8 +29,13 @@
2629
"MixtureDensityHeadConfig",
2730
"NODEMDNConfig",
2831
"NODEMDN",
32+
"AutoIntMDN",
33+
"AutoIntMDNConfig",
34+
"AutoIntConfig",
35+
"AutoIntModel",
2936
"category_embedding",
3037
"node",
3138
"mixture_density",
3239
"tabnet",
40+
"autoint",
3341
]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .autoint import AutoIntBackbone, AutoIntModel
2+
from .config import AutoIntConfig
3+
4+
__all__ = ["AutoIntModel", "AutoIntBackbone", "AutoIntConfig"]
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Pytorch Tabular
2+
# Author: Manu Joseph <manujoseph@gmail.com>
3+
# For license information, see LICENSE.TXT
4+
# Inspired by https://github.com/rixwew/pytorch-fm/blob/master/torchfm/model/afi.py
5+
"""AutomaticFeatureInteraction Model"""
6+
import logging
7+
from typing import Dict
8+
9+
import torch
10+
import torch.nn as nn
11+
from omegaconf import DictConfig
12+
13+
from pytorch_tabular.utils import _initialize_layers, _linear_dropout_bn
14+
15+
from ..base_model import BaseModel
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class AutoIntBackbone(BaseModel):
21+
def __init__(self, config: DictConfig, **kwargs):
22+
self.embedding_cat_dim = sum([y for x, y in config.embedding_dims])
23+
super().__init__(config, **kwargs)
24+
25+
def _build_network(self):
26+
# Category Embedding layers
27+
self.cat_embedding_layers = nn.ModuleList(
28+
[
29+
nn.Embedding(cardinality, self.hparams.embedding_dim)
30+
for cardinality in self.hparams.categorical_cardinality
31+
]
32+
)
33+
if self.hparams.batch_norm_continuous_input:
34+
self.normalizing_batch_norm = nn.BatchNorm1d(self.hparams.continuous_dim)
35+
# Continuous Embedding Layer
36+
self.cont_embedding_layer = nn.Embedding(
37+
self.hparams.continuous_dim, self.hparams.embedding_dim
38+
)
39+
if self.hparams.embedding_dropout != 0 and self.embedding_cat_dim != 0:
40+
self.embed_dropout = nn.Dropout(self.hparams.embedding_dropout)
41+
# Deep Layers
42+
_curr_units = self.hparams.embedding_dim
43+
if self.hparams.deep_layers:
44+
activation = getattr(nn, self.hparams.activation)
45+
# Linear Layers
46+
layers = []
47+
for units in self.hparams.layers.split("-"):
48+
layers.extend(
49+
_linear_dropout_bn(
50+
self.hparams,
51+
_curr_units,
52+
int(units),
53+
activation,
54+
self.hparams.dropout,
55+
)
56+
)
57+
_curr_units = int(units)
58+
self.linear_layers = nn.Sequential(*layers)
59+
# Projection to Multi-Headed Attention Dims
60+
self.attn_proj = nn.Linear(_curr_units, self.hparams.attn_embed_dim)
61+
_initialize_layers(self.hparams, self.attn_proj)
62+
# Multi-Headed Attention Layers
63+
self.self_attns = nn.ModuleList(
64+
[
65+
nn.MultiheadAttention(
66+
self.hparams.attn_embed_dim,
67+
self.hparams.num_heads,
68+
dropout=self.hparams.attn_dropouts,
69+
)
70+
for _ in range(self.hparams.num_attn_blocks)
71+
]
72+
)
73+
if self.hparams.has_residuals:
74+
self.V_res_embedding = torch.nn.Linear(
75+
_curr_units, self.hparams.attn_embed_dim
76+
)
77+
self.output_dim = (
78+
self.hparams.continuous_dim + self.hparams.categorical_dim
79+
) * self.hparams.attn_embed_dim
80+
81+
def forward(self, x: Dict):
82+
# (B, N)
83+
continuous_data, categorical_data = x["continuous"], x["categorical"]
84+
x = None
85+
if self.embedding_cat_dim != 0:
86+
x_cat = [
87+
embedding_layer(categorical_data[:, i]).unsqueeze(1)
88+
for i, embedding_layer in enumerate(self.cat_embedding_layers)
89+
]
90+
# (B, N, E)
91+
x = torch.cat(x_cat, 1)
92+
if self.hparams.continuous_dim > 0:
93+
cont_idx = (
94+
torch.arange(self.hparams.continuous_dim)
95+
.expand(continuous_data.size(0), -1)
96+
.to(self.device)
97+
)
98+
if self.hparams.batch_norm_continuous_input:
99+
continuous_data = self.normalizing_batch_norm(continuous_data)
100+
x_cont = torch.mul(
101+
continuous_data.unsqueeze(2),
102+
self.cont_embedding_layer(cont_idx),
103+
)
104+
# (B, N, E)
105+
x = x_cont if x is None else torch.cat([x, x_cont], 1)
106+
if self.hparams.embedding_dropout != 0 and self.embedding_cat_dim != 0:
107+
x = self.embed_dropout(x)
108+
if self.hparams.deep_layers:
109+
x = self.linear_layers(x)
110+
# (N, B, E*) --> E* is the Attn Dimention
111+
cross_term = self.attn_proj(x).transpose(0, 1)
112+
for self_attn in self.self_attns:
113+
cross_term, _ = self_attn(cross_term, cross_term, cross_term)
114+
# (B, N, E*)
115+
cross_term = cross_term.transpose(0, 1)
116+
if self.hparams.has_residuals:
117+
# (B, N, E*) --> Projecting Embedded input to Attention sub-space
118+
V_res = self.V_res_embedding(x)
119+
cross_term = cross_term + V_res
120+
# (B, NxE*)
121+
cross_term = nn.ReLU()(cross_term).reshape(-1, self.output_dim)
122+
return cross_term
123+
124+
125+
class AutoIntModel(BaseModel):
126+
def __init__(self, config: DictConfig, **kwargs):
127+
# The concatenated output dim of the embedding layer
128+
self.embedding_cat_dim = sum([y for x, y in config.embedding_dims])
129+
super().__init__(config, **kwargs)
130+
131+
def _build_network(self):
132+
# Backbone
133+
self.backbone = AutoIntBackbone(self.hparams)
134+
self.dropout = nn.Dropout(self.hparams.dropout)
135+
# Adding the last layer
136+
self.output_layer = nn.Linear(
137+
self.backbone.output_dim, self.hparams.output_dim
138+
) # output_dim auto-calculated from other config
139+
_initialize_layers(self.hparams, self.output_layer)
140+
141+
def forward(self, x: Dict):
142+
x = self.backbone(x)
143+
x = self.dropout(x)
144+
y_hat = self.output_layer(x)
145+
if (self.hparams.task == "regression") and (
146+
self.hparams.target_range is not None
147+
):
148+
for i in range(self.hparams.output_dim):
149+
y_min, y_max = self.hparams.target_range[i]
150+
y_hat[:, i] = y_min + nn.Sigmoid()(y_hat[:, i]) * (y_max - y_min)
151+
return {"logits": y_hat, "backbone_features": x}
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Pytorch Tabular
2+
# Author: Manu Joseph <manujoseph@gmail.com>
3+
# For license information, see LICENSE.TXT
4+
"""AutomaticFeatureInteraction Config"""
5+
from dataclasses import dataclass, field
6+
from typing import List, Optional
7+
8+
from pytorch_tabular.config import ModelConfig, _validate_choices
9+
10+
11+
@dataclass
12+
class AutoIntConfig(ModelConfig):
13+
"""AutomaticFeatureInteraction configuration
14+
Args:
15+
task (str): Specify whether the problem is regression of classification.Choices are: regression classification
16+
learning_rate (float): The learning rate of the model
17+
loss (Union[str, NoneType]): The loss function to be applied.
18+
By Default it is MSELoss for regression and CrossEntropyLoss for classification.
19+
Unless you are sure what you are doing, leave it at MSELoss or L1Loss for regression and CrossEntropyLoss for classification
20+
metrics (Union[List[str], NoneType]): the list of metrics you need to track during training.
21+
The metrics should be one of the metrics implemented in PyTorch Lightning.
22+
By default, it is Accuracy if classification and MeanSquaredLogError for regression
23+
metrics_params (Union[List, NoneType]): The parameters to be passed to the Metrics initialized
24+
target_range (Union[List, NoneType]): The range in which we should limit the output variable. Currently ignored for multi-target regression
25+
Typically used for Regression problems. If left empty, will not apply any restrictions
26+
27+
attn_embed_dim (int): The number of hidden units in the Multi-Headed Attention layers. Defaults to 32
28+
num_heads (int): The number of heads in the Multi-Headed Attention layer. Defaults to 2
29+
num_attn_blocks (int): The number of layers of stacked Multi-Headed Attention layers. Defaults to 2
30+
attn_dropouts (float): Dropout between layers of Multi-Headed Attention Layers. Defaults to 0.0
31+
has_residuals (bool): Flag to have a residual connect from enbedded output to attention layer output.
32+
Defaults to True
33+
embedding_dim (int): The dimensions of the embedding for continuous and categorical columns. Defaults to 16
34+
embedding_dropout (float): probability of an embedding element to be zeroed. Defaults to 0.0
35+
deep_layers (bool): Flag to enable a deep MLP layer before the Multi-Headed Attention layer. Defaults to False
36+
layers (str): Hyphen-separated number of layers and units in the deep MLP. Defaults to 128-64-32
37+
activation (str): The activation type in the deep MLP. The default activaion in PyTorch like
38+
ReLU, TanH, LeakyReLU, etc. https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity.
39+
Defaults to ReLU
40+
dropout (float): probability of an classification element to be zeroed in the deep MLP. Defaults to 0.0
41+
use_batch_norm (bool): Flag to include a BatchNorm layer after each Linear Layer+DropOut. Defaults to False
42+
batch_norm_continuous_input (bool): If True, we will normalize the contiinuous layer by passing it through a BatchNorm layer
43+
initialization (str): Initialization scheme for the linear layers. Defaults to `kaiming`.
44+
Choices are: [`kaiming`,`xavier`,`random`].
45+
46+
Raises:
47+
NotImplementedError: Raises an error if task is not in ['regression','classification']
48+
"""
49+
50+
attn_embed_dim: int = field(
51+
default=32,
52+
metadata={
53+
"help": "The number of hidden units in the Multi-Headed Attention layers. Defaults to 32"
54+
},
55+
)
56+
num_heads: int = field(
57+
default=2,
58+
metadata={
59+
"help": "The number of heads in the Multi-Headed Attention layer. Defaults to 2"
60+
},
61+
)
62+
num_attn_blocks: int = field(
63+
default=3,
64+
metadata={
65+
"help": "The number of layers of stacked Multi-Headed Attention layers. Defaults to 2"
66+
},
67+
)
68+
attn_dropouts: float = field(
69+
default=0.0,
70+
metadata={
71+
"help": "Dropout between layers of Multi-Headed Attention Layers. Defaults to 0.0"
72+
},
73+
)
74+
has_residuals: bool = field(
75+
default=True,
76+
metadata={
77+
"help": "Flag to have a residual connect from enbedded output to attention layer output. Defaults to True"
78+
},
79+
)
80+
embedding_dim: int = field(
81+
default=16,
82+
metadata={
83+
"help": "The dimensions of the embedding for continuous and categorical columns. Defaults to 16"
84+
},
85+
)
86+
embedding_dropout: float = field(
87+
default=0.0,
88+
metadata={
89+
"help": "probability of an embedding element to be zeroed. Defaults to 0.0"
90+
},
91+
)
92+
deep_layers: bool = field(
93+
default=False,
94+
metadata={
95+
"help": "Flag to enable a deep MLP layer before the Multi-Headed Attention layer. Defaults to False"
96+
},
97+
)
98+
layers: str = field(
99+
default="128-64-32",
100+
metadata={
101+
"help": "Hyphen-separated number of layers and units in the deep MLP. Defaults to 128-64-32"
102+
},
103+
)
104+
activation: str = field(
105+
default="ReLU",
106+
metadata={
107+
"help": "The activation type in the deep MLP. The default activaion in PyTorch like ReLU, TanH, LeakyReLU, etc. https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity. Defaults to ReLU"
108+
},
109+
)
110+
dropout: float = field(
111+
default=0.0,
112+
metadata={
113+
"help": "probability of an classification element to be zeroed in the deep MLP. Defaults to 0.0"
114+
},
115+
)
116+
use_batch_norm: bool = field(
117+
default=False,
118+
metadata={
119+
"help": "Flag to include a BatchNorm layer after each Linear Layer+DropOut. Defaults to False"
120+
},
121+
)
122+
batch_norm_continuous_input: bool = field(
123+
default=False,
124+
metadata={
125+
"help": "If True, we will normalize the continuous layer by passing it through a BatchNorm layer"
126+
},
127+
)
128+
initialization: str = field(
129+
default="kaiming",
130+
metadata={
131+
"help": "Initialization scheme for the linear layers. Defaults to `kaiming`",
132+
"choices": ["kaiming", "xavier", "random"],
133+
},
134+
)
135+
_module_src: str = field(default="autoint")
136+
_model_name: str = field(default="AutoIntModel")
137+
_config_name: str = field(default="AutoIntConfig")
138+
139+
140+
# cls = AutoIntConfig
141+
# desc = "Configuration for Data."
142+
# doc_str = f"{desc}\nArgs:"
143+
# for key in cls.__dataclass_fields__.keys():
144+
# atr = cls.__dataclass_fields__[key]
145+
# if atr.init:
146+
# type = str(atr.type).replace("<class '","").replace("'>","").replace("typing.","")
147+
# help_str = atr.metadata.get("help","")
148+
# if "choices" in atr.metadata.keys():
149+
# help_str += f'. Choices are: [{",".join(["`"+str(ch)+"`" for ch in atr.metadata["choices"]])}].'
150+
# # help_str += f'. Defaults to {atr.default}'
151+
# doc_str+=f'\n\t\t{key} ({type}): {help_str}'
152+
153+
# print(doc_str)

0 commit comments

Comments
 (0)