2626from torchao .quantization import quantize_
2727from torchao .quantization .utils import compute_error
2828from torchao .utils import (
29+ is_cuda_version_at_least ,
2930 is_sm_at_least_89 ,
3031 is_sm_at_least_100 ,
3132 torch_version_at_least ,
@@ -50,12 +51,25 @@ def run_around_tests():
5051elem_dtypes = (
5152 [
5253 # test each dtype
53- (torch .float8_e4m3fn , torch .float8_e4m3fn , torch .float8_e4m3fn ),
54+ (
55+ torch .float8_e4m3fn ,
56+ torch .float8_e4m3fn ,
57+ torch .float8_e4m3fn ,
58+ ),
5459 (DTYPE_FP6_E3M2 , DTYPE_FP6_E3M2 , DTYPE_FP6_E3M2 ),
5560 (DTYPE_FP6_E2M3 , DTYPE_FP6_E2M3 , DTYPE_FP6_E2M3 ),
56- (torch .float4_e2m1fn_x2 , torch .float4_e2m1fn_x2 , torch .float4_e2m1fn_x2 ),
57- # only test one type of mixed-dtype overrides, to save testing time
58- (torch .float8_e4m3fn , torch .float4_e2m1fn_x2 , torch .float4_e2m1fn_x2 ),
61+ (
62+ torch .float4_e2m1fn_x2 ,
63+ torch .float4_e2m1fn_x2 ,
64+ torch .float4_e2m1fn_x2 ,
65+ ),
66+ # only test one type of mixed-dtype overrides, to save
67+ # testing time
68+ (
69+ torch .float8_e4m3fn ,
70+ torch .float4_e2m1fn_x2 ,
71+ torch .float4_e2m1fn_x2 ,
72+ ),
5973 ]
6074 if torch_version_at_least ("2.8.0" )
6175 else [
@@ -117,6 +131,8 @@ def test_linear_eager_vs_hp(
117131 pytest .skip ("unsupported configuration" )
118132 elif not is_sm_at_least_100 ():
119133 pytest .skip ("CUDA capability >= 10.0 required for MX dim1 cast cuda kernel" )
134+ elif not is_cuda_version_at_least (12 , 8 ):
135+ pytest .skip ("CUDA version >= 12.8 required for MXFP8 CUDA extension" )
120136
121137 # elem_dtype is a tuple of (input, weight, gradient) dtypes.
122138 grad_shape = list (input_shape )
@@ -166,7 +182,12 @@ def test_linear_eager_vs_hp(
166182
167183@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
168184@pytest .mark .skipif (
169- not is_sm_at_least_100 (), reason = "CUDA capability >= 10.0 required for mxfloat8"
185+ not is_sm_at_least_100 (),
186+ reason = "CUDA capability >= 10.0 required for mxfloat8" ,
187+ )
188+ @pytest .mark .skipif (
189+ not is_cuda_version_at_least (12 , 8 ),
190+ reason = "CUDA version >= 12.8 required for MXFP8" ,
170191)
171192@pytest .mark .parametrize (
172193 "recipe_name" ,
@@ -303,6 +324,10 @@ def test_linear_compile(
303324 ScaleCalculationMode .RCEIL ,
304325 ):
305326 pytest .skip ("unsupported configuration" )
327+ elif not is_sm_at_least_100 ():
328+ pytest .skip ("CUDA capability >= 10.0 required for MX dim1 cast cuda kernel" )
329+ elif not is_cuda_version_at_least (12 , 8 ):
330+ pytest .skip ("CUDA version >= 12.8 required for MXFP8" )
306331
307332 if hp_dtype == torch .bfloat16 and recipe_name != "mxfp8_cublas" :
308333 # TODO(future PR): properly enable float32 + bfloat16 for every
@@ -318,7 +343,8 @@ def test_linear_compile(
318343 ):
319344 # TODO(future): debug this
320345 pytest .skip (
321- "there are currently accuracy issues with this configuration on H100 and below"
346+ "there are currently accuracy issues with this configuration "
347+ "on H100 and below"
322348 )
323349
324350 M , K , N = 128 , 256 , 512
0 commit comments