2222 TrainerConfig ,
2323)
2424from pytorch_tabular .tabular_model import TabularModel
25- from pytorch_tabular .utils import OOMException , OutOfMemoryHandler , get_logger
25+ from pytorch_tabular .utils import OOMException , OutOfMemoryHandler , get_logger , suppress_lightning_logs
2626
2727logger = get_logger (__name__ )
2828
@@ -45,6 +45,7 @@ def __init__(
4545 trainer_config : Optional [Union [TrainerConfig , str ]] = None ,
4646 model_callable : Optional [Callable ] = None ,
4747 model_state_dict_path : Optional [Union [str , Path ]] = None ,
48+ suppress_lightning_logger : bool = True ,
4849 ** kwargs ,
4950 ):
5051 """Tabular Model Tuner helps you tune the hyperparameters of a TabularModel.
@@ -53,21 +54,30 @@ def __init__(
5354 data_config (Optional[Union[DataConfig, str]], optional): The DataConfig for the TabularModel.
5455 If str is passed, will initialize the DataConfig using the yaml file in that path.
5556 Defaults to None.
57+
5658 model_config (Optional[Union[ModelConfig, str]], optional): The ModelConfig for the TabularModel.
5759 If str is passed, will initialize the ModelConfig using the yaml file in that path.
5860 Defaults to None.
61+
5962 optimizer_config (Optional[Union[OptimizerConfig, str]], optional): The OptimizerConfig for the
6063 TabularModel. If str is passed, will initialize the OptimizerConfig using the yaml file in
6164 that path. Defaults to None.
65+
6266 trainer_config (Optional[Union[TrainerConfig, str]], optional): The TrainerConfig for the TabularModel.
6367 If str is passed, will initialize the TrainerConfig using the yaml file in that path.
6468 Defaults to None.
69+
6570 model_callable (Optional[Callable], optional): A callable that returns a PyTorch Tabular Model.
6671 If provided, will ignore the model_config and use this callable to initialize the model.
6772 Defaults to None.
73+
6874 model_state_dict_path (Optional[Union[str, Path]], optional): Path to the state dict of the model.
75+
6976 If provided, will ignore the model_config and use this state dict to initialize the model.
7077 Defaults to None.
78+
79+ suppress_lightning_logger (bool, optional): Whether to suppress the lightning logger. Defaults to True.
80+
7181 **kwargs: Additional keyword arguments to be passed to the TabularModel init.
7282 """
7383 if trainer_config .profiler is not None :
@@ -84,6 +94,7 @@ def __init__(
8494 self .model_config = model_config
8595 self .optimizer_config = optimizer_config
8696 self .trainer_config = trainer_config
97+ self .suppress_lightning_logger = suppress_lightning_logger
8798 self .tabular_model_init_kwargs = {
8899 "model_callable" : model_callable ,
89100 "model_state_dict_path" : model_state_dict_path ,
@@ -208,6 +219,8 @@ def tune(
208219 assert mode in ["max" , "min" ], "mode must be one of ['max', 'min']"
209220 assert metric is not None , "metric must be specified"
210221 assert isinstance (search_space , dict ) and len (search_space ) > 0 , "search_space must be a non-empty dict"
222+ if self .suppress_lightning_logger :
223+ suppress_lightning_logs ()
211224 if cv is not None and validation is not None :
212225 warnings .warn (
213226 "Both validation and cv are provided. Ignoring validation and using cv. Use "
0 commit comments