-
Notifications
You must be signed in to change notification settings - Fork 12
fix: Diverse model seeding across PP ranks #426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
BlueCrescent
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM.
Should we also explicitly allow seeding for the "model_initialized" component?
It will probably inherit the random state from the model_raw component but it seems a bit risky to me to assume that (also in the future) no other interaction with the random state happens between these two components (though, probably, only interactions that are asymmetrical between the ranks would be problematic). In particular, since we cannot guarantee the order in which the components are build, something like a dataloader component might even re-seed the random state.
tests/fsdp2_parallelization/test_parallel_seed_initialization.py
Outdated
Show resolved
Hide resolved
le1nux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the seeding (not the test) and from my understanding the changes do not provide the expected results (also what @BlueCrescent was hinting towards).
When we seed the raw model, the model weights are indeed deterministic at instantiation time. However, we also have model weight initialization which runs afterwards and would just override the weights / seeding.
Additionally, passing device_mesh to the model is coupling two components that should normally not know anything about each other.
I think we have to integrate the seeding to the weight initializer component and can also pass in the device_mesh there.
Yes that makes sense. I moved the seeding to the model initialization component |
See #426 (comment) |
BlueCrescent
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
What does this PR do?
This PR gives a unique model seed for each pp rank, such that parameters are initialized differently across ranks.
General Changes
Breaking Changes
Checklist before submitting final PR
python tests/tests.py)CHANGELOG_DEV.md)