Skip to content

Commit aa1eeda

Browse files
iamzainhudameta-codesync[bot]
authored andcommitted
uneven shard sizes support to Fully Sharded 2D collectives and unit tests (#3584)
Summary: Pull Request resolved: #3584 Adding support for uneven sharding splits across data parallel dimension. In sharding types like row wise and table row wise, uneven sharding cases exist which will cause current collectives in fully sharded 2D to fail. We add padding to ensure the collectives see equal shapes. The collectives shape handling happens as such: ``` total_size = self._emb_module.weights_dev.numel() shard_size = (total_size + num_groups - 1) // num_groups # ceil division padded_total_size = shard_size * num_groups padding_size = padded_total_size - total_size if padding_size > 0: input_tensor = torch.nn.functional.pad( self._emb_module.weights_dev.contiguous(), (0, padding_size), value=0.0, ) else: input_tensor = self._emb_module.weights_dev.contiguous() ``` Padding occurs on the right most shard (the same happens with TorchRec uneven sharding as the last shard is the uneven one The all_gather also accounts for this: ``` num_groups = self._env.num_sharding_groups() shard_size = self._shard_buf.numel() padded_total_size = shard_size * num_groups self._unsharded_param.untyped_storage().resize_( padded_total_size * self._element_size ) self._emb_module.weights_dev = self._unsharded_param[ : self._original_shape.numel() ] ``` This diff also adds all required unit tests for all sharding types for fully sharded 2D (sequence and pooled embeddings) Reviewed By: liangbeixu, kausv Differential Revision: D87406987 fbshipit-source-id: d1311bd665a6ce2443035f2da92ca73cdb892db3
1 parent 7c30d39 commit aa1eeda

File tree

4 files changed

+754
-37
lines changed

4 files changed

+754
-37
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2548,15 +2548,13 @@ def __init__(
25482548
self._env: ShardingEnv2D = env
25492549

25502550
self.weights_sharded = False
2551+
self._element_size = self._emb_module.weights_dev.element_size()
25512552
# pyre-ignore[8]
25522553
self._original_shape: torch.Size = self._emb_module.weights_dev.shape
25532554
# pyre-ignore[8]
25542555
self._unsharded_param: torch.Tensor = self._emb_module.weights_dev
2555-
self._stash_nbytes: int = (
2556-
self._emb_module.weights_dev.untyped_storage().nbytes() # pyre-ignore[29]
2557-
)
25582556
self._shard_buf_nbytes: int = 0
2559-
self.shard_buf: Optional[torch.Tensor] = None
2557+
self._shard_buf: Optional[torch.Tensor] = None
25602558

25612559
self._async_stream: torch.cuda.Stream = torch.cuda.Stream(
25622560
device=self._emb_module.weights_dev.device
@@ -2573,18 +2571,26 @@ def _all_gather_table_weights(self) -> None:
25732571
if not self.weights_sharded:
25742572
return
25752573
self._wait_on_reduce_scatter()
2576-
self._unsharded_param.untyped_storage().resize_(self._stash_nbytes)
2574+
num_groups = self._env.num_sharding_groups()
2575+
shard_size = self._shard_buf.numel()
2576+
padded_total_size = shard_size * num_groups
2577+
2578+
self._unsharded_param.untyped_storage().resize_(
2579+
padded_total_size * self._element_size
2580+
)
25772581

25782582
dist.all_gather_into_tensor(
25792583
output_tensor=self._unsharded_param,
2580-
input_tensor=self.shard_buf,
2584+
input_tensor=self._shard_buf,
25812585
group=self._env.replica_pg,
25822586
async_op=False,
25832587
)
25842588
# pyre-ignore[16]
2585-
self._emb_module.weights_dev = self._unsharded_param
2589+
self._emb_module.weights_dev = self._unsharded_param[
2590+
: self._original_shape.numel()
2591+
]
25862592
# pyre-ignore[16]
2587-
self.shard_buf.untyped_storage().resize_(0)
2593+
self._shard_buf.untyped_storage().resize_(0)
25882594
self.weights_sharded = False
25892595

25902596
def _hybird_sharded_backward_hook(
@@ -2633,26 +2639,38 @@ def _reduce_scatter_weights_async(self) -> ReduceScatterResizeAwaitable:
26332639

26342640
# pyre-ignore[29]
26352641
total_size = self._emb_module.weights_dev.numel()
2636-
shard_size = total_size // num_groups
26372642

2638-
if self.shard_buf is None:
2639-
self.shard_buf = torch.empty(
2643+
shard_size = (total_size + num_groups - 1) // num_groups # ceil division
2644+
padded_total_size = shard_size * num_groups
2645+
padding_size = padded_total_size - total_size
2646+
2647+
if padding_size > 0:
2648+
input_tensor = torch.nn.functional.pad(
2649+
self._emb_module.weights_dev.contiguous(),
2650+
(0, padding_size),
2651+
value=0.0,
2652+
)
2653+
else:
2654+
input_tensor = self._emb_module.weights_dev.contiguous()
2655+
2656+
if self._shard_buf is None:
2657+
self._shard_buf = torch.empty(
26402658
shard_size,
26412659
# pyre-ignore[6]
26422660
dtype=self._emb_module.weights_dev.dtype,
26432661
# pyre-ignore[6]
26442662
device=self._emb_module.weights_dev.device,
26452663
)
26462664
# pyre-ignore[16]
2647-
self._shard_buf_nbytes = self.shard_buf.untyped_storage().nbytes()
2665+
self._shard_buf_nbytes = self._shard_buf.untyped_storage().nbytes()
26482666
else:
2649-
self.shard_buf.untyped_storage().resize_(self._shard_buf_nbytes)
2667+
self._shard_buf.untyped_storage().resize_(self._shard_buf_nbytes)
26502668

26512669
# pyre-ignore[29]
26522670
input_tensor = self._emb_module.weights_dev.contiguous()
26532671

26542672
self._async_work = dist.reduce_scatter_tensor(
2655-
output=self.shard_buf,
2673+
output=self._shard_buf,
26562674
input=input_tensor,
26572675
op=dist.ReduceOp.AVG,
26582676
group=self._env.replica_pg,
@@ -2665,14 +2683,14 @@ def _reduce_scatter_weights_async(self) -> ReduceScatterResizeAwaitable:
26652683

26662684
def resize_callback() -> None:
26672685
self._emb_module.weights_dev.untyped_storage().resize_(0) # pyre-ignore[29]
2668-
self._emb_module.weights_dev = self.shard_buf # pyre-ignore[16]
2686+
self._emb_module.weights_dev = self._shard_buf # pyre-ignore[16]
26692687

26702688
return ReduceScatterResizeAwaitable(
26712689
async_work=self._async_work,
26722690
async_event=self._async_event,
26732691
async_stream=self._async_stream,
26742692
unsharded_param=self._unsharded_param,
2675-
shard_buf=self.shard_buf,
2693+
shard_buf=self._shard_buf,
26762694
resize_callback=resize_callback,
26772695
)
26782696

@@ -3590,15 +3608,13 @@ def __init__(
35903608
self._env: ShardingEnv2D = env
35913609

35923610
self.weights_sharded = False
3611+
self._element_size = self._emb_module.weights_dev.element_size()
35933612
# pyre-ignore[8]
35943613
self._original_shape: torch.Size = self._emb_module.weights_dev.shape
35953614
# pyre-ignore[8]
35963615
self._unsharded_param: torch.Tensor = self._emb_module.weights_dev
3597-
self._stash_nbytes: int = (
3598-
self._emb_module.weights_dev.untyped_storage().nbytes() # pyre-ignore[29]
3599-
)
36003616
self._shard_buf_nbytes: int = 0
3601-
self.shard_buf: Optional[torch.Tensor] = None
3617+
self._shard_buf: Optional[torch.Tensor] = None
36023618

36033619
self._async_stream: torch.cuda.Stream = torch.cuda.Stream(
36043620
device=self._emb_module.weights_dev.device
@@ -3615,18 +3631,27 @@ def _all_gather_table_weights(self) -> None:
36153631
if not self.weights_sharded:
36163632
return
36173633
self._wait_on_reduce_scatter()
3618-
self._unsharded_param.untyped_storage().resize_(self._stash_nbytes)
3634+
3635+
num_groups = self._env.num_sharding_groups()
3636+
shard_size = self._shard_buf.numel()
3637+
padded_total_size = shard_size * num_groups
3638+
3639+
self._unsharded_param.untyped_storage().resize_(
3640+
padded_total_size * self._element_size
3641+
)
36193642

36203643
dist.all_gather_into_tensor(
36213644
output_tensor=self._unsharded_param,
3622-
input_tensor=self.shard_buf,
3645+
input_tensor=self._shard_buf,
36233646
group=self._env.replica_pg,
36243647
async_op=False,
36253648
)
36263649
# pyre-ignore[16]
3627-
self._emb_module.weights_dev = self._unsharded_param
3650+
self._emb_module.weights_dev = self._unsharded_param[
3651+
: self._original_shape.numel()
3652+
]
36283653
# pyre-ignore[16]
3629-
self.shard_buf.untyped_storage().resize_(0)
3654+
self._shard_buf.untyped_storage().resize_(0)
36303655
self.weights_sharded = False
36313656

36323657
def _hybird_sharded_backward_hook(
@@ -3675,26 +3700,35 @@ def _reduce_scatter_weights_async(self) -> ReduceScatterResizeAwaitable:
36753700

36763701
# pyre-ignore[29]
36773702
total_size = self._emb_module.weights_dev.numel()
3678-
shard_size = total_size // num_groups
36793703

3680-
if self.shard_buf is None:
3681-
self.shard_buf = torch.empty(
3704+
shard_size = (total_size + num_groups - 1) // num_groups # ceil division
3705+
padded_total_size = shard_size * num_groups
3706+
padding_size = padded_total_size - total_size
3707+
3708+
if padding_size > 0:
3709+
input_tensor = torch.nn.functional.pad(
3710+
self._emb_module.weights_dev.contiguous(),
3711+
(0, padding_size),
3712+
value=0.0,
3713+
)
3714+
else:
3715+
input_tensor = self._emb_module.weights_dev.contiguous()
3716+
3717+
if self._shard_buf is None:
3718+
self._shard_buf = torch.empty(
36823719
shard_size,
36833720
# pyre-ignore[6]
36843721
dtype=self._emb_module.weights_dev.dtype,
36853722
# pyre-ignore[6]
36863723
device=self._emb_module.weights_dev.device,
36873724
)
36883725
# pyre-ignore[16]
3689-
self._shard_buf_nbytes = self.shard_buf.untyped_storage().nbytes()
3726+
self._shard_buf_nbytes = self._shard_buf.untyped_storage().nbytes()
36903727
else:
3691-
self.shard_buf.untyped_storage().resize_(self._shard_buf_nbytes)
3692-
3693-
# pyre-ignore[29]
3694-
input_tensor = self._emb_module.weights_dev.contiguous()
3728+
self._shard_buf.untyped_storage().resize_(self._shard_buf_nbytes)
36953729

36963730
self._async_work = dist.reduce_scatter_tensor(
3697-
output=self.shard_buf,
3731+
output=self._shard_buf,
36983732
input=input_tensor,
36993733
op=dist.ReduceOp.AVG,
37003734
group=self._env.replica_pg,
@@ -3707,14 +3741,14 @@ def _reduce_scatter_weights_async(self) -> ReduceScatterResizeAwaitable:
37073741

37083742
def resize_callback() -> None:
37093743
self._emb_module.weights_dev.untyped_storage().resize_(0) # pyre-ignore[29]
3710-
self._emb_module.weights_dev = self.shard_buf # pyre-ignore[16]
3744+
self._emb_module.weights_dev = self._shard_buf # pyre-ignore[16]
37113745

37123746
return ReduceScatterResizeAwaitable(
37133747
async_work=self._async_work,
37143748
async_event=self._async_event,
37153749
async_stream=self._async_stream,
37163750
unsharded_param=self._unsharded_param,
3717-
shard_buf=self.shard_buf,
3751+
shard_buf=self._shard_buf,
37183752
resize_callback=resize_callback,
37193753
)
37203754

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
SharderType,
2727
sharding_single_rank_test,
2828
)
29-
from torchrec.distributed.types import ModuleSharder, ShardingType
29+
from torchrec.distributed.types import ModuleSharder, ShardingStrategy, ShardingType
3030
from torchrec.modules.embedding_configs import EmbeddingBagConfig, PoolingType
3131
from torchrec.test_utils import seed_and_log, skip_if_asan_class
3232
from torchrec.types import DataType
@@ -161,6 +161,7 @@ def _test_sharding(
161161
indices_dtype: torch.dtype = torch.int64,
162162
offsets_dtype: torch.dtype = torch.int64,
163163
lengths_dtype: torch.dtype = torch.int64,
164+
sharding_strategy: Optional[ShardingStrategy] = None,
164165
) -> None:
165166
self._build_tables_and_groups(data_type=data_type)
166167
# directly run the test with single process
@@ -191,6 +192,7 @@ def _test_sharding(
191192
indices_dtype=indices_dtype,
192193
offsets_dtype=offsets_dtype,
193194
lengths_dtype=lengths_dtype,
195+
sharding_strategy=sharding_strategy,
194196
)
195197
else:
196198
self._run_multi_process_test(
@@ -219,6 +221,7 @@ def _test_sharding(
219221
indices_dtype=indices_dtype,
220222
offsets_dtype=offsets_dtype,
221223
lengths_dtype=lengths_dtype,
224+
sharding_strategy=sharding_strategy,
222225
)
223226

224227
def _test_dynamic_sharding(

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
ShardedTensor,
6363
ShardingEnv,
6464
ShardingPlan,
65+
ShardingStrategy,
6566
ShardingType,
6667
)
6768
from torchrec.modules.embedding_configs import (
@@ -790,6 +791,7 @@ def sharding_single_rank_test_single_process(
790791
offsets_dtype: torch.dtype = torch.int64,
791792
lengths_dtype: torch.dtype = torch.int64,
792793
random_seed: Optional[int] = None,
794+
sharding_strategy: Optional[ShardingStrategy] = None,
793795
) -> None:
794796
batch_size = random.randint(0, batch_size) if allow_zero_batch_size else batch_size
795797
# Generate model & inputs.
@@ -956,6 +958,7 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
956958
use_inter_host_allreduce=use_inter_host_allreduce,
957959
custom_all_reduce=all_reduce_func,
958960
submodule_configs=submodule_configs,
961+
sharding_strategy=sharding_strategy,
959962
)
960963
else:
961964
local_model = DistributedModelParallel(
@@ -1069,6 +1072,7 @@ def sharding_single_rank_test(
10691072
offsets_dtype: torch.dtype = torch.int64,
10701073
lengths_dtype: torch.dtype = torch.int64,
10711074
random_seed: Optional[int] = None,
1075+
sharding_strategy: Optional[ShardingStrategy] = None,
10721076
) -> None:
10731077
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
10741078
assert ctx.pg is not None
@@ -1104,6 +1108,7 @@ def sharding_single_rank_test(
11041108
offsets_dtype=offsets_dtype,
11051109
lengths_dtype=lengths_dtype,
11061110
random_seed=random_seed,
1111+
sharding_strategy=sharding_strategy,
11071112
)
11081113

11091114

0 commit comments

Comments
 (0)