From 910a712a07cae43fade20208a7ee0fbd2dcc6218 Mon Sep 17 00:00:00 2001 From: arrdel Date: Fri, 5 Dec 2025 21:58:11 -0500 Subject: [PATCH 1/2] Add adapt_checkpoint_hparams hook for customizing checkpoint hyperparameter loading Fixes #21255 This commit adds the adapt_checkpoint_hparams() public method to LightningCLI, allowing users to customize hyperparameters loaded from checkpoints before they are used to instantiate model classes. This is particularly useful when using checkpoints from a TrainingModule with a different InferenceModule class that has different __init__ parameters. Problem: When loading a checkpoint trained with TrainingModule(lr=1e-3) into an InferenceModule() that doesn't accept 'lr' as a parameter, the CLI would fail during instantiation because it tries to pass all checkpoint hyperparameters to the new module class. Solution: Added adapt_checkpoint_hparams() hook that is called in _parse_ckpt_path() after loading checkpoint hyperparameters but before applying them. Users can override this method to: - Remove training-specific hyperparameters (e.g., lr, weight_decay) - Modify _class_path for subclass mode - Transform hyperparameter names/values - Completely disable checkpoint hyperparameters by returning {} Example usage: class MyCLI(LightningCLI): def adapt_checkpoint_hparams(self, checkpoint_hparams): checkpoint_hparams.pop('lr', None) checkpoint_hparams.pop('weight_decay', None) return checkpoint_hparams This approach is preferable to: - Disabling checkpoint loading entirely (loses valuable hyperparameter info) - Adding CLI arguments (deviates from Trainer parameter pattern) - Modifying private methods (breaks encapsulation) The hook provides maximum flexibility while maintaining backward compatibility (default implementation returns hyperparameters unchanged). --- src/lightning/pytorch/cli.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 2bcb1d8f4b1fd..87b3634dda413 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -560,6 +560,36 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No else: self.config = parser.parse_args(args) + def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: + """Adapt checkpoint hyperparameters before instantiating the model class. + + This method allows for customization of hyperparameters loaded from a checkpoint when + using a different model class than the one used for training. For example, when loading + a checkpoint from a TrainingModule to use with an InferenceModule that has different + ``__init__`` parameters, you can remove or modify incompatible hyperparameters. + + Args: + checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint. + + Returns: + Dictionary of adapted hyperparameters to be used for model instantiation. + + Example:: + + class MyCLI(LightningCLI): + def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: + # Remove training-specific hyperparameters not needed for inference + checkpoint_hparams.pop("lr", None) + checkpoint_hparams.pop("weight_decay", None) + return checkpoint_hparams + + Note: + If subclass module mode is enabled and ``_class_path`` is present in the checkpoint + hyperparameters, you may need to modify it as well to point to your new module class. + + """ + return checkpoint_hparams + def _parse_ckpt_path(self) -> None: """If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config.""" if not self.config.get("subcommand"): @@ -571,6 +601,12 @@ def _parse_ckpt_path(self) -> None: hparams.pop("_instantiator", None) if not hparams: return + + # Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook + hparams = self.adapt_checkpoint_hparams(hparams) + if not hparams: + return + if "_class_path" in hparams: hparams = { "class_path": hparams.pop("_class_path"), From ad1a0285dfa3d2b9fcb98ee725d14d5ff63fd6fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 6 Dec 2025 02:59:15 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 87b3634dda413..2e2bc939705ef 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -601,12 +601,12 @@ def _parse_ckpt_path(self) -> None: hparams.pop("_instantiator", None) if not hparams: return - + # Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook hparams = self.adapt_checkpoint_hparams(hparams) if not hparams: return - + if "_class_path" in hparams: hparams = { "class_path": hparams.pop("_class_path"),