Skip to content

Conversation

@BlueCrescent
Copy link
Collaborator

@BlueCrescent BlueCrescent commented Nov 7, 2025

What does this PR do?

Adds support for multi stage pipeline parallelism schedules, in particular interleaved 1F1B.
Issue #408

General Changes

  • Made code compatible with having multiple stages per rank.
  • Switched to interleaved 1F1B in some configs.
  • Note: In warmstart test, drastically increased epsilon for loss comparison.

Breaking Changes

  • Changes should be backwards compatible.

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

…odel.

Also made None returns more visible in get_module_class_from_name().
- Switched from using abs=1e-16 to rel=1e-2 for loss comparisons. Need to investigate further, why this is necessary for some configurations.
- Additional configs and test setups which are however commented out due to the long runtime of these tests.
- Easier configurability for expected checkpoint paths (for debugging/messing around).
- Better error logging.
Comment on lines +196 to +204
else:
assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer."
sd = get_optimizer_state_dict(
model=app_state.model_parts[0],
optimizers=app_state.optimizer,
# NOTE: Flattening is required for pipeline parallelism to work correctly.
# see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove this, since in case of PP we now always have an optimizer list which takes care of the flattening?

@BlueCrescent BlueCrescent marked this pull request as ready for review November 24, 2025 09:35
Comment on lines 59 to 69
@model_validator(mode="before")
@classmethod
def warn_deprecated_alias(cls, data: Any) -> Any:
if isinstance(data, dict) and "wrapped_model" in data:
warnings.warn(
"Field 'wrapped_model' is deprecated. Use 'wrapped_model_or_parts' instead.",
DeprecationWarning,
stacklevel=3,
)
return data

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use this deprecation warning? If yes, should we use it also in other configs where a field got renamed to plural?

Comment on lines +54 to +56
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),
# ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are currently deactivated due to the long runtime of these tests. Should we activate them anyways?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first and the third commented-out configs are the same, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that
("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
is necessary since we already test
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 8, 2),
which is the same setup + data parallelism, correct?

And since we have
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_grad_accu.yaml", 8, 1),
we can probably skip
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),

Copy link
Collaborator Author

@BlueCrescent BlueCrescent Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, these configs are mostly useful for debugging with fewer ranks. Probably makes sense to have them turned off (or even delete them in the future).

Comment on lines +108 to +111
( # FIXME wpe and drop probably should not get the higher weight
["transformer.wte", "transformer.wpe", "transformer.drop"],
self._input_layer_equivalence,
),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this FIXME, anyone got an opinion on whether I can remove wpe and drop from this list?

@rrutmann rrutmann self-requested a review November 25, 2025 13:37
Copy link
Collaborator

@rrutmann rrutmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some tests failing for me:

/workspaces/modalities/tests/conversion/gpt2/test_conversion_model.py::test_convert_model_checkpoint_produces_same_logits_as_original[gpt2_config_test.yaml-False]
TypeError: check_model_inputs..wrapped_fn() got an unexpected keyword argument 'input_ids'

/workspaces/modalities/tests/conversion/gpt2/test_convert_gpt2.py::test_converting_gpt2_does_not_change_outputs[gpt2_config_test.yaml-False]
TypeError: check_model_inputs..wrapped_fn() got an unexpected keyword argument 'input_ids'

/workspaces/modalities/tests/fsdp2_parallelization/test_tensor_parallelism.py::TestTensorParallelism::test_tp_sharding[swiglu-fsdp2_config_path1-tp_config_path1]
torch.multiprocessing.spawn.ProcessExitedException: process 2 terminated with signal SIGABRT

As well as an error importing one of the tests:
______ ERROR collecting tests/checkpointing/test_checkpoint_conversion.py ______
tests/checkpointing/test_checkpoint_conversion.py:59: in
@pytest.mark.skipif(
/home/richard-rutmann/.local/lib/python3.11/site-packages/_pytest/mark/structures.py:401: in call
store_mark(unwrapped_func, self.mark, stacklevel=3)
/home/richard-rutmann/.local/lib/python3.11/site-packages/_pytest/mark/structures.py:466: in store_mark
warnings.warn(MARKED_FIXTURE, stacklevel=stacklevel)
E pytest.PytestRemovedIn9Warning: Marks applied to fixtures have no effect
E See docs: https://docs.pytest.org/en/stable/deprecations.html#applying-a-mark-to-a-fixture-function

Comment on lines +54 to +56
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),
# ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first and the third commented-out configs are the same, right?

Comment on lines +54 to +56
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),
# ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that
("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
is necessary since we already test
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 8, 2),
which is the same setup + data parallelism, correct?

And since we have
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_grad_accu.yaml", 8, 1),
we can probably skip
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),

Copy link
Collaborator

@rrutmann rrutmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, thank you. A few tests are failing (see my comment), but aside from that, no major changes required from my side

@le1nux le1nux self-requested a review December 10, 2025 09:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants