Skip to content

Commit ea81677

Browse files
committed
resolve CUDA OOM issue in tests
1 parent 9c2faf5 commit ea81677

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/py/dynamo/models/test_weight_stripped_engine.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class TestWeightStrippedEngine(TestCase):
2929
)
3030
def test_three_ways_to_compile(self):
3131
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
32-
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
32+
example_inputs = (torch.randn((2, 3, 224, 224)).to("cuda"),)
3333
exp_program = torch.export.export(pyt_model, example_inputs)
3434

3535
settings = {
@@ -48,6 +48,7 @@ def test_three_ways_to_compile(self):
4848
**settings,
4949
)
5050
gm1_output = gm1(*example_inputs)
51+
torch.cuda.empty_cache()
5152

5253
# 2. Compile with torch.compile using tensorrt backend
5354
gm2 = torch.compile(
@@ -56,8 +57,10 @@ def test_three_ways_to_compile(self):
5657
options=settings,
5758
)
5859
gm2_output = gm2(*example_inputs)
60+
torch.cuda.empty_cache()
5961

6062
pyt_model_output = pyt_model(*example_inputs)
63+
torch.cuda.empty_cache()
6164

6265
assert torch.allclose(
6366
pyt_model_output, gm1_output, 1e-2, 1e-2

0 commit comments

Comments
 (0)