diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 2bcb1d8f4b1fd..2e2bc939705ef 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -560,6 +560,36 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No else: self.config = parser.parse_args(args) + def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: + """Adapt checkpoint hyperparameters before instantiating the model class. + + This method allows for customization of hyperparameters loaded from a checkpoint when + using a different model class than the one used for training. For example, when loading + a checkpoint from a TrainingModule to use with an InferenceModule that has different + ``__init__`` parameters, you can remove or modify incompatible hyperparameters. + + Args: + checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint. + + Returns: + Dictionary of adapted hyperparameters to be used for model instantiation. + + Example:: + + class MyCLI(LightningCLI): + def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: + # Remove training-specific hyperparameters not needed for inference + checkpoint_hparams.pop("lr", None) + checkpoint_hparams.pop("weight_decay", None) + return checkpoint_hparams + + Note: + If subclass module mode is enabled and ``_class_path`` is present in the checkpoint + hyperparameters, you may need to modify it as well to point to your new module class. + + """ + return checkpoint_hparams + def _parse_ckpt_path(self) -> None: """If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config.""" if not self.config.get("subcommand"): @@ -571,6 +601,12 @@ def _parse_ckpt_path(self) -> None: hparams.pop("_instantiator", None) if not hparams: return + + # Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook + hparams = self.adapt_checkpoint_hparams(hparams) + if not hparams: + return + if "_class_path" in hparams: hparams = { "class_path": hparams.pop("_class_path"),