Skip to content

Commit 70e903b

Browse files
authored
[xpu][test] Port 2 test/quantization/pt2e/test_{quantize_pt2e, quantize_pt2e_qat} UT files to intel XPU (#3405)
* add test/quantization/pt2e/test_quantize_pt2e.py * add test/quantization/pt2e/test_quantize_pt2e.py * test/quantization/pt2e/test_quantize_pt2e_qat.py * test/quantization/pt2e/test_quantize_pt2e_qat.py * fix format issue * update format * increase timeout for xpu
1 parent 1272f3c commit 70e903b

File tree

3 files changed

+21
-17
lines changed

3 files changed

+21
-17
lines changed

.github/workflows/xpu_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
test:
2222
# Don't run on forked repos or empty test matrix
2323
# if: github.repository_owner == 'pytorch' && toJSON(fromJSON(inputs.test-matrix).include) != '[]'
24-
timeout-minutes: 60
24+
timeout-minutes: 120
2525
runs-on: linux.idc.xpu
2626
env:
2727
DOCKER_IMAGE: ci-image:pytorch-linux-noble-xpu-n-py3
@@ -166,7 +166,7 @@ jobs:
166166
GITHUB_RUN_NUMBER: ${{ github.run_number }}
167167
GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }}
168168
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
169-
timeout-minutes: 60
169+
timeout-minutes: 120
170170
run: |
171171
set -x
172172

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from torch.testing._internal.common_utils import (
3434
TEST_CUDA,
35+
TEST_XPU,
3536
TemporaryFileName,
3637
instantiate_parametrized_tests,
3738
parametrize,
@@ -70,9 +71,10 @@
7071
QuantizationConfig,
7172
)
7273
from torchao.testing.pt2e.utils import PT2EQuantizationTestCase
73-
from torchao.utils import torch_version_at_least
74+
from torchao.utils import get_current_accelerator_device, torch_version_at_least
7475

75-
DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else [])
76+
DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) + (["xpu"] if TEST_XPU else [])
77+
_DEVICE = get_current_accelerator_device()
7678

7779
if torch_version_at_least("2.7.0"):
7880
from torch.testing._internal.common_utils import (
@@ -2154,7 +2156,7 @@ def __init__(self) -> None:
21542156
def forward(self, x):
21552157
return self.bn(x)
21562158

2157-
if TEST_CUDA or TEST_HPU:
2159+
if TEST_CUDA or TEST_HPU or TEST_XPU:
21582160
m = M().train().to(device)
21592161
example_inputs = (torch.randn((1, 3, 3, 3), device=device),)
21602162

@@ -2229,9 +2231,9 @@ def forward(self, x):
22292231
x = self.dropout(x)
22302232
return x
22312233

2232-
if TEST_CUDA:
2233-
m = M().train().cuda()
2234-
example_inputs = (torch.randn(1, 3, 3, 3).cuda(),)
2234+
if TEST_CUDA or TEST_XPU:
2235+
m = M().train().to(_DEVICE)
2236+
example_inputs = (torch.randn(1, 3, 3, 3).to(_DEVICE),)
22352237
else:
22362238
m = M().train()
22372239
example_inputs = (torch.randn(1, 3, 3, 3),)
@@ -2243,7 +2245,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool):
22432245
bn_op = bn_train_op if train else bn_eval_op
22442246
bn_node = self._get_node(m, bn_op)
22452247
self.assertTrue(bn_node is not None)
2246-
if TEST_CUDA:
2248+
if TEST_CUDA or TEST_XPU:
22472249
self.assertEqual(bn_node.args[5], train)
22482250
dropout_node = self._get_node(m, torch.ops.aten.dropout.default)
22492251
self.assertEqual(dropout_node.args[2], train)

test/quantization/pt2e/test_quantize_pt2e_qat.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
skipIfNoQNNPACK,
2929
)
3030
from torch.testing._internal.common_quantized import override_quantized_engine
31-
from torch.testing._internal.common_utils import run_tests
31+
from torch.testing._internal.common_utils import TEST_XPU, run_tests
3232

3333
from torchao.quantization.pt2e import (
3434
FusedMovingAvgObsFakeQuantize,
@@ -52,7 +52,9 @@
5252
XNNPACKQuantizer,
5353
get_symmetric_quantization_config,
5454
)
55-
from torchao.utils import torch_version_at_least
55+
from torchao.utils import get_current_accelerator_device, torch_version_at_least
56+
57+
_DEVICE = get_current_accelerator_device()
5658

5759

5860
class PT2EQATTestCase(QuantizationTestCase):
@@ -453,10 +455,10 @@ def test_qat_conv_bn_fusion(self):
453455
self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False)
454456
self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
455457

456-
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
458+
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "GPU unavailable")
457459
def test_qat_conv_bn_fusion_cuda(self):
458-
m = self._get_conv_bn_model().cuda()
459-
example_inputs = (self.example_inputs[0].cuda(),)
460+
m = self._get_conv_bn_model().to(_DEVICE)
461+
example_inputs = (self.example_inputs[0].to(_DEVICE),)
460462
self._verify_symmetric_xnnpack_qat_graph(
461463
m,
462464
example_inputs,
@@ -540,10 +542,10 @@ def test_qat_conv_bn_relu_fusion(self):
540542
self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True)
541543
self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
542544

543-
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
545+
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "GPU unavailable")
544546
def test_qat_conv_bn_relu_fusion_cuda(self):
545-
m = self._get_conv_bn_model(has_relu=True).cuda()
546-
example_inputs = (self.example_inputs[0].cuda(),)
547+
m = self._get_conv_bn_model(has_relu=True).to(_DEVICE)
548+
example_inputs = (self.example_inputs[0].to(_DEVICE),)
547549
self._verify_symmetric_xnnpack_qat_graph(
548550
m,
549551
example_inputs,

0 commit comments

Comments
 (0)