|
12 | 12 | class TabTransformerConfig(ModelConfig): |
13 | 13 | """Tab Transformer configuration |
14 | 14 | Args: |
15 | | - task (str): Specify whether the problem is regression of classification.Choices are: regression classification |
| 15 | + task (str): Specify whether the problem is regression of classification. |
| 16 | + Choices are: [`regression`,`classification`]. |
| 17 | + embedding_dims (Union[List[int], NoneType]): The dimensions of the embedding for |
| 18 | + each categorical column as a list of tuples (cardinality, embedding_dim). |
| 19 | + If left empty, will infer using the cardinality of the categorical column using |
| 20 | + the rule min(50, (x + 1) // 2) |
16 | 21 | learning_rate (float): The learning rate of the model |
17 | | - loss (Union[str, NoneType]): The loss function to be applied. |
18 | | - By Default it is MSELoss for regression and CrossEntropyLoss for classification. |
19 | | - Unless you are sure what you are doing, leave it at MSELoss or L1Loss for regression and CrossEntropyLoss for classification |
20 | | - metrics (Union[List[str], NoneType]): the list of metrics you need to track during training. |
21 | | - The metrics should be one of the metrics implemented in PyTorch Lightning. |
22 | | - By default, it is Accuracy if classification and MeanSquaredLogError for regression |
23 | | - metrics_params (Union[List, NoneType]): The parameters to be passed to the Metrics initialized |
24 | | - target_range (Union[List, NoneType]): The range in which we should limit the output variable. Currently ignored for multi-target regression |
25 | | - Typically used for Regression problems. If left empty, will not apply any restrictions |
| 22 | + loss (Union[str, NoneType]): The loss function to be applied. |
| 23 | + By Default it is MSELoss for regression and CrossEntropyLoss for classification. |
| 24 | + Unless you are sure what you are doing, leave it at MSELoss or L1Loss for regression |
| 25 | + and CrossEntropyLoss for classification |
| 26 | + metrics (Union[List[str], NoneType]): the list of metrics you need to track during training. |
| 27 | + The metrics should be one of the functional metrics implemented in ``torchmetrics``. |
| 28 | + By default, it is accuracy if classification and mean_squared_error for regression |
| 29 | + metrics_params (Union[List, NoneType]): The parameters to be passed to the metrics function |
| 30 | + target_range (Union[List, NoneType]): The range in which we should limit the output variable. |
| 31 | + Currently ignored for multi-target regression. Typically used for Regression problems. |
| 32 | + If left empty, will not apply any restrictions |
26 | 33 |
|
27 | | - attn_embed_dim (int): The number of hidden units in the Multi-Headed Attention layers. Defaults to 32 |
28 | | - num_heads (int): The number of heads in the Multi-Headed Attention layer. Defaults to 2 |
29 | | - num_attn_blocks (int): The number of layers of stacked Multi-Headed Attention layers. Defaults to 2 |
30 | | - attn_dropouts (float): Dropout between layers of Multi-Headed Attention Layers. Defaults to 0.0 |
31 | | - has_residuals (bool): Flag to have a residual connect from enbedded output to attention layer output. |
32 | | - Defaults to True |
33 | | - embedding_dim (int): The dimensions of the embedding for continuous and categorical columns. Defaults to 16 |
34 | | - embedding_dropout (float): probability of an embedding element to be zeroed. Defaults to 0.0 |
35 | | - deep_layers (bool): Flag to enable a deep MLP layer before the Multi-Headed Attention layer. Defaults to False |
36 | | - layers (str): Hyphen-separated number of layers and units in the deep MLP. Defaults to 128-64-32 |
37 | | - activation (str): The activation type in the deep MLP. The default activaion in PyTorch like |
38 | | - ReLU, TanH, LeakyReLU, etc. https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity. |
| 34 | + input_embed_dim (int): The embedding dimension for the input categorical features. |
| 35 | + Defaults to 32 |
| 36 | + embedding_dropout (float): Dropout to be applied to the Categorical Embedding. |
| 37 | + Defaults to 0.1 |
| 38 | + share_embedding (bool): The flag turns on shared embeddings in the input embedding process. |
| 39 | + The key idea here is to have an embedding for the feature as a whole along with embeddings of |
| 40 | + each unique values of that column. For more details refer to Appendix A of the TabTransformer paper. |
| 41 | + Defaults to False |
| 42 | + share_embedding_strategy (Union[str, NoneType]): There are two strategies in adding shared embeddings. |
| 43 | + 1. `add` - A separate embedding for the feature is added to the embedding of the unique values of the feature. |
| 44 | + 2. `fraction` - A fraction of the input embedding is reserved for the shared embedding of the feature. |
| 45 | + Defaults to fraction. |
| 46 | + Choices are: [`add`,`fraction`]. |
| 47 | + shared_embedding_fraction (float): Fraction of the input_embed_dim to be reserved by the shared embedding. |
| 48 | + Should be less than one. Defaults to 0.25 |
| 49 | + num_heads (int): The number of heads in the Multi-Headed Attention layer. |
| 50 | + Defaults to 8 |
| 51 | + num_attn_blocks (int): The number of layers of stacked Multi-Headed Attention layers. |
| 52 | + Defaults to 6 |
| 53 | + transformer_head_dim (Union[int, NoneType]): The number of hidden units in the Multi-Headed Attention layers. |
| 54 | + Defaults to None and will be same as input_dim. |
| 55 | + attn_dropout (float): Dropout to be applied after Multi headed Attention. |
| 56 | + Defaults to 0.1 |
| 57 | + add_norm_dropout (float): Dropout to be applied in the AddNorm Layer. |
| 58 | + Defaults to 0.1 |
| 59 | + ff_dropout (float): Dropout to be applied in the Positionwise FeedForward Network. |
| 60 | + Defaults to 0.1 |
| 61 | + ff_hidden_multiplier (int): Multiple by which the Positionwise FF layer scales the input. |
| 62 | + Defaults to 4 |
| 63 | + transformer_activation (str): The activation type in the transformer feed forward layers. |
| 64 | + In addition to the default activation in PyTorch like ReLU, TanH, LeakyReLU, etc. |
| 65 | + https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity, |
| 66 | + GEGLU, ReGLU and SwiGLU are also implemented(https://arxiv.org/pdf/2002.05202.pdf). |
| 67 | + Defaults to GEGLU |
| 68 | + out_ff_layers (str): Hyphen-separated number of layers and units in the deep MLP. |
| 69 | + Defaults to 128-64-32 |
| 70 | + out_ff_activation (str): The activation type in the deep MLP. The default activaion in PyTorch like ReLU, TanH, LeakyReLU, etc. |
| 71 | + https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity. |
39 | 72 | Defaults to ReLU |
40 | | - dropout (float): probability of an classification element to be zeroed in the deep MLP. Defaults to 0.0 |
41 | | - use_batch_norm (bool): Flag to include a BatchNorm layer after each Linear Layer+DropOut. Defaults to False |
42 | | - batch_norm_continuous_input (bool): If True, we will normalize the contiinuous layer by passing it through a BatchNorm layer. Defaults to False |
43 | | - attention_pooling (bool): If True, will combine the attention outputs of each block for final prediction. Defaults to False |
44 | | - initialization (str): Initialization scheme for the linear layers. Defaults to `kaiming`. |
| 73 | + out_ff_dropout (float): Probability of an classification element to be zeroed in the deep MLP. |
| 74 | + Defaults to 0.0 |
| 75 | + use_batch_norm (bool): Flag to include a BatchNorm layer after each Linear Layer+DropOut. |
| 76 | + Defaults to False |
| 77 | + batch_norm_continuous_input (bool): If True, we will normalize the continuous layer by passing it through a BatchNorm layer. |
| 78 | + Defaults to False |
| 79 | + out_ff_initialization (str): Initialization scheme for the linear layers. |
| 80 | + Defaults to `kaiming`. |
45 | 81 | Choices are: [`kaiming`,`xavier`,`random`]. |
46 | 82 |
|
47 | 83 | Raises: |
@@ -170,7 +206,7 @@ class TabTransformerConfig(ModelConfig): |
170 | 206 | _config_name: str = field(default="TabTransformerConfig") |
171 | 207 |
|
172 | 208 |
|
173 | | -# cls = AutoIntConfig |
| 209 | +# cls = TabTransformerConfig |
174 | 210 | # desc = "Configuration for Data." |
175 | 211 | # doc_str = f"{desc}\nArgs:" |
176 | 212 | # for key in cls.__dataclass_fields__.keys(): |
|
0 commit comments