Skip to content

Commit de2ebe1

Browse files
Flag to disable Lightning Logs (#367)
* added flag to disable lightning logs * changed default value to suppress logs as false * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 24ce9e5 commit de2ebe1

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

src/pytorch_tabular/tabular_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@
5151
from pytorch_tabular.utils import (
5252
OOMException,
5353
OutOfMemoryHandler,
54+
count_parameters,
5455
get_logger,
5556
getattr_nested,
5657
pl_load,
58+
suppress_lightning_logs,
5759
)
58-
from pytorch_tabular.utils.nn_utils import count_parameters
5960

6061
try:
6162
import captum.attr
@@ -79,6 +80,7 @@ def __init__(
7980
model_callable: Optional[Callable] = None,
8081
model_state_dict_path: Optional[Union[str, Path]] = None,
8182
verbose: bool = True,
83+
suppress_lightning_logger: bool = False,
8284
) -> None:
8385
"""The core model which orchestrates everything from initializing the datamodule, the model, trainer, etc.
8486
@@ -111,8 +113,13 @@ def __init__(
111113
If provided, will load the state dict after initializing the model from config.
112114
113115
verbose (bool): turns off and on the logging. Defaults to True.
116+
117+
suppress_lightning_logger (bool): If True, will suppress the default logging from PyTorch Lightning.
118+
Defaults to False.
114119
"""
115120
super().__init__()
121+
if suppress_lightning_logger:
122+
suppress_lightning_logs()
116123
self.verbose = verbose
117124
self.exp_manager = ExperimentRunManager()
118125
if config is None:

src/pytorch_tabular/tabular_model_tuner.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
TrainerConfig,
2323
)
2424
from pytorch_tabular.tabular_model import TabularModel
25-
from pytorch_tabular.utils import OOMException, OutOfMemoryHandler, get_logger
25+
from pytorch_tabular.utils import OOMException, OutOfMemoryHandler, get_logger, suppress_lightning_logs
2626

2727
logger = get_logger(__name__)
2828

@@ -45,6 +45,7 @@ def __init__(
4545
trainer_config: Optional[Union[TrainerConfig, str]] = None,
4646
model_callable: Optional[Callable] = None,
4747
model_state_dict_path: Optional[Union[str, Path]] = None,
48+
suppress_lightning_logger: bool = True,
4849
**kwargs,
4950
):
5051
"""Tabular Model Tuner helps you tune the hyperparameters of a TabularModel.
@@ -53,21 +54,30 @@ def __init__(
5354
data_config (Optional[Union[DataConfig, str]], optional): The DataConfig for the TabularModel.
5455
If str is passed, will initialize the DataConfig using the yaml file in that path.
5556
Defaults to None.
57+
5658
model_config (Optional[Union[ModelConfig, str]], optional): The ModelConfig for the TabularModel.
5759
If str is passed, will initialize the ModelConfig using the yaml file in that path.
5860
Defaults to None.
61+
5962
optimizer_config (Optional[Union[OptimizerConfig, str]], optional): The OptimizerConfig for the
6063
TabularModel. If str is passed, will initialize the OptimizerConfig using the yaml file in
6164
that path. Defaults to None.
65+
6266
trainer_config (Optional[Union[TrainerConfig, str]], optional): The TrainerConfig for the TabularModel.
6367
If str is passed, will initialize the TrainerConfig using the yaml file in that path.
6468
Defaults to None.
69+
6570
model_callable (Optional[Callable], optional): A callable that returns a PyTorch Tabular Model.
6671
If provided, will ignore the model_config and use this callable to initialize the model.
6772
Defaults to None.
73+
6874
model_state_dict_path (Optional[Union[str, Path]], optional): Path to the state dict of the model.
75+
6976
If provided, will ignore the model_config and use this state dict to initialize the model.
7077
Defaults to None.
78+
79+
suppress_lightning_logger (bool, optional): Whether to suppress the lightning logger. Defaults to True.
80+
7181
**kwargs: Additional keyword arguments to be passed to the TabularModel init.
7282
"""
7383
if trainer_config.profiler is not None:
@@ -84,6 +94,7 @@ def __init__(
8494
self.model_config = model_config
8595
self.optimizer_config = optimizer_config
8696
self.trainer_config = trainer_config
97+
self.suppress_lightning_logger = suppress_lightning_logger
8798
self.tabular_model_init_kwargs = {
8899
"model_callable": model_callable,
89100
"model_state_dict_path": model_state_dict_path,
@@ -208,6 +219,8 @@ def tune(
208219
assert mode in ["max", "min"], "mode must be one of ['max', 'min']"
209220
assert metric is not None, "metric must be specified"
210221
assert isinstance(search_space, dict) and len(search_space) > 0, "search_space must be a non-empty dict"
222+
if self.suppress_lightning_logger:
223+
suppress_lightning_logs()
211224
if cv is not None and validation is not None:
212225
warnings.warn(
213226
"Both validation and cv are provided. Ignoring validation and using cv. Use "

0 commit comments

Comments
 (0)