Skip to content

Commit e1d5752

Browse files
authored
General_improvements and Doc changes (#351)
* general improvements and doc changes * fix
1 parent 64d3d8f commit e1d5752

File tree

11 files changed

+150
-25
lines changed

11 files changed

+150
-25
lines changed

docs/tabular_model.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,14 @@ For self-supervised learning, there is a different API because the process is di
129129
options:
130130
show_root_heading: yes
131131
heading_level: 4
132+
::: pytorch_tabular.TabularModel.cross_validate
133+
options:
134+
show_root_heading: yes
135+
heading_level: 4
136+
::: pytorch_tabular.TabularModel.bagging_predict
137+
options:
138+
show_root_heading: yes
139+
heading_level: 4
132140

133141
# Artifact Saving and Loading
134142

src/pytorch_tabular/categorical_encoders.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@
2222

2323
class BaseEncoder:
2424
def __init__(self, cols, handle_unseen, min_samples, imputed, handle_missing):
25+
"""Base class for categorical encoders.
26+
Args:
27+
cols (list): list of columns to encode, or None (then all dataset columns will be encoded at fitting time)
28+
handle_unseen (str):
29+
'error' - raise an error if a category unseen at fitting time is found
30+
'ignore' - skip unseen categories
31+
'impute' - impute new categories to a predefined value, which is same as NAN_CATEGORY
32+
min_samples (int): minimum samples to take category as valid
33+
imputed (float): value to impute unseen categories
34+
handle_missing (str):
35+
'error' - raise an error if missing values are found in columns to encode
36+
'impute' - impute missing values to a predefined value, which is same as NAN_CATEGORY
37+
"""
2538
self.cols = cols
2639
self.handle_unseen = handle_unseen
2740
self.handle_missing = handle_missing
@@ -87,11 +100,21 @@ def _before_fit_check(self, X, y):
87100
assert X.shape[0] == y.shape[0]
88101

89102
def save_as_object_file(self, path):
103+
"""Save the encoder as a pickle file.
104+
105+
Args:
106+
path (str): path to save the encoder
107+
"""
90108
if not self._mapping:
91109
raise ValueError("`fit` method must be called before `save_as_object_file`.")
92110
pickle.dump(self.__dict__, open(path, "wb"))
93111

94112
def load_from_object_file(self, path):
113+
"""Load the encoder from a pickle file.
114+
115+
Args:
116+
path (str): path to load the encoder
117+
"""
95118
for k, v in pickle.load(open(path, "rb")).items():
96119
setattr(self, k, v)
97120

src/pytorch_tabular/config/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""Config."""
55
import os
66
import re
7-
import warnings
87
from dataclasses import MISSING, dataclass, field
98
from typing import Any, Dict, Iterable, List, Optional
109

@@ -242,7 +241,8 @@ def __post_init__(self):
242241

243242
@dataclass
244243
class TrainerConfig:
245-
"""Trainer configuration
244+
"""Trainer configuration.
245+
246246
Args:
247247
batch_size (int): Number of samples in each batch of training
248248
@@ -539,7 +539,6 @@ def __post_init__(self):
539539
if self.accelerator is None:
540540
self.accelerator = "cpu"
541541
if self.devices_list is not None:
542-
warnings.warn("Ignoring devices in favor of devices_list")
543542
self.devices = self.devices_list
544543
delattr(self, "devices_list")
545544
for key in self.early_stopping_kwargs.keys():

src/pytorch_tabular/feature_extractor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(self, tabular_model, extract_keys=["backbone_features"], drop_origi
2424
2525
Args:
2626
tabular_model (TabularModel): The trained TabularModel object
27+
extract_keys (list, optional): The keys of the features to extract. Defaults to ["backbone_features"].
28+
drop_original (bool, optional): Whether to drop the original columns. Defaults to True.
2729
"""
2830
assert not (
2931
isinstance(tabular_model.model, NODEModel)
@@ -102,10 +104,20 @@ def fit_transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
102104
return self.transform(X)
103105

104106
def save_as_object_file(self, path):
107+
"""Saves the feature extractor as a pickle file.
108+
109+
Args:
110+
path (str): The path to save the file
111+
"""
105112
if not self._mapping:
106113
raise ValueError("`fit` method must be called before `save_as_object_file`.")
107114
pickle.dump(self.__dict__, open(path, "wb"))
108115

109116
def load_from_object_file(self, path):
117+
"""Loads the feature extractor from a pickle file.
118+
119+
Args:
120+
path (str): The path to load the file from
121+
"""
110122
for k, v in pickle.load(open(path, "rb")).items():
111123
setattr(self, k, v)

src/pytorch_tabular/models/autoint/autoint.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515

1616
class AutoIntBackbone(nn.Module):
1717
def __init__(self, config: DictConfig):
18+
"""Automatic Feature Interaction Network.
19+
20+
Args:
21+
config (DictConfig): config of the model
22+
"""
1823
super().__init__()
1924
self.hparams = config
2025
self._build_network()

src/pytorch_tabular/models/base_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,14 @@ def __init__(
157157
warnings.warn(
158158
"Wandb is not installed. Please install wandb to log logits. "
159159
"You can install wandb using pip install wandb or install PyTorch Tabular"
160-
" using pip install pytorch-tabular[all]"
160+
" using pip install pytorch-tabular[extra]"
161161
)
162162
if not PLOTLY_INSTALLED:
163163
self.do_log_logits = False
164164
warnings.warn(
165165
"Plotly is not installed. Please install plotly to log logits. "
166166
"You can install plotly using pip install plotly or install PyTorch Tabular"
167-
" using pip install pytorch-tabular[all]"
167+
" using pip install pytorch-tabular[extra]"
168168
)
169169

170170
@abstractmethod
@@ -376,7 +376,7 @@ def pack_output(self, y_hat: torch.Tensor, backbone_features: torch.tensor) -> D
376376
"""
377377
# if self.head is the Identity function it means that we cannot extract backbone features,
378378
# because the model cannot be divide in backbone and head (i.e. TabNet)
379-
if type(self.head) == nn.Identity:
379+
if type(self.head) is nn.Identity:
380380
return {"logits": y_hat}
381381
return {"logits": y_hat, "backbone_features": backbone_features}
382382

src/pytorch_tabular/models/mixture_density/mdn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Author: Manu Joseph <manujoseph@gmail.com>
33
# For license information, see LICENSE.TXT
44
"""Mixture Density Models."""
5-
import warnings
65
from typing import Dict, Optional, Union
76

87
import torch
@@ -13,7 +12,9 @@
1312
from pytorch_tabular import models
1413
from pytorch_tabular.config.config import ModelConfig
1514
from pytorch_tabular.models.common.heads import blocks
16-
from pytorch_tabular.models.tab_transformer.tab_transformer import TabTransformerBackbone
15+
from pytorch_tabular.models.tab_transformer.tab_transformer import (
16+
TabTransformerBackbone,
17+
)
1718
from pytorch_tabular.tabular_model import getattr_nested
1819
from pytorch_tabular.utils import get_logger
1920

@@ -22,7 +23,7 @@
2223
try:
2324
import wandb
2425
except ImportError:
25-
warnings.warn("Wandb not installed. WandbLogger will not work.")
26+
pass
2627

2728
logger = get_logger(__name__)
2829

src/pytorch_tabular/models/node/config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from dataclasses import dataclass, field
23
from typing import Optional
34

@@ -195,11 +196,27 @@ class NodeConfig(ModelConfig):
195196
},
196197
)
197198

199+
head: Optional[str] = field(
200+
default=None,
201+
)
202+
198203
_module_src: str = field(default="models.node")
199204
_model_name: str = field(default="NODEModel")
200205
_backbone_name: str = field(default="NODEBackbone")
201206
_config_name: str = field(default="NodeConfig")
202207

208+
def __post_init__(self):
209+
if self.head is not None:
210+
warnings.warn(
211+
"`head` and `head_config` is ignored as NODE has a specific"
212+
" head which subsets the tree outputs. Set `head=None`"
213+
" to turn off the warning"
214+
)
215+
else:
216+
# Setting Head to LinearHead for compatibility
217+
self.head = "LinearHead"
218+
return super().__post_init__()
219+
203220

204221
# if __name__ == "__main__":
205222
# from pytorch_tabular.utils import generate_doc_dataclass

src/pytorch_tabular/ssl_models/base_model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ def __init__(
4444
custom_optimizer_params: Dict = {},
4545
**kwargs,
4646
):
47+
"""Base Model for all SSL Models.
48+
49+
Args:
50+
config (DictConfig): Configuration defined by the user
51+
mode (str, optional): Mode of the model. Defaults to "pretrain".
52+
encoder (Optional[nn.Module], optional): Encoder of the model. Defaults to None.
53+
decoder (Optional[nn.Module], optional): Decoder of the model. Defaults to None.
54+
custom_optimizer (Optional[torch.optim.Optimizer], optional): Custom optimizer to use. Defaults to None.
55+
custom_optimizer_params (Dict, optional): Custom optimizer parameters to use. Defaults to {}.
56+
"""
4757
super().__init__()
4858
assert "inferred_config" in kwargs, "inferred_config not found in initialization arguments"
4959
inferred_config = kwargs["inferred_config"]
@@ -167,7 +177,9 @@ def test_step(self, batch, batch_idx):
167177

168178
def on_validation_epoch_end(self) -> None:
169179
if hasattr(self.hparams, "log_logits") and self.hparams.log_logits:
170-
warnings.warn("Logging Logits is disabled for SSL tasks")
180+
warnings.warn(
181+
"Logging Logits is disabled for SSL tasks. Set `log_logits` to False" " to turn off this warning"
182+
)
171183
super().on_validation_epoch_end()
172184

173185
def configure_optimizers(self):

src/pytorch_tabular/tabular_model.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -714,8 +714,9 @@ def fit(
714714
else:
715715
if train is not None:
716716
warnings.warn(
717-
"train data is provided but datamodule is provided."
718-
" Ignoring the train data and using the datamodule"
717+
"train data and datamodule is provided."
718+
" Ignoring the train data and using the datamodule."
719+
" Set either one of them to None to avoid this warning."
719720
)
720721
model = self.prepare_model(
721722
datamodule,
@@ -791,8 +792,9 @@ def pretrain(
791792
else:
792793
if train is not None:
793794
warnings.warn(
794-
"train data is provided but datamodule is provided."
795-
" Ignoring the train data and using the datamodule"
795+
"train data and datamodule is provided."
796+
" Ignoring the train data and using the datamodule."
797+
" Set either one of them to None to avoid this warning."
796798
)
797799
model = self.prepare_model(
798800
datamodule,
@@ -1050,8 +1052,9 @@ def finetune(
10501052
else:
10511053
if train is not None:
10521054
warnings.warn(
1053-
"train data is provided but datamodule is provided."
1054-
" Ignoring the train data and using the datamodule"
1055+
"train data and datamodule is provided."
1056+
" Ignoring the train data and using the datamodule."
1057+
" Set either one of them to None to avoid this warning."
10551058
)
10561059
if freeze_backbone:
10571060
for param in self.model.backbone.parameters():
@@ -1197,7 +1200,9 @@ def predict(
11971200
If classification, it returns probabilities and final prediction
11981201
"""
11991202
warnings.warn(
1200-
"`include_input_features` will be deprecated in the next release.",
1203+
"`include_input_features` will be deprecated in the next release."
1204+
" Please add index columns to the test dataframe if you want to"
1205+
" retain some features like the key or id",
12011206
DeprecationWarning,
12021207
)
12031208
assert all(q <= 1 and q >= 0 for q in quantiles), "Quantiles should be a decimal between 0 and 1"
@@ -1286,6 +1291,11 @@ def predict(
12861291
pred_df["prediction"] = self.datamodule.label_encoder.inverse_transform(
12871292
np.argmax(point_predictions, axis=1)
12881293
)
1294+
warnings.warn(
1295+
"Classification prediction column will be renamed to `{target_col}_prediction` "
1296+
"in the next release to maintain consistency with regression.",
1297+
DeprecationWarning,
1298+
)
12891299
if ret_logits:
12901300
for k, v in logits_predictions.items():
12911301
v = torch.cat(v, dim=0).numpy()
@@ -1558,6 +1568,7 @@ def explain(
15581568
Defaults to None.
15591569
15601570
**kwargs: Additional keyword arguments to be passed to the Captum method `attribute` function.
1571+
15611572
Returns:
15621573
DataFrame: The dataframe with the feature importance
15631574
"""
@@ -1587,7 +1598,7 @@ def explain(
15871598
if len(data) <= 100:
15881599
warnings.warn(
15891600
f"{method} gives better results when the number of samples is"
1590-
" large. For better results, try usingmore samples or some other"
1601+
" large. For better results, try using more samples or some other"
15911602
" methods like GradientShap which works well on single examples."
15921603
)
15931604
is_full_baselines = method in ["GradientShap", "DeepLiftShap"]
@@ -1742,6 +1753,7 @@ def cross_validate(
17421753
fold, they will be valid for all the other folds. Defaults to True.
17431754
17441755
**kwargs: Additional keyword arguments to be passed to the `fit` method of the model.
1756+
17451757
Returns:
17461758
DataFrame: The dataframe with the cross validation results
17471759
"""
@@ -1900,6 +1912,7 @@ def bagging_predict(
19001912
Defaults to None.
19011913
19021914
**kwargs: Additional keyword arguments to be passed to the `fit` method of the model.
1915+
19031916
Returns:
19041917
DataFrame: The dataframe with the bagged predictions.
19051918
"""

0 commit comments

Comments
 (0)