@@ -184,7 +184,10 @@ def _get_run_name_uid(self) -> Tuple[str, int]:
184184 """
185185 if hasattr (self .config , "run_name" ) and self .config .run_name is not None :
186186 name = self .config .run_name
187- elif hasattr (self .config , "checkpoints_name" ) and self .config .checkpoints_name is not None :
187+ elif (
188+ hasattr (self .config , "checkpoints_name" )
189+ and self .config .checkpoints_name is not None
190+ ):
188191 name = self .config .checkpoints_name
189192 else :
190193 name = self .config .task
@@ -287,7 +290,6 @@ def _prepare_model(self, loss, metrics, optimizer, optimizer_params, reset):
287290 )
288291 # Data Aware Initialization(for the models that need it)
289292 self .model .data_aware_initialization (self .datamodule )
290-
291293
292294 def _prepare_trainer (self , max_epochs = None , min_epochs = None ):
293295 logger .info ("Preparing the Trainer..." )
@@ -297,7 +299,7 @@ def _prepare_trainer(self, max_epochs=None, min_epochs=None):
297299 self .config .min_epochs = min_epochs
298300 # TODO get Trainer Arguments from the init signature
299301 trainer_sig = inspect .signature (pl .Trainer .__init__ )
300- trainer_args = [p for p in trainer_sig .parameters .keys () if p != "self" ]
302+ trainer_args = [p for p in trainer_sig .parameters .keys () if p != "self" ]
301303 trainer_args_config = {
302304 k : v for k , v in self .config .items () if k in trainer_args
303305 }
@@ -314,9 +316,14 @@ def load_best_model(self):
314316 if self .trainer .checkpoint_callback is not None :
315317 logger .info ("Loading the best model..." )
316318 ckpt_path = self .trainer .checkpoint_callback .best_model_path
317- logger .debug (f"Model Checkpoint: { ckpt_path } " )
318- ckpt = pl_load (ckpt_path , map_location = lambda storage , loc : storage )
319- self .model .load_state_dict (ckpt ["state_dict" ])
319+ if ckpt_path != "" :
320+ logger .debug (f"Model Checkpoint: { ckpt_path } " )
321+ ckpt = pl_load (ckpt_path , map_location = lambda storage , loc : storage )
322+ self .model .load_state_dict (ckpt ["state_dict" ])
323+ else :
324+ logger .info (
325+ "No best model available to load. Did you run it more than 1 epoch?..."
326+ )
320327 else :
321328 logger .info (
322329 "No best model available to load. Did you run it more than 1 epoch?..."
@@ -737,19 +744,18 @@ def load_from_checkpoint(cls, dir: str):
737744 custom_params = joblib .load (os .path .join (dir , "custom_params.sav" ))
738745 model_args = {}
739746 if custom_params .get ("custom_loss" ) is not None :
740- model_args [' loss' ] = "MSELoss"
747+ model_args [" loss" ] = "MSELoss"
741748 if custom_params .get ("custom_metrics" ) is not None :
742- model_args [' metrics' ] = ["mean_squared_error" ]
743- model_args [' metric_params' ] = [{}]
749+ model_args [" metrics" ] = ["mean_squared_error" ]
750+ model_args [" metric_params" ] = [{}]
744751 if custom_params .get ("custom_optimizer" ) is not None :
745- model_args [' optimizer' ] = "Adam"
752+ model_args [" optimizer" ] = "Adam"
746753 if custom_params .get ("custom_optimizer_params" ) is not None :
747- model_args [' optimizer_params' ] = {}
748-
754+ model_args [" optimizer_params" ] = {}
755+
749756 # Initializing with default metrics, losses, and optimizers. Will revert once initialized
750757 model = model_callable .load_from_checkpoint (
751- checkpoint_path = os .path .join (dir , "model.ckpt" ),
752- ** model_args
758+ checkpoint_path = os .path .join (dir , "model.ckpt" ), ** model_args
753759 )
754760 # else:
755761 # # Initializing with default values
0 commit comments