-
Notifications
You must be signed in to change notification settings - Fork 12
Hf checkpoint conversion for distributed checkpoints #424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Should we maybe move the conversion directory into the checkpointing directory with this PR (after review). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for converting distributed checkpoint (DCP) formats (FSDP2, PP, TP) to HuggingFace transformers format. The conversion is implemented as a two-step process: first converting DCP checkpoints to standard PyTorch format, then using the existing conversion pipeline to create HuggingFace models.
- Added new
convert_dcp_to_torchmodule to handle DCP-to-PyTorch checkpoint conversion - Extended the GPT-2 conversion script to support DCP checkpoints via
--dcpflag - Introduced
ConfigDictTypetype alias for better type consistency across configuration handling
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| src/modalities/checkpointing/convert_dcp_to_torch.py | New module implementing DCP to PyTorch checkpoint conversion with config file transformation |
| src/modalities/conversion/gpt2/convert_gpt2.py | Added --dcp flag support, new convert_gpt2_dcp function, and refactored main entry point |
| src/modalities/conversion/gpt2/conversion_model.py | Updated type hints to use ConfigDictType and added dtype assertion in model checking |
| src/modalities/config/config.py | Added ConfigDictType alias, new save_yaml_config_dict function, and fixed implicit return in resolver |
| src/modalities/models/utils.py | Updated type hints to use ConfigDictType for consistency |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…env context manager.
- Now only loading model weights into memory (no optimizer or scheduler weights). - Always creating a FP32 config since FSDP2 always has FP32 weights. - Disabled overwriting of existing config files.
…ity again. This was not required after all.
- Detection and warning if another attention implementation than Huggignface default is used since this is not saved with the checkpoint. - Correct handling and matching of FSDP2 mixed precision behavior. (In particular for rotary pos embeddings).
…llama implementation.
At this time, this bug seems to be fixed in main and we should be able to use a version >4.57.3 once it is released. Problematic line: https://github.com/huggingface/transformers/blob/47b0e478f324b54f177ea7998a0791870fdd0324/src/transformers/utils/generic.py#L947 Fixed version: https://github.com/huggingface/transformers/blob/d3ee06b8cb5e45aab51b85aafd54f4b3f7cad2e2/src/transformers/utils/generic.py#L791
…onment variables.
…l uses of that function work.
…for missing fields in the original config.
therealdavidos
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice one!
| self._env_override = EnvOverride( | ||
| { | ||
| "MASTER_ADDR": "localhost", | ||
| "MASTER_PORT": str(rdvz_port), | ||
| "RANK": str(global_rank), | ||
| "LOCAL_RANK": str(local_rank), | ||
| "WORLD_SIZE": str(world_size), | ||
| } | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this stuff typically taken care of by torchrun ? Why do we need the CudaEnv class ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MultiProcessingCudaEnv is useful when running distributed stuff from Python directly. Previously, we only used this for our unit tests. For conversion, it was also necessary to have this to work with the DCP model in the conversion script.
What does this PR do?
Implements checkpoint conversion for DCP checkpoints (FSDP2, PP, TP). For this the checkpoint first gets converted to a normal Pytorch checkpoint (together with a corresponding config) and then gets converted using the existing code.
Note: Currently, no new tests were added due to the effort of creating and manipulating a dcp checkpoint as needed for those tests.
General Changes
Breaking Changes
Checklist before submitting final PR
python tests/tests.py)CHANGELOG_DEV.md)