Commit aa1eeda
uneven shard sizes support to Fully Sharded 2D collectives and unit tests (#3584)
Summary:
Pull Request resolved: #3584
Adding support for uneven sharding splits across data parallel dimension. In sharding types like row wise and table row wise, uneven sharding cases exist which will cause current collectives in fully sharded 2D to fail. We add padding to ensure the collectives see equal shapes.
The collectives shape handling happens as such:
```
total_size = self._emb_module.weights_dev.numel()
shard_size = (total_size + num_groups - 1) // num_groups # ceil division
padded_total_size = shard_size * num_groups
padding_size = padded_total_size - total_size
if padding_size > 0:
input_tensor = torch.nn.functional.pad(
self._emb_module.weights_dev.contiguous(),
(0, padding_size),
value=0.0,
)
else:
input_tensor = self._emb_module.weights_dev.contiguous()
```
Padding occurs on the right most shard (the same happens with TorchRec uneven sharding as the last shard is the uneven one
The all_gather also accounts for this:
```
num_groups = self._env.num_sharding_groups()
shard_size = self._shard_buf.numel()
padded_total_size = shard_size * num_groups
self._unsharded_param.untyped_storage().resize_(
padded_total_size * self._element_size
)
self._emb_module.weights_dev = self._unsharded_param[
: self._original_shape.numel()
]
```
This diff also adds all required unit tests for all sharding types for fully sharded 2D (sequence and pooled embeddings)
Reviewed By: liangbeixu, kausv
Differential Revision: D87406987
fbshipit-source-id: d1311bd665a6ce2443035f2da92ca73cdb892db31 parent 7c30d39 commit aa1eeda
File tree
4 files changed
+754
-37
lines changed- torchrec/distributed
- test_utils
- tests
4 files changed
+754
-37
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2548 | 2548 | | |
2549 | 2549 | | |
2550 | 2550 | | |
| 2551 | + | |
2551 | 2552 | | |
2552 | 2553 | | |
2553 | 2554 | | |
2554 | 2555 | | |
2555 | | - | |
2556 | | - | |
2557 | | - | |
2558 | 2556 | | |
2559 | | - | |
| 2557 | + | |
2560 | 2558 | | |
2561 | 2559 | | |
2562 | 2560 | | |
| |||
2573 | 2571 | | |
2574 | 2572 | | |
2575 | 2573 | | |
2576 | | - | |
| 2574 | + | |
| 2575 | + | |
| 2576 | + | |
| 2577 | + | |
| 2578 | + | |
| 2579 | + | |
| 2580 | + | |
2577 | 2581 | | |
2578 | 2582 | | |
2579 | 2583 | | |
2580 | | - | |
| 2584 | + | |
2581 | 2585 | | |
2582 | 2586 | | |
2583 | 2587 | | |
2584 | 2588 | | |
2585 | | - | |
| 2589 | + | |
| 2590 | + | |
| 2591 | + | |
2586 | 2592 | | |
2587 | | - | |
| 2593 | + | |
2588 | 2594 | | |
2589 | 2595 | | |
2590 | 2596 | | |
| |||
2633 | 2639 | | |
2634 | 2640 | | |
2635 | 2641 | | |
2636 | | - | |
2637 | 2642 | | |
2638 | | - | |
2639 | | - | |
| 2643 | + | |
| 2644 | + | |
| 2645 | + | |
| 2646 | + | |
| 2647 | + | |
| 2648 | + | |
| 2649 | + | |
| 2650 | + | |
| 2651 | + | |
| 2652 | + | |
| 2653 | + | |
| 2654 | + | |
| 2655 | + | |
| 2656 | + | |
| 2657 | + | |
2640 | 2658 | | |
2641 | 2659 | | |
2642 | 2660 | | |
2643 | 2661 | | |
2644 | 2662 | | |
2645 | 2663 | | |
2646 | 2664 | | |
2647 | | - | |
| 2665 | + | |
2648 | 2666 | | |
2649 | | - | |
| 2667 | + | |
2650 | 2668 | | |
2651 | 2669 | | |
2652 | 2670 | | |
2653 | 2671 | | |
2654 | 2672 | | |
2655 | | - | |
| 2673 | + | |
2656 | 2674 | | |
2657 | 2675 | | |
2658 | 2676 | | |
| |||
2665 | 2683 | | |
2666 | 2684 | | |
2667 | 2685 | | |
2668 | | - | |
| 2686 | + | |
2669 | 2687 | | |
2670 | 2688 | | |
2671 | 2689 | | |
2672 | 2690 | | |
2673 | 2691 | | |
2674 | 2692 | | |
2675 | | - | |
| 2693 | + | |
2676 | 2694 | | |
2677 | 2695 | | |
2678 | 2696 | | |
| |||
3590 | 3608 | | |
3591 | 3609 | | |
3592 | 3610 | | |
| 3611 | + | |
3593 | 3612 | | |
3594 | 3613 | | |
3595 | 3614 | | |
3596 | 3615 | | |
3597 | | - | |
3598 | | - | |
3599 | | - | |
3600 | 3616 | | |
3601 | | - | |
| 3617 | + | |
3602 | 3618 | | |
3603 | 3619 | | |
3604 | 3620 | | |
| |||
3615 | 3631 | | |
3616 | 3632 | | |
3617 | 3633 | | |
3618 | | - | |
| 3634 | + | |
| 3635 | + | |
| 3636 | + | |
| 3637 | + | |
| 3638 | + | |
| 3639 | + | |
| 3640 | + | |
| 3641 | + | |
3619 | 3642 | | |
3620 | 3643 | | |
3621 | 3644 | | |
3622 | | - | |
| 3645 | + | |
3623 | 3646 | | |
3624 | 3647 | | |
3625 | 3648 | | |
3626 | 3649 | | |
3627 | | - | |
| 3650 | + | |
| 3651 | + | |
| 3652 | + | |
3628 | 3653 | | |
3629 | | - | |
| 3654 | + | |
3630 | 3655 | | |
3631 | 3656 | | |
3632 | 3657 | | |
| |||
3675 | 3700 | | |
3676 | 3701 | | |
3677 | 3702 | | |
3678 | | - | |
3679 | 3703 | | |
3680 | | - | |
3681 | | - | |
| 3704 | + | |
| 3705 | + | |
| 3706 | + | |
| 3707 | + | |
| 3708 | + | |
| 3709 | + | |
| 3710 | + | |
| 3711 | + | |
| 3712 | + | |
| 3713 | + | |
| 3714 | + | |
| 3715 | + | |
| 3716 | + | |
| 3717 | + | |
| 3718 | + | |
3682 | 3719 | | |
3683 | 3720 | | |
3684 | 3721 | | |
3685 | 3722 | | |
3686 | 3723 | | |
3687 | 3724 | | |
3688 | 3725 | | |
3689 | | - | |
| 3726 | + | |
3690 | 3727 | | |
3691 | | - | |
3692 | | - | |
3693 | | - | |
3694 | | - | |
| 3728 | + | |
3695 | 3729 | | |
3696 | 3730 | | |
3697 | | - | |
| 3731 | + | |
3698 | 3732 | | |
3699 | 3733 | | |
3700 | 3734 | | |
| |||
3707 | 3741 | | |
3708 | 3742 | | |
3709 | 3743 | | |
3710 | | - | |
| 3744 | + | |
3711 | 3745 | | |
3712 | 3746 | | |
3713 | 3747 | | |
3714 | 3748 | | |
3715 | 3749 | | |
3716 | 3750 | | |
3717 | | - | |
| 3751 | + | |
3718 | 3752 | | |
3719 | 3753 | | |
3720 | 3754 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
29 | | - | |
| 29 | + | |
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
| |||
161 | 161 | | |
162 | 162 | | |
163 | 163 | | |
| 164 | + | |
164 | 165 | | |
165 | 166 | | |
166 | 167 | | |
| |||
191 | 192 | | |
192 | 193 | | |
193 | 194 | | |
| 195 | + | |
194 | 196 | | |
195 | 197 | | |
196 | 198 | | |
| |||
219 | 221 | | |
220 | 222 | | |
221 | 223 | | |
| 224 | + | |
222 | 225 | | |
223 | 226 | | |
224 | 227 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
62 | 62 | | |
63 | 63 | | |
64 | 64 | | |
| 65 | + | |
65 | 66 | | |
66 | 67 | | |
67 | 68 | | |
| |||
790 | 791 | | |
791 | 792 | | |
792 | 793 | | |
| 794 | + | |
793 | 795 | | |
794 | 796 | | |
795 | 797 | | |
| |||
956 | 958 | | |
957 | 959 | | |
958 | 960 | | |
| 961 | + | |
959 | 962 | | |
960 | 963 | | |
961 | 964 | | |
| |||
1069 | 1072 | | |
1070 | 1073 | | |
1071 | 1074 | | |
| 1075 | + | |
1072 | 1076 | | |
1073 | 1077 | | |
1074 | 1078 | | |
| |||
1104 | 1108 | | |
1105 | 1109 | | |
1106 | 1110 | | |
| 1111 | + | |
1107 | 1112 | | |
1108 | 1113 | | |
1109 | 1114 | | |
| |||
0 commit comments