Skip to content

Commit 051ceae

Browse files
Raahul Kalyaan Jakkameta-codesync[bot]
authored andcommitted
Added unit test for Heuristic Storage Reservation (#3511)
Summary: Pull Request resolved: #3511 **Context:** Heuristic Storage reservation is a common component for all planner that checks if the given module along with the constraints can be sharded across the topology. **In this diff:** We added a UT to validate the error for storage use in the storage reservation process. If the given module is larger than the provided topology. We need to OOM the process asap with appropriate error to notify the PG Reviewed By: kausv, mserturk Differential Revision: D85892579 fbshipit-source-id: 03eb679e6cabf8c030092c46d93a7b030e4c3814
1 parent b12fbba commit 051ceae

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

torchrec/distributed/planner/tests/test_storage_reservations.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
_get_module_size,
2020
HeuristicalStorageReservation,
2121
)
22-
from torchrec.distributed.planner.types import Topology
22+
from torchrec.distributed.planner.types import PlannerError, PlannerErrorType, Topology
2323

2424
from torchrec.distributed.test_utils.test_model import TestTowerInteraction
2525
from torchrec.distributed.types import ModuleSharder
@@ -36,6 +36,36 @@ def __init__(self, shardable_sparse: nn.Module) -> None:
3636

3737

3838
class TestHeuristicalStorageReservation(unittest.TestCase):
39+
40+
def test_validate_storage_reservations_errors(self) -> None:
41+
tables = [
42+
EmbeddingBagConfig(
43+
num_embeddings=1_000_000,
44+
embedding_dim=1024,
45+
name="table_0",
46+
feature_names=["feature_0"],
47+
),
48+
]
49+
50+
ebc = EmbeddingBagCollection(tables)
51+
model = TestModel(shardable_sparse=ebc)
52+
53+
# Reserving 100% of HBM to make sure the heuristic storage reservation fails
54+
heuristical_storage_reservation = HeuristicalStorageReservation(percentage=1)
55+
with self.assertRaises(PlannerError) as context:
56+
heuristical_storage_reservation.reserve(
57+
topology=Topology(world_size=1, compute_device="cuda"),
58+
batch_size=1024,
59+
module=model,
60+
sharders=cast(
61+
List[ModuleSharder[nn.Module]], [EmbeddingBagCollectionSharder()]
62+
),
63+
)
64+
65+
self.assertEqual(
66+
context.exception.error_type, PlannerErrorType.INSUFFICIENT_STORAGE
67+
)
68+
3969
def test_storage_reservations_ebc(self) -> None:
4070
tables = [
4171
EmbeddingBagConfig(

0 commit comments

Comments
 (0)