Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,36 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]:
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use lowercase dict instead of Dict for type annotations to align with the modern Python 3.9+ style used throughout this file. Change Dict[str, Any] to dict[str, Any] in both the parameter and return type annotations.

Copilot uses AI. Check for mistakes.
"""Adapt checkpoint hyperparameters before instantiating the model class.

This method allows for customization of hyperparameters loaded from a checkpoint when
using a different model class than the one used for training. For example, when loading
a checkpoint from a TrainingModule to use with an InferenceModule that has different
``__init__`` parameters, you can remove or modify incompatible hyperparameters.

Args:
checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint.

Returns:
Dictionary of adapted hyperparameters to be used for model instantiation.

Example::

class MyCLI(LightningCLI):
def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]:
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use lowercase dict instead of Dict for type annotations to align with the modern Python 3.9+ style used throughout this file. Change Dict[str, Any] to dict[str, Any] in both the parameter and return type annotations.

Copilot uses AI. Check for mistakes.
# Remove training-specific hyperparameters not needed for inference
checkpoint_hparams.pop("lr", None)
checkpoint_hparams.pop("weight_decay", None)
return checkpoint_hparams

Note:
If subclass module mode is enabled and ``_class_path`` is present in the checkpoint
hyperparameters, you may need to modify it as well to point to your new module class.

"""
return checkpoint_hparams
Comment on lines +563 to +591
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new adapt_checkpoint_hparams() hook lacks test coverage. Given that tests/tests_pytorch/test_cli.py contains comprehensive tests for checkpoint loading functionality (e.g., test_lightning_cli_ckpt_path_argument_hparams and test_lightning_cli_ckpt_path_argument_hparams_subclass_mode), tests should be added to verify:

  1. The hook is called when loading checkpoint hyperparameters
  2. Modifications made in the hook are applied correctly
  3. Returning an empty dict properly skips checkpoint hyperparameter loading
  4. The hook works in both regular and subclass modes

Copilot uses AI. Check for mistakes.

def _parse_ckpt_path(self) -> None:
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
if not self.config.get("subcommand"):
Expand All @@ -571,6 +601,12 @@ def _parse_ckpt_path(self) -> None:
hparams.pop("_instantiator", None)
if not hparams:
return

# Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook
hparams = self.adapt_checkpoint_hparams(hparams)
if not hparams:
return

if "_class_path" in hparams:
hparams = {
"class_path": hparams.pop("_class_path"),
Expand Down