Skip to content

Commit 28c82d5

Browse files
iamzainhudameta-codesync[bot]
authored andcommitted
dynamic 2D + fully sharded 2D (#3600)
Summary: Pull Request resolved: #3600 Add support for dynamic (D76774334) + fully sharded 2D together. Users can specify which modules to apply fully sharded 2D to through adding `ShardingStrategy` in their submodule configs. Reviewed By: aliafzal Differential Revision: D88675533 fbshipit-source-id: ed7c7a4b767aa9317848e5ed65e7dcc2795c5f29
1 parent 71a0539 commit 28c82d5

File tree

4 files changed

+219
-46
lines changed

4 files changed

+219
-46
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2565,9 +2565,8 @@ def _all_gather_table_weights(self) -> None:
25652565
if not self.weights_sharded:
25662566
return
25672567
self._wait_on_reduce_scatter()
2568-
num_groups = self._env.num_sharding_groups()
25692568
shard_size = self._shard_buf.numel()
2570-
padded_total_size = shard_size * num_groups
2569+
padded_total_size = shard_size * self._env.num_sharding_groups()
25712570

25722571
self._unsharded_param.untyped_storage().resize_(
25732572
padded_total_size * self._element_size
@@ -2629,11 +2628,11 @@ def _reduce_scatter_weights_async(self) -> ReduceScatterResizeAwaitable:
26292628
"""
26302629
with torch.no_grad():
26312630
self.weights_sharded = True
2632-
num_groups = self._env.num_sharding_groups()
26332631

26342632
# pyre-ignore[29]
26352633
total_size = self._emb_module.weights_dev.numel()
26362634

2635+
num_groups = self._env.num_sharding_groups()
26372636
shard_size = (total_size + num_groups - 1) // num_groups # ceil division
26382637
padded_total_size = shard_size * num_groups
26392638
padding_size = padded_total_size - total_size
@@ -3594,7 +3593,6 @@ def __init__(
35943593
env: Optional[ShardingEnv] = None,
35953594
) -> None:
35963595
super().__init__(config, pg, device, sharding_type)
3597-
35983596
assert isinstance(
35993597
env, ShardingEnv2D
36003598
), "env is required for ShardedBatchedFusedEmbeddingBag"
@@ -3625,9 +3623,8 @@ def _all_gather_table_weights(self) -> None:
36253623
return
36263624
self._wait_on_reduce_scatter()
36273625

3628-
num_groups = self._env.num_sharding_groups()
36293626
shard_size = self._shard_buf.numel()
3630-
padded_total_size = shard_size * num_groups
3627+
padded_total_size = shard_size * self._env.num_sharding_groups()
36313628

36323629
self._unsharded_param.untyped_storage().resize_(
36333630
padded_total_size * self._element_size
@@ -3689,11 +3686,11 @@ def _reduce_scatter_weights_async(self) -> ReduceScatterResizeAwaitable:
36893686
"""
36903687
with torch.no_grad():
36913688
self.weights_sharded = True
3692-
num_groups = self._env.num_sharding_groups()
36933689

36943690
# pyre-ignore[29]
36953691
total_size = self._emb_module.weights_dev.numel()
36963692

3693+
num_groups = self._env.num_sharding_groups()
36973694
shard_size = (total_size + num_groups - 1) // num_groups # ceil division
36983695
padded_total_size = shard_size * num_groups
36993696
padding_size = padded_total_size - total_size

torchrec/distributed/model_parallel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,7 @@ def __init__(
931931
plan=submodule_config.plan,
932932
sharding_group_size=submodule_config.sharding_group_size,
933933
use_inter_host_allreduce=submodule_config.use_inter_host_allreduce,
934+
sharding_strategy=submodule_config.sharding_strategy,
934935
)
935936
)
936937

@@ -1022,6 +1023,7 @@ def _shard_modules_impl(
10221023
device_mesh=ctx.device_mesh,
10231024
node_group_size=ctx.sharding_group_size,
10241025
use_inter_host_allreduce=ctx.use_inter_host_allreduce,
1026+
sharding_strategy=ctx.sharding_strategy,
10251027
)
10261028
break
10271029

torchrec/distributed/tests/test_2d_sharding.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,172 @@ def test_sharding_dynamic_2D(
952952
submodule_configs=[ec_submodule_config],
953953
)
954954

955+
@unittest.skipIf(
956+
torch.cuda.device_count() <= 7,
957+
"Not enough GPUs, this test requires at least eight GPUs",
958+
)
959+
# pyre-fixme[56]
960+
@given(
961+
sharding_type=st.just(ShardingType.ROW_WISE.value),
962+
kernel_type=st.sampled_from(
963+
[
964+
# EmbeddingComputeKernel.DENSE.value,
965+
EmbeddingComputeKernel.FUSED.value,
966+
]
967+
),
968+
qcomms_config=st.sampled_from(
969+
[
970+
None,
971+
QCommsConfig(
972+
forward_precision=CommType.FP16, backward_precision=CommType.BF16
973+
),
974+
]
975+
),
976+
apply_optimizer_in_backward_config=st.sampled_from(
977+
[
978+
None,
979+
{
980+
"embedding_bags": (torch.optim.SGD, {"lr": 0.01}),
981+
"embeddings": (torch.optim.SGD, {"lr": 0.2}),
982+
},
983+
]
984+
),
985+
variable_batch_size=st.booleans(),
986+
)
987+
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
988+
def test_fully_sharded_dynamic_2D(
989+
self,
990+
sharding_type: str,
991+
kernel_type: str,
992+
qcomms_config: Optional[QCommsConfig],
993+
apply_optimizer_in_backward_config: Optional[
994+
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
995+
],
996+
variable_batch_size: bool,
997+
) -> None:
998+
assume(
999+
apply_optimizer_in_backward_config is None
1000+
or kernel_type != EmbeddingComputeKernel.DENSE.value
1001+
)
1002+
1003+
# add sharding plan for embedding collection later
1004+
ec_submodule_config = DMPCollectionConfig(
1005+
module=EmbeddingCollection,
1006+
sharding_group_size=2,
1007+
plan=None, # pyre-ignore[6]
1008+
sharding_strategy=ShardingStrategy.FULLY_SHARDED,
1009+
)
1010+
1011+
self._test_sharding(
1012+
world_size=self.WORLD_SIZE,
1013+
world_size_2D=self.WORLD_SIZE_2D,
1014+
sharders=[ # pyre-ignore[6]
1015+
cast(
1016+
ModuleSharder[nn.Module],
1017+
create_test_sharder(
1018+
SharderType.EMBEDDING_BAG_COLLECTION.value,
1019+
sharding_type,
1020+
kernel_type,
1021+
qcomms_config=qcomms_config,
1022+
device=torch.device("cuda"),
1023+
),
1024+
),
1025+
],
1026+
backend="nccl",
1027+
qcomms_config=qcomms_config,
1028+
constraints={
1029+
table.name: ParameterConstraints(min_partition=2)
1030+
for table in self.tables
1031+
},
1032+
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
1033+
variable_batch_size=variable_batch_size,
1034+
submodule_configs=[ec_submodule_config],
1035+
sharding_strategy=ShardingStrategy.FULLY_SHARDED,
1036+
)
1037+
1038+
@unittest.skipIf(
1039+
torch.cuda.device_count() <= 7,
1040+
"Not enough GPUs, this test requires at least eight GPUs",
1041+
)
1042+
# pyre-fixme[56]
1043+
@given(
1044+
sharding_type=st.just(ShardingType.ROW_WISE.value),
1045+
kernel_type=st.sampled_from(
1046+
[
1047+
# EmbeddingComputeKernel.DENSE.value,
1048+
EmbeddingComputeKernel.FUSED.value,
1049+
]
1050+
),
1051+
qcomms_config=st.sampled_from(
1052+
[
1053+
None,
1054+
QCommsConfig(
1055+
forward_precision=CommType.FP16, backward_precision=CommType.BF16
1056+
),
1057+
]
1058+
),
1059+
apply_optimizer_in_backward_config=st.sampled_from(
1060+
[
1061+
None,
1062+
{
1063+
"embedding_bags": (torch.optim.SGD, {"lr": 0.01}),
1064+
"embeddings": (torch.optim.SGD, {"lr": 0.2}),
1065+
},
1066+
]
1067+
),
1068+
variable_batch_size=st.booleans(),
1069+
)
1070+
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
1071+
def test_partially_fully_sharded_dynamic_2D(
1072+
self,
1073+
sharding_type: str,
1074+
kernel_type: str,
1075+
qcomms_config: Optional[QCommsConfig],
1076+
apply_optimizer_in_backward_config: Optional[
1077+
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
1078+
],
1079+
variable_batch_size: bool,
1080+
) -> None:
1081+
assume(
1082+
apply_optimizer_in_backward_config is None
1083+
or kernel_type != EmbeddingComputeKernel.DENSE.value
1084+
)
1085+
1086+
# add sharding plan for embedding collection later
1087+
ec_submodule_config = DMPCollectionConfig(
1088+
module=EmbeddingCollection,
1089+
sharding_group_size=2,
1090+
plan=None, # pyre-ignore[6]
1091+
sharding_strategy=ShardingStrategy.FULLY_SHARDED, # only apply fully sharded to EC tables
1092+
)
1093+
1094+
self._test_sharding(
1095+
world_size=self.WORLD_SIZE,
1096+
world_size_2D=self.WORLD_SIZE_2D,
1097+
sharders=[ # pyre-ignore[6]
1098+
cast(
1099+
ModuleSharder[nn.Module],
1100+
create_test_sharder(
1101+
SharderType.EMBEDDING_BAG_COLLECTION.value,
1102+
sharding_type,
1103+
kernel_type,
1104+
qcomms_config=qcomms_config,
1105+
device=torch.device("cuda"),
1106+
),
1107+
),
1108+
],
1109+
backend="nccl",
1110+
qcomms_config=qcomms_config,
1111+
constraints={
1112+
table.name: ParameterConstraints(min_partition=2)
1113+
for table in self.tables
1114+
},
1115+
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
1116+
variable_batch_size=variable_batch_size,
1117+
submodule_configs=[ec_submodule_config],
1118+
sharding_strategy=ShardingStrategy.DEFAULT,
1119+
)
1120+
9551121
def _test_sharding(
9561122
self,
9571123
sharders: List[TestEmbeddingCollectionSharder],
@@ -969,6 +1135,7 @@ def _test_sharding(
9691135
variable_batch_size: bool = False,
9701136
variable_batch_per_feature: bool = False,
9711137
submodule_configs: Optional[List[DMPCollectionConfig]] = None,
1138+
sharding_strategy: ShardingStrategy = ShardingStrategy.DEFAULT,
9721139
) -> None:
9731140
self._run_multi_process_test(
9741141
callable=sharding_single_rank_test,
@@ -988,6 +1155,7 @@ def _test_sharding(
9881155
variable_batch_per_feature=variable_batch_per_feature,
9891156
global_constant_batch=True,
9901157
submodule_configs=submodule_configs,
1158+
sharding_strategy=sharding_strategy,
9911159
)
9921160

9931161

torchrec/distributed/types.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,51 @@ class ShardingStrategy(Enum):
931931
FULLY_SHARDED = "fully_sharded"
932932

933933

934+
class DMPCollectionConfig:
935+
module: Type[nn.Module]
936+
plan: "ShardingPlan" = field(repr=False) # sub-tree-specific sharding plan
937+
sharding_group_size: int
938+
node_group_size: Optional[int] = None
939+
use_inter_host_allreduce: bool = False
940+
sharding_strategy: ShardingStrategy = ShardingStrategy.DEFAULT
941+
942+
def __init__(
943+
self,
944+
module: Type[nn.Module],
945+
plan: "ShardingPlan",
946+
sharding_group_size: int,
947+
node_group_size: Optional[int] = None,
948+
use_inter_host_allreduce: bool = False,
949+
sharding_strategy: ShardingStrategy = ShardingStrategy.DEFAULT,
950+
) -> None:
951+
self.module = module
952+
self.plan = plan
953+
self.sharding_group_size = sharding_group_size
954+
self.node_group_size = node_group_size
955+
self.use_inter_host_allreduce = use_inter_host_allreduce
956+
self.sharding_strategy = sharding_strategy
957+
958+
def __post_init__(self) -> None:
959+
if isinstance(self.module, ShardedModule):
960+
raise ValueError(
961+
f"ShardedModule should not be passed into DMPCollectionConfig: got {type(self.module)}"
962+
)
963+
964+
965+
# for internal use in DMPCollection
966+
class DMPCollectionContext(DMPCollectionConfig):
967+
device_mesh: "DeviceMesh" = field(init=False)
968+
sharding_pg: "dist.ProcessGroup" = field(init=False)
969+
replica_pg: "dist.ProcessGroup" = field(init=False)
970+
modules_to_sync: List[Tuple[nn.Module, nn.Module]] = field(
971+
init=False, default_factory=list
972+
)
973+
sharded_module: Optional[nn.Module] = field(init=False, default=None)
974+
sharding_strategy: ShardingStrategy = field(
975+
init=False, default=ShardingStrategy.DEFAULT
976+
)
977+
978+
934979
class ShardingEnv2D(ShardingEnv):
935980
"""
936981
Creates a sharding environment for 2D parallelism, enables usage of 2D parallelism in sharding
@@ -1375,42 +1420,3 @@ class ShardingBucketMetadata:
13751420
num_buckets_per_shard: List[int]
13761421
bucket_offsets_per_shard: List[int]
13771422
bucket_size: int
1378-
1379-
1380-
class DMPCollectionConfig:
1381-
module: Type[nn.Module]
1382-
plan: "ShardingPlan" = field(repr=False) # sub-tree-specific sharding plan
1383-
sharding_group_size: int
1384-
node_group_size: Optional[int] = None
1385-
use_inter_host_allreduce: bool = False
1386-
1387-
def __init__(
1388-
self,
1389-
module: Type[nn.Module],
1390-
plan: "ShardingPlan",
1391-
sharding_group_size: int,
1392-
node_group_size: Optional[int] = None,
1393-
use_inter_host_allreduce: bool = False,
1394-
) -> None:
1395-
self.module = module
1396-
self.plan = plan
1397-
self.sharding_group_size = sharding_group_size
1398-
self.node_group_size = node_group_size
1399-
self.use_inter_host_allreduce = use_inter_host_allreduce
1400-
1401-
def __post_init__(self) -> None:
1402-
if isinstance(self.module, ShardedModule):
1403-
raise ValueError(
1404-
f"ShardedModule should not be passed into DMPCollectionConfig: got {type(self.module)}"
1405-
)
1406-
1407-
1408-
# for internal use in DMPCollection
1409-
class DMPCollectionContext(DMPCollectionConfig):
1410-
device_mesh: "DeviceMesh" = field(init=False)
1411-
sharding_pg: "dist.ProcessGroup" = field(init=False)
1412-
replica_pg: "dist.ProcessGroup" = field(init=False)
1413-
modules_to_sync: List[Tuple[nn.Module, nn.Module]] = field(
1414-
init=False, default_factory=list
1415-
)
1416-
sharded_module: Optional[nn.Module] = field(init=False, default=None)

0 commit comments

Comments
 (0)