-
Notifications
You must be signed in to change notification settings - Fork 12
Multi stage pipeline parallelism support #418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…elism support. (WIP)
…odel. Also made None returns more visible in get_module_class_from_name().
…ith interleaved 1F1B.
…tack traces/views).
- Switched from using abs=1e-16 to rel=1e-2 for loss comparisons. Need to investigate further, why this is necessary for some configurations. - Additional configs and test setups which are however commented out due to the long runtime of these tests. - Easier configurability for expected checkpoint paths (for debugging/messing around). - Better error logging.
| else: | ||
| assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." | ||
| sd = get_optimizer_state_dict( | ||
| model=app_state.model_parts[0], | ||
| optimizers=app_state.optimizer, | ||
| # NOTE: Flattening is required for pipeline parallelism to work correctly. | ||
| # see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214 | ||
| options=StateDictOptions(flatten_optimizer_state_dict=True), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove this, since in case of PP we now always have an optimizer list which takes care of the flattening?
| @model_validator(mode="before") | ||
| @classmethod | ||
| def warn_deprecated_alias(cls, data: Any) -> Any: | ||
| if isinstance(data, dict) and "wrapped_model" in data: | ||
| warnings.warn( | ||
| "Field 'wrapped_model' is deprecated. Use 'wrapped_model_or_parts' instead.", | ||
| DeprecationWarning, | ||
| stacklevel=3, | ||
| ) | ||
| return data | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use this deprecation warning? If yes, should we use it also in other configs where a field got renamed to plural?
| # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), | ||
| # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2), | ||
| # ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are currently deactivated due to the long runtime of these tests. Should we activate them anyways?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first and the third commented-out configs are the same, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that
("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
is necessary since we already test
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 8, 2),
which is the same setup + data parallelism, correct?
And since we have
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_grad_accu.yaml", 8, 1),
we can probably skip
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, these configs are mostly useful for debugging with fewer ranks. Probably makes sense to have them turned off (or even delete them in the future).
| ( # FIXME wpe and drop probably should not get the higher weight | ||
| ["transformer.wte", "transformer.wpe", "transformer.drop"], | ||
| self._input_layer_equivalence, | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this FIXME, anyone got an opinion on whether I can remove wpe and drop from this list?
rrutmann
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some tests failing for me:
/workspaces/modalities/tests/conversion/gpt2/test_conversion_model.py::test_convert_model_checkpoint_produces_same_logits_as_original[gpt2_config_test.yaml-False]
TypeError: check_model_inputs..wrapped_fn() got an unexpected keyword argument 'input_ids'
/workspaces/modalities/tests/conversion/gpt2/test_convert_gpt2.py::test_converting_gpt2_does_not_change_outputs[gpt2_config_test.yaml-False]
TypeError: check_model_inputs..wrapped_fn() got an unexpected keyword argument 'input_ids'
/workspaces/modalities/tests/fsdp2_parallelization/test_tensor_parallelism.py::TestTensorParallelism::test_tp_sharding[swiglu-fsdp2_config_path1-tp_config_path1]
torch.multiprocessing.spawn.ProcessExitedException: process 2 terminated with signal SIGABRT
As well as an error importing one of the tests:
______ ERROR collecting tests/checkpointing/test_checkpoint_conversion.py ______
tests/checkpointing/test_checkpoint_conversion.py:59: in
@pytest.mark.skipif(
/home/richard-rutmann/.local/lib/python3.11/site-packages/_pytest/mark/structures.py:401: in call
store_mark(unwrapped_func, self.mark, stacklevel=3)
/home/richard-rutmann/.local/lib/python3.11/site-packages/_pytest/mark/structures.py:466: in store_mark
warnings.warn(MARKED_FIXTURE, stacklevel=stacklevel)
E pytest.PytestRemovedIn9Warning: Marks applied to fixtures have no effect
E See docs: https://docs.pytest.org/en/stable/deprecations.html#applying-a-mark-to-a-fixture-function
| # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), | ||
| # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2), | ||
| # ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first and the third commented-out configs are the same, right?
| # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), | ||
| # ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2), | ||
| # ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that
("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
is necessary since we already test
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 8, 2),
which is the same setup + data parallelism, correct?
And since we have
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_grad_accu.yaml", 8, 1),
we can probably skip
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),
rrutmann
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, thank you. A few tests are failing (see my comment), but aside from that, no major changes required from my side
Also enabled extra="forbid" in BaseModel to prevent accidental extra fields.
Note: Only strings are supported, not more complex path aliases.
…ecated all aliases created due to multi stage pp.
Co-authored-by: Richard Rutmann <97447451+rrutmann@users.noreply.github.com>
…s in code base. Also added missing deprecation marker for GPT2MFUCalculatorConfig.
What does this PR do?
Adds support for multi stage pipeline parallelism schedules, in particular interleaved 1F1B.
Issue #408
General Changes
Breaking Changes
Checklist before submitting final PR
python tests/tests.py)CHANGELOG_DEV.md)