Skip to content

Commit 1b2abf5

Browse files
authored
verbosity fix for tabular datamodule as well (#353)
1 parent 3350681 commit 1b2abf5

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

src/pytorch_tabular/models/common/layers/activations.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,13 @@ def forward(self, input: Tensor, r: Tensor):
9595
return self.tsoftmax(input, t)
9696

9797

98-
"""
99-
An implementation of entmax (Peters et al., 2019). See
100-
https://arxiv.org/pdf/1905.05702 for detailed description.
98+
# """
99+
# An implementation of entmax (Peters et al., 2019). See
100+
# https://arxiv.org/pdf/1905.05702 for detailed description.
101101

102-
This builds on previous work with sparsemax (Martins & Astudillo, 2016).
103-
See https://arxiv.org/pdf/1602.02068.
104-
"""
102+
# This builds on previous work with sparsemax (Martins & Astudillo, 2016).
103+
# See https://arxiv.org/pdf/1602.02068.
104+
# """
105105

106106
# Author: Ben Peters
107107
# Author: Vlad Niculae <vlad@vene.ro>

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def __getitem__(self, idx):
100100
"""Generates one sample of data."""
101101
return {
102102
"target": self.y[idx],
103-
"continuous": self.continuous_X[idx] if self.continuous_cols else torch.Tensor(),
104-
"categorical": self.categorical_X[idx] if self.categorical_cols else torch.Tensor(),
103+
"continuous": (self.continuous_X[idx] if self.continuous_cols else torch.Tensor()),
104+
"categorical": (self.categorical_X[idx] if self.categorical_cols else torch.Tensor()),
105105
}
106106

107107

@@ -140,6 +140,7 @@ def __init__(
140140
seed: Optional[int] = 42,
141141
cache_data: str = "memory",
142142
copy_data: bool = True,
143+
verbose: bool = True,
143144
):
144145
"""The Pytorch Lightning Datamodule for Tabular Data.
145146
@@ -168,6 +169,8 @@ def __init__(
168169
"memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".
169170
170171
copy_data (bool): If True, will copy the dataframes before preprocessing. Defaults to True.
172+
173+
verbose (bool): Sets the verbosity of the databodule logging
171174
"""
172175
super().__init__()
173176
self.train = train.copy() if copy_data else train
@@ -181,6 +184,7 @@ def __init__(
181184
self.train_sampler = train_sampler
182185
self.config = config
183186
self.seed = seed
187+
self.verbose = verbose
184188
self._fitted = False
185189
self._setup_cache(cache_data)
186190
self._inferred_config = self._update_config(config)
@@ -266,7 +270,7 @@ def _encode_categorical_columns(self, data: DataFrame, stage: str) -> DataFrame:
266270
logger.debug("Encoding Categorical Columns using OrdinalEncoder")
267271
self.categorical_encoder = OrdinalEncoder(
268272
cols=self.config.categorical_cols,
269-
handle_unseen="impute" if self.config.handle_unknown_categories else "error",
273+
handle_unseen=("impute" if self.config.handle_unknown_categories else "error"),
270274
handle_missing="impute" if self.config.handle_missing_values else "error",
271275
)
272276
data = self.categorical_encoder.fit_transform(data)
@@ -400,7 +404,7 @@ def _cache_dataset(self):
400404

401405
def split_train_val(self, train):
402406
logger.debug(
403-
f"No validation data provided." f" Using {self.config.validation_split*100}% of train data as validation"
407+
"No validation data provided." f" Using {self.config.validation_split*100}% of train data as validation"
404408
)
405409
val_idx = train.sample(
406410
int(self.config.validation_split * len(train)),
@@ -420,7 +424,8 @@ def setup(self, stage: Optional[str] = None) -> None:
420424
"""
421425
if not (stage is None or stage == "fit" or stage == "ssl_finetune"):
422426
return
423-
logger.info(f"Setting up the datamodule for {self.config.task} task")
427+
if self.verbose:
428+
logger.info(f"Setting up the datamodule for {self.config.task} task")
424429
is_ssl = stage == "ssl_finetune"
425430
if self.validation is None:
426431
self.train, self.validation = self.split_train_val(self.train)
@@ -496,7 +501,7 @@ def time_features_from_frequency_str(cls, freq_str: str) -> List[str]:
496501
"Is_year_end",
497502
"Is_year_start",
498503
"Is_month_start",
499-
"Week" "Day",
504+
"WeekDay",
500505
"Dayofweek",
501506
"Dayofyear",
502507
],
@@ -508,7 +513,7 @@ def time_features_from_frequency_str(cls, freq_str: str) -> List[str]:
508513
"Is_year_end",
509514
"Is_year_start",
510515
"Is_month_start",
511-
"Week" "Day",
516+
"WeekDay",
512517
"Dayofweek",
513518
"Dayofyear",
514519
],
@@ -520,7 +525,7 @@ def time_features_from_frequency_str(cls, freq_str: str) -> List[str]:
520525
"Is_year_end",
521526
"Is_year_start",
522527
"Is_month_start",
523-
"Week" "Day",
528+
"WeekDay",
524529
"Dayofweek",
525530
"Dayofyear",
526531
"Hour",
@@ -533,7 +538,7 @@ def time_features_from_frequency_str(cls, freq_str: str) -> List[str]:
533538
"Is_year_end",
534539
"Is_year_start",
535540
"Is_month_start",
536-
"Week" "Day",
541+
"WeekDay",
537542
"Dayofweek",
538543
"Dayofyear",
539544
"Hour",
@@ -645,16 +650,18 @@ def _load_dataset_from_cache(self, tag: str = "train"):
645650
try:
646651
dataset = getattr(self, f"{tag}_dataset")
647652
except AttributeError:
648-
raise AttributeError(f"{tag}_dataset not found in memory. Please provide the data for {tag} dataloader")
653+
raise AttributeError(
654+
f"{tag}_dataset not found in memory. Please provide the data for" f" {tag} dataloader"
655+
)
649656
elif self.cache_mode is self.CACHE_MODES.DISK:
650657
try:
651658
dataset = torch.load(self.cache_dir / f"{tag}_dataset")
652659
except FileNotFoundError:
653660
raise FileNotFoundError(
654-
f"{tag}_dataset not found in {self.cache_dir}. Please provide the data for {tag} dataloader"
661+
f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader"
655662
)
656663
elif self.cache_mode is self.CACHE_MODES.INFERENCE:
657-
raise RuntimeError("Cannot load dataset in inference mode. Use `prepare_inference_dataloader` instead")
664+
raise RuntimeError("Cannot load dataset in inference mode. Use" " `prepare_inference_dataloader` instead")
658665
else:
659666
raise ValueError(f"{self.cache_mode} is not a valid cache mode")
660667
return dataset
@@ -741,7 +748,7 @@ def prepare_inference_dataloader(
741748
data=df,
742749
categorical_cols=self.config.categorical_cols,
743750
continuous_cols=self.config.continuous_cols,
744-
target=self.target if all(col in df.columns for col in self.target) else None,
751+
target=(self.target if all(col in df.columns for col in self.target) else None),
745752
)
746753
return DataLoader(
747754
dataset,

src/pytorch_tabular/tabular_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ def prepare_dataloader(
508508
train_sampler=train_sampler,
509509
seed=seed,
510510
cache_data=cache_data,
511+
verbose=self.verbose,
511512
)
512513
datamodule.prepare_data()
513514
datamodule.setup("fit")

0 commit comments

Comments
 (0)