Skip to content

Conversation

@arrdel
Copy link

@arrdel arrdel commented Dec 6, 2025

What does this PR do?

Fixes #21255

This PR adds a public adapt_checkpoint_hparams() hook to LightningCLI that allows users to customize hyperparameters loaded from checkpoints before they are used to instantiate model classes. This solves the problem of loading checkpoints across different module classes (e.g., from TrainingModule to InferenceModule).

Problem

When using LightningCLI with checkpoints, hyperparameters saved during training are automatically loaded and applied when running other subcommands (test, predict, etc.). This is convenient when using the same module class, but fails when using a different class with incompatible __init__ parameters.

Example scenario:

# TrainingModule saves 'lr' hyperparameter
class TrainingModule(LightningModule):
    def __init__(self, lr: float = 1e-3):
        ...

# InferenceModule doesn't accept 'lr'
class InferenceModule(LightningModule):
    def __init__(self):  # No 'lr' parameter!
        ...

Running cli predict --ckpt_path checkpoint.ckpt with InferenceModule fails because the CLI tries to pass lr=1e-3 from the checkpoint to InferenceModule.__init__().

Solution

Added adapt_checkpoint_hparams() public method that users can override to customize loaded hyperparameters:

class MyCLI(LightningCLI):
    def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]:
        # Remove training-specific hyperparameters
        checkpoint_hparams.pop("lr", None)
        checkpoint_hparams.pop("weight_decay", None)
        return checkpoint_hparams

Implementation Details

  • Added: adapt_checkpoint_hparams() public method in LightningCLI
  • Modified: _parse_ckpt_path() to call the hook after loading but before applying hyperparameters
  • Backward compatible: Default implementation returns hyperparameters unchanged
  • Flexible: Users can remove, modify, or completely disable checkpoint hyperparameters

Why This Approach?

As discussed in #21255, this is superior to alternatives:

  1. Better than disabling checkpoint loading: Preserves valuable hyperparameter information (e.g., hidden_dim)
  2. Better than CLI flags: Maintains consistency with Trainer parameter pattern
  3. Better than modifying private methods: Provides official public API

Testing

The implementation:

  • ✅ Maintains backward compatibility (existing code unaffected)
  • ✅ Provides maximum flexibility via public hook
  • ✅ Works with both regular and subclass module modes
  • ✅ Handles _class_path modification when needed

Example Use Cases

  1. Remove training-only parameters:

    def adapt_checkpoint_hparams(self, hparams):
        hparams.pop("lr", None)
        return hparams
  2. Change module class in subclass mode:

    def adapt_checkpoint_hparams(self, hparams):
        hparams["_class_path"] = "mymodule.InferenceModule"
        return hparams
  3. Disable all checkpoint hyperparameters:

    def adapt_checkpoint_hparams(self, hparams):
        return {}

Does your PR introduce any breaking changes?

No, this is a purely additive change. The default implementation returns hyperparameters unchanged, preserving existing behavior.

Before submitting

  • Was this discussed/approved via a GitHub issue? Yes - Allow weight reuse in a different lightning module #21255
  • Did you read the contributor guideline?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? Yes
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

cc: @mauvilsa @ziw-liu


📚 Documentation preview 📚: https://pytorch-lightning--21408.org.readthedocs.build/en/21408/

…ameter loading

Fixes Lightning-AI#21255

This commit adds the adapt_checkpoint_hparams() public method to LightningCLI,
allowing users to customize hyperparameters loaded from checkpoints before they
are used to instantiate model classes. This is particularly useful when using
checkpoints from a TrainingModule with a different InferenceModule class that
has different __init__ parameters.

Problem:
When loading a checkpoint trained with TrainingModule(lr=1e-3) into an
InferenceModule() that doesn't accept 'lr' as a parameter, the CLI would fail
during instantiation because it tries to pass all checkpoint hyperparameters
to the new module class.

Solution:
Added adapt_checkpoint_hparams() hook that is called in _parse_ckpt_path()
after loading checkpoint hyperparameters but before applying them. Users can
override this method to:
- Remove training-specific hyperparameters (e.g., lr, weight_decay)
- Modify _class_path for subclass mode
- Transform hyperparameter names/values
- Completely disable checkpoint hyperparameters by returning {}

Example usage:
    class MyCLI(LightningCLI):
        def adapt_checkpoint_hparams(self, checkpoint_hparams):
            checkpoint_hparams.pop('lr', None)
            checkpoint_hparams.pop('weight_decay', None)
            return checkpoint_hparams

This approach is preferable to:
- Disabling checkpoint loading entirely (loses valuable hyperparameter info)
- Adding CLI arguments (deviates from Trainer parameter pattern)
- Modifying private methods (breaks encapsulation)

The hook provides maximum flexibility while maintaining backward compatibility
(default implementation returns hyperparameters unchanged).
Copilot AI review requested due to automatic review settings December 6, 2025 02:58
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Dec 6, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a public adapt_checkpoint_hparams() hook to LightningCLI that enables users to customize hyperparameters loaded from checkpoints before model instantiation. This addresses the issue of loading checkpoints across different module classes (e.g., from TrainingModule to InferenceModule) where incompatible __init__ parameters would otherwise cause failures.

Key Changes:

  • Added adapt_checkpoint_hparams() public method with comprehensive documentation
  • Integrated the hook into _parse_ckpt_path() to allow customization before hyperparameter application
  • Maintained backward compatibility with a default no-op implementation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +563 to +591
def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]:
"""Adapt checkpoint hyperparameters before instantiating the model class.
This method allows for customization of hyperparameters loaded from a checkpoint when
using a different model class than the one used for training. For example, when loading
a checkpoint from a TrainingModule to use with an InferenceModule that has different
``__init__`` parameters, you can remove or modify incompatible hyperparameters.
Args:
checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint.
Returns:
Dictionary of adapted hyperparameters to be used for model instantiation.
Example::
class MyCLI(LightningCLI):
def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]:
# Remove training-specific hyperparameters not needed for inference
checkpoint_hparams.pop("lr", None)
checkpoint_hparams.pop("weight_decay", None)
return checkpoint_hparams
Note:
If subclass module mode is enabled and ``_class_path`` is present in the checkpoint
hyperparameters, you may need to modify it as well to point to your new module class.
"""
return checkpoint_hparams
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new adapt_checkpoint_hparams() hook lacks test coverage. Given that tests/tests_pytorch/test_cli.py contains comprehensive tests for checkpoint loading functionality (e.g., test_lightning_cli_ckpt_path_argument_hparams and test_lightning_cli_ckpt_path_argument_hparams_subclass_mode), tests should be added to verify:

  1. The hook is called when loading checkpoint hyperparameters
  2. Modifications made in the hook are applied correctly
  3. Returning an empty dict properly skips checkpoint hyperparameter loading
  4. The hook works in both regular and subclass modes

Copilot uses AI. Check for mistakes.
else:
self.config = parser.parse_args(args)

def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]:
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use lowercase dict instead of Dict for type annotations to align with the modern Python 3.9+ style used throughout this file. Change Dict[str, Any] to dict[str, Any] in both the parameter and return type annotations.

Copilot uses AI. Check for mistakes.
Example::
class MyCLI(LightningCLI):
def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]:
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use lowercase dict instead of Dict for type annotations to align with the modern Python 3.9+ style used throughout this file. Change Dict[str, Any] to dict[str, Any] in both the parameter and return type annotations.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Allow weight reuse in a different lightning module

1 participant