Commit b12fbba
Fully Sharded 2D Parallelism (#3558)
Summary:
Pull Request resolved: #3558
**This diff introduces Fully Sharded 2D Parallelism in TorchRec. It brings forth significant memory (50%+) savings by sharding embedding tables when they are not in use.**
After the embedding lookup, the embedding table is further sharded across the data parallel dimension until it is needed in the backward pass. This allows model layers after the embedding lookup to have more memory headroom. Enabling further scaling of the dense architecture. **Practically speaking, this saves 50%+ embedding memory per GPU which account for upwards of 10GB of memory saving on large models.**
The peak memory during this step becomes, ```O(shard + shard/num_replication)```, which then leads to an embedding memory of ```O(shard/num_replication)``` after the lookup step.
The memory free and collective communications are done in a overhead free manner by maximizing computation and communication collectives through asynchronous handling on multiple streams.
With Fully Sharded 2D, the embedding weight synchronization has to happen every step or trained batches are lost across ranks. We use an asynchronous reduce scatter after the embedding lookup step. We are able to fully overlap this collective with compute to expose no additional overhead.
A new awaitable is introduced, ```ReduceScatterResizeAwaitable``` under the Fully Sharded path that is called with SDD output_dist all to all. This awaitable ```wait()```s on the async reduce scatter and calls the ```resize()``` operation on the embedding memory ensuring no race conditions.
Users can enable fully sharded 2D through, a new arg `ShardingStrategy`
```
DMPCollection(..., sharding_strategy=ShardingStrategy.FULLY_SHARDED)
```
This is part of our work to create an overhead free 2D parallel which will allow us to use it for every model.
Remaining work from this diff is to launch an async all gather in the backward pass, making planner aware of such memory savings, and integrate this work with per module 2D.
Checkpointing compatibility remains as checkpoints are created after a given train step. At this point each rank contains the full embedding weights that are synced across the DP dimension.
Reviewed By: liangbeixu
Differential Revision: D82253387
fbshipit-source-id: 2fdc6be611556ac3da9a8fea6a9017efd1bbc5031 parent 7ecee99 commit b12fbba
File tree
16 files changed
+522
-31
lines changed- torchrec/distributed
- sharding
16 files changed
+522
-31
lines changedLarge diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
343 | 343 | | |
344 | 344 | | |
345 | 345 | | |
| 346 | + | |
346 | 347 | | |
347 | 348 | | |
348 | 349 | | |
| |||
354 | 355 | | |
355 | 356 | | |
356 | 357 | | |
| 358 | + | |
357 | 359 | | |
358 | 360 | | |
359 | 361 | | |
| |||
398 | 400 | | |
399 | 401 | | |
400 | 402 | | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
401 | 410 | | |
402 | 411 | | |
403 | 412 | | |
| |||
1588 | 1597 | | |
1589 | 1598 | | |
1590 | 1599 | | |
| 1600 | + | |
| 1601 | + | |
1591 | 1602 | | |
1592 | 1603 | | |
1593 | 1604 | | |
| |||
1604 | 1615 | | |
1605 | 1616 | | |
1606 | 1617 | | |
| 1618 | + | |
| 1619 | + | |
| 1620 | + | |
1607 | 1621 | | |
1608 | 1622 | | |
1609 | 1623 | | |
| |||
1631 | 1645 | | |
1632 | 1646 | | |
1633 | 1647 | | |
| 1648 | + | |
1634 | 1649 | | |
1635 | 1650 | | |
1636 | 1651 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
| 42 | + | |
| 43 | + | |
42 | 44 | | |
43 | 45 | | |
44 | 46 | | |
| |||
65 | 67 | | |
66 | 68 | | |
67 | 69 | | |
68 | | - | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
69 | 79 | | |
70 | 80 | | |
71 | 81 | | |
| |||
185 | 195 | | |
186 | 196 | | |
187 | 197 | | |
| 198 | + | |
188 | 199 | | |
189 | 200 | | |
190 | 201 | | |
191 | 202 | | |
192 | 203 | | |
193 | | - | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
194 | 207 | | |
195 | 208 | | |
196 | 209 | | |
| |||
218 | 231 | | |
219 | 232 | | |
220 | 233 | | |
| 234 | + | |
221 | 235 | | |
222 | 236 | | |
223 | 237 | | |
| |||
234 | 248 | | |
235 | 249 | | |
236 | 250 | | |
237 | | - | |
238 | | - | |
239 | | - | |
240 | | - | |
241 | | - | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
242 | 265 | | |
243 | 266 | | |
244 | 267 | | |
| |||
329 | 352 | | |
330 | 353 | | |
331 | 354 | | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
332 | 363 | | |
333 | 364 | | |
334 | 365 | | |
| |||
512 | 543 | | |
513 | 544 | | |
514 | 545 | | |
| 546 | + | |
515 | 547 | | |
516 | 548 | | |
| 549 | + | |
517 | 550 | | |
518 | 551 | | |
519 | 552 | | |
520 | | - | |
| 553 | + | |
521 | 554 | | |
522 | 555 | | |
523 | 556 | | |
| |||
555 | 588 | | |
556 | 589 | | |
557 | 590 | | |
| 591 | + | |
558 | 592 | | |
559 | 593 | | |
560 | 594 | | |
| |||
564 | 598 | | |
565 | 599 | | |
566 | 600 | | |
567 | | - | |
568 | | - | |
569 | | - | |
570 | | - | |
571 | | - | |
572 | | - | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
573 | 621 | | |
574 | 622 | | |
575 | 623 | | |
| |||
744 | 792 | | |
745 | 793 | | |
746 | 794 | | |
| 795 | + | |
| 796 | + | |
| 797 | + | |
| 798 | + | |
| 799 | + | |
| 800 | + | |
| 801 | + | |
| 802 | + | |
747 | 803 | | |
748 | 804 | | |
749 | 805 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
354 | 354 | | |
355 | 355 | | |
356 | 356 | | |
| 357 | + | |
357 | 358 | | |
358 | 359 | | |
359 | 360 | | |
| |||
366 | 367 | | |
367 | 368 | | |
368 | 369 | | |
| 370 | + | |
369 | 371 | | |
370 | 372 | | |
371 | 373 | | |
| |||
390 | 392 | | |
391 | 393 | | |
392 | 394 | | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
393 | 402 | | |
394 | 403 | | |
395 | 404 | | |
| |||
407 | 416 | | |
408 | 417 | | |
409 | 418 | | |
| 419 | + | |
410 | 420 | | |
411 | 421 | | |
412 | 422 | | |
413 | 423 | | |
414 | 424 | | |
415 | 425 | | |
416 | 426 | | |
| 427 | + | |
417 | 428 | | |
418 | 429 | | |
419 | 430 | | |
| |||
425 | 436 | | |
426 | 437 | | |
427 | 438 | | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
428 | 445 | | |
429 | 446 | | |
430 | 447 | | |
| |||
1655 | 1672 | | |
1656 | 1673 | | |
1657 | 1674 | | |
| 1675 | + | |
1658 | 1676 | | |
1659 | 1677 | | |
1660 | 1678 | | |
| |||
1669 | 1687 | | |
1670 | 1688 | | |
1671 | 1689 | | |
| 1690 | + | |
1672 | 1691 | | |
| 1692 | + | |
| 1693 | + | |
| 1694 | + | |
1673 | 1695 | | |
1674 | 1696 | | |
1675 | 1697 | | |
| |||
1702 | 1724 | | |
1703 | 1725 | | |
1704 | 1726 | | |
| 1727 | + | |
1705 | 1728 | | |
1706 | 1729 | | |
1707 | 1730 | | |
| |||
1710 | 1733 | | |
1711 | 1734 | | |
1712 | 1735 | | |
| 1736 | + | |
1713 | 1737 | | |
1714 | 1738 | | |
1715 | 1739 | | |
| |||
0 commit comments