@@ -100,8 +100,8 @@ def __getitem__(self, idx):
100100 """Generates one sample of data."""
101101 return {
102102 "target" : self .y [idx ],
103- "continuous" : self .continuous_X [idx ] if self .continuous_cols else torch .Tensor (),
104- "categorical" : self .categorical_X [idx ] if self .categorical_cols else torch .Tensor (),
103+ "continuous" : ( self .continuous_X [idx ] if self .continuous_cols else torch .Tensor () ),
104+ "categorical" : ( self .categorical_X [idx ] if self .categorical_cols else torch .Tensor () ),
105105 }
106106
107107
@@ -140,6 +140,7 @@ def __init__(
140140 seed : Optional [int ] = 42 ,
141141 cache_data : str = "memory" ,
142142 copy_data : bool = True ,
143+ verbose : bool = True ,
143144 ):
144145 """The Pytorch Lightning Datamodule for Tabular Data.
145146
@@ -168,6 +169,8 @@ def __init__(
168169 "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".
169170
170171 copy_data (bool): If True, will copy the dataframes before preprocessing. Defaults to True.
172+
173+ verbose (bool): Sets the verbosity of the databodule logging
171174 """
172175 super ().__init__ ()
173176 self .train = train .copy () if copy_data else train
@@ -181,6 +184,7 @@ def __init__(
181184 self .train_sampler = train_sampler
182185 self .config = config
183186 self .seed = seed
187+ self .verbose = verbose
184188 self ._fitted = False
185189 self ._setup_cache (cache_data )
186190 self ._inferred_config = self ._update_config (config )
@@ -266,7 +270,7 @@ def _encode_categorical_columns(self, data: DataFrame, stage: str) -> DataFrame:
266270 logger .debug ("Encoding Categorical Columns using OrdinalEncoder" )
267271 self .categorical_encoder = OrdinalEncoder (
268272 cols = self .config .categorical_cols ,
269- handle_unseen = "impute" if self .config .handle_unknown_categories else "error" ,
273+ handle_unseen = ( "impute" if self .config .handle_unknown_categories else "error" ) ,
270274 handle_missing = "impute" if self .config .handle_missing_values else "error" ,
271275 )
272276 data = self .categorical_encoder .fit_transform (data )
@@ -400,7 +404,7 @@ def _cache_dataset(self):
400404
401405 def split_train_val (self , train ):
402406 logger .debug (
403- f "No validation data provided." f" Using { self .config .validation_split * 100 } % of train data as validation"
407+ "No validation data provided." f" Using { self .config .validation_split * 100 } % of train data as validation"
404408 )
405409 val_idx = train .sample (
406410 int (self .config .validation_split * len (train )),
@@ -420,7 +424,8 @@ def setup(self, stage: Optional[str] = None) -> None:
420424 """
421425 if not (stage is None or stage == "fit" or stage == "ssl_finetune" ):
422426 return
423- logger .info (f"Setting up the datamodule for { self .config .task } task" )
427+ if self .verbose :
428+ logger .info (f"Setting up the datamodule for { self .config .task } task" )
424429 is_ssl = stage == "ssl_finetune"
425430 if self .validation is None :
426431 self .train , self .validation = self .split_train_val (self .train )
@@ -496,7 +501,7 @@ def time_features_from_frequency_str(cls, freq_str: str) -> List[str]:
496501 "Is_year_end" ,
497502 "Is_year_start" ,
498503 "Is_month_start" ,
499- "Week" "Day " ,
504+ "WeekDay " ,
500505 "Dayofweek" ,
501506 "Dayofyear" ,
502507 ],
@@ -508,7 +513,7 @@ def time_features_from_frequency_str(cls, freq_str: str) -> List[str]:
508513 "Is_year_end" ,
509514 "Is_year_start" ,
510515 "Is_month_start" ,
511- "Week" "Day " ,
516+ "WeekDay " ,
512517 "Dayofweek" ,
513518 "Dayofyear" ,
514519 ],
@@ -520,7 +525,7 @@ def time_features_from_frequency_str(cls, freq_str: str) -> List[str]:
520525 "Is_year_end" ,
521526 "Is_year_start" ,
522527 "Is_month_start" ,
523- "Week" "Day " ,
528+ "WeekDay " ,
524529 "Dayofweek" ,
525530 "Dayofyear" ,
526531 "Hour" ,
@@ -533,7 +538,7 @@ def time_features_from_frequency_str(cls, freq_str: str) -> List[str]:
533538 "Is_year_end" ,
534539 "Is_year_start" ,
535540 "Is_month_start" ,
536- "Week" "Day " ,
541+ "WeekDay " ,
537542 "Dayofweek" ,
538543 "Dayofyear" ,
539544 "Hour" ,
@@ -645,16 +650,18 @@ def _load_dataset_from_cache(self, tag: str = "train"):
645650 try :
646651 dataset = getattr (self , f"{ tag } _dataset" )
647652 except AttributeError :
648- raise AttributeError (f"{ tag } _dataset not found in memory. Please provide the data for { tag } dataloader" )
653+ raise AttributeError (
654+ f"{ tag } _dataset not found in memory. Please provide the data for" f" { tag } dataloader"
655+ )
649656 elif self .cache_mode is self .CACHE_MODES .DISK :
650657 try :
651658 dataset = torch .load (self .cache_dir / f"{ tag } _dataset" )
652659 except FileNotFoundError :
653660 raise FileNotFoundError (
654- f"{ tag } _dataset not found in { self .cache_dir } . Please provide the data for { tag } dataloader"
661+ f"{ tag } _dataset not found in { self .cache_dir } . Please provide the" f" data for { tag } dataloader"
655662 )
656663 elif self .cache_mode is self .CACHE_MODES .INFERENCE :
657- raise RuntimeError ("Cannot load dataset in inference mode. Use `prepare_inference_dataloader` instead" )
664+ raise RuntimeError ("Cannot load dataset in inference mode. Use" " `prepare_inference_dataloader` instead" )
658665 else :
659666 raise ValueError (f"{ self .cache_mode } is not a valid cache mode" )
660667 return dataset
@@ -741,7 +748,7 @@ def prepare_inference_dataloader(
741748 data = df ,
742749 categorical_cols = self .config .categorical_cols ,
743750 continuous_cols = self .config .continuous_cols ,
744- target = self .target if all (col in df .columns for col in self .target ) else None ,
751+ target = ( self .target if all (col in df .columns for col in self .target ) else None ) ,
745752 )
746753 return DataLoader (
747754 dataset ,
0 commit comments