Skip to content

Commit aa21b80

Browse files
skip certain mxfp8 tests for cuda < 12.8 (#3443)
skip certain mxfp8 tests when mxfp8_cuda extension is unavailable
1 parent 69ce0fd commit aa21b80

File tree

3 files changed

+61
-9
lines changed

3 files changed

+61
-9
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_dtype, to_mx
4444
from torchao.prototype.mx_formats.utils import to_blocked
4545
from torchao.utils import (
46+
is_cuda_version_at_least,
4647
is_sm_at_least_89,
4748
is_sm_at_least_100,
4849
torch_version_at_least,
@@ -529,6 +530,10 @@ def test_rearrange(shape):
529530
not is_sm_at_least_100(),
530531
reason="MXFP8 requires CUDA capability 10.0 or greater",
531532
)
533+
@pytest.mark.skipif(
534+
not is_cuda_version_at_least(12, 8),
535+
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
536+
)
532537
@pytest.mark.parametrize("M", (32, 256))
533538
@pytest.mark.parametrize("K", (32, 256))
534539
@pytest.mark.parametrize("input_dtype", (torch.float32, torch.bfloat16))
@@ -577,6 +582,10 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
577582
not is_sm_at_least_100(),
578583
reason="MXFP8 requires CUDA capability 10.0 or greater",
579584
)
585+
@pytest.mark.skipif(
586+
not is_cuda_version_at_least(12, 8),
587+
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
588+
)
580589
def test_cuda_mx_dim0_not_supported():
581590
from torchao.prototype import mxfp8_cuda
582591

@@ -601,6 +610,10 @@ def test_cuda_mx_dim0_not_supported():
601610
not is_sm_at_least_100(),
602611
reason="MXFP8 requires CUDA capability 10.0 or greater",
603612
)
613+
@pytest.mark.skipif(
614+
not is_cuda_version_at_least(12, 8),
615+
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
616+
)
604617
def test_cuda_mx_dim1_invalid_block_size():
605618
from torchao.prototype import mxfp8_cuda
606619

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torchao.quantization import quantize_
2727
from torchao.quantization.utils import compute_error
2828
from torchao.utils import (
29+
is_cuda_version_at_least,
2930
is_sm_at_least_89,
3031
is_sm_at_least_100,
3132
torch_version_at_least,
@@ -50,12 +51,25 @@ def run_around_tests():
5051
elem_dtypes = (
5152
[
5253
# test each dtype
53-
(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn),
54+
(
55+
torch.float8_e4m3fn,
56+
torch.float8_e4m3fn,
57+
torch.float8_e4m3fn,
58+
),
5459
(DTYPE_FP6_E3M2, DTYPE_FP6_E3M2, DTYPE_FP6_E3M2),
5560
(DTYPE_FP6_E2M3, DTYPE_FP6_E2M3, DTYPE_FP6_E2M3),
56-
(torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2),
57-
# only test one type of mixed-dtype overrides, to save testing time
58-
(torch.float8_e4m3fn, torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2),
61+
(
62+
torch.float4_e2m1fn_x2,
63+
torch.float4_e2m1fn_x2,
64+
torch.float4_e2m1fn_x2,
65+
),
66+
# only test one type of mixed-dtype overrides, to save
67+
# testing time
68+
(
69+
torch.float8_e4m3fn,
70+
torch.float4_e2m1fn_x2,
71+
torch.float4_e2m1fn_x2,
72+
),
5973
]
6074
if torch_version_at_least("2.8.0")
6175
else [
@@ -117,6 +131,8 @@ def test_linear_eager_vs_hp(
117131
pytest.skip("unsupported configuration")
118132
elif not is_sm_at_least_100():
119133
pytest.skip("CUDA capability >= 10.0 required for MX dim1 cast cuda kernel")
134+
elif not is_cuda_version_at_least(12, 8):
135+
pytest.skip("CUDA version >= 12.8 required for MXFP8 CUDA extension")
120136

121137
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
122138
grad_shape = list(input_shape)
@@ -166,7 +182,12 @@ def test_linear_eager_vs_hp(
166182

167183
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
168184
@pytest.mark.skipif(
169-
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
185+
not is_sm_at_least_100(),
186+
reason="CUDA capability >= 10.0 required for mxfloat8",
187+
)
188+
@pytest.mark.skipif(
189+
not is_cuda_version_at_least(12, 8),
190+
reason="CUDA version >= 12.8 required for MXFP8",
170191
)
171192
@pytest.mark.parametrize(
172193
"recipe_name",
@@ -303,6 +324,10 @@ def test_linear_compile(
303324
ScaleCalculationMode.RCEIL,
304325
):
305326
pytest.skip("unsupported configuration")
327+
elif not is_sm_at_least_100():
328+
pytest.skip("CUDA capability >= 10.0 required for MX dim1 cast cuda kernel")
329+
elif not is_cuda_version_at_least(12, 8):
330+
pytest.skip("CUDA version >= 12.8 required for MXFP8")
306331

307332
if hp_dtype == torch.bfloat16 and recipe_name != "mxfp8_cublas":
308333
# TODO(future PR): properly enable float32 + bfloat16 for every
@@ -318,7 +343,8 @@ def test_linear_compile(
318343
):
319344
# TODO(future): debug this
320345
pytest.skip(
321-
"there are currently accuracy issues with this configuration on H100 and below"
346+
"there are currently accuracy issues with this configuration "
347+
"on H100 and below"
322348
)
323349

324350
M, K, N = 128, 256, 512

torchao/utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"get_model_size_in_bytes",
3030
"unwrap_tensor_subclass",
3131
"TorchAOBaseTensor",
32+
"is_cuda_version_at_least",
3233
"is_MI300",
3334
"is_sm_at_least_89",
3435
"is_sm_at_least_90",
@@ -512,9 +513,11 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool:
512513
if hasattr(self, "optional_tensor_data_names"):
513514
# either both are None or both are not Tensors and the shape match
514515
_optional_tensor_shape_match = all(
515-
getattr(self, t_name).shape == getattr(src, t_name).shape
516-
if getattr(self, t_name) is not None
517-
else getattr(src, t_name) is None
516+
(
517+
getattr(self, t_name).shape == getattr(src, t_name).shape
518+
if getattr(self, t_name) is not None
519+
else getattr(src, t_name) is None
520+
)
518521
for t_name in self.optional_tensor_data_names
519522
)
520523

@@ -1097,6 +1100,16 @@ def is_sm_at_least_100():
10971100
)
10981101

10991102

1103+
def is_cuda_version_at_least(major: int, minor: int) -> bool:
1104+
if not torch.cuda.is_available():
1105+
return False
1106+
cuda_version = torch.version.cuda
1107+
if cuda_version is None:
1108+
return False
1109+
cuda_major, cuda_minor = map(int, cuda_version.split(".")[:2])
1110+
return (cuda_major, cuda_minor) >= (major, minor)
1111+
1112+
11001113
def check_cpu_version(device, version="2.6.0"):
11011114
if isinstance(device, torch.device):
11021115
device = device.type

0 commit comments

Comments
 (0)