88from torch .utils ._triton import has_triton
99
1010from torchao .dtypes .uintx .bitpacking import pack , pack_cpu , unpack , unpack_cpu
11+ from torchao .utils import get_current_accelerator_device
1112
1213bit_widths = (1 , 2 , 3 , 4 , 5 , 6 , 7 )
1314dimensions = (0 , - 1 , 1 )
15+ _DEVICE = get_current_accelerator_device ()
1416
1517
1618@pytest .fixture (autouse = True )
@@ -30,40 +32,46 @@ def test_CPU(bit_width, dim):
3032 assert unpacked .allclose (test_tensor )
3133
3234
33- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
35+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
3436@pytest .mark .parametrize ("bit_width" , bit_widths )
3537@pytest .mark .parametrize ("dim" , dimensions )
3638def test_GPU (bit_width , dim ):
37- test_tensor = torch .randint (0 , 2 ** bit_width , (32 , 32 , 32 ), dtype = torch .uint8 ).cuda ()
39+ test_tensor = torch .randint (0 , 2 ** bit_width , (32 , 32 , 32 ), dtype = torch .uint8 ).to (
40+ _DEVICE
41+ )
3842 packed = pack (test_tensor , bit_width , dim = dim )
3943 unpacked = unpack (packed , bit_width , dim = dim )
4044 assert unpacked .allclose (test_tensor )
4145
4246
43- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
47+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
4448@pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
4549@pytest .mark .parametrize ("bit_width" , bit_widths )
4650@pytest .mark .parametrize ("dim" , dimensions )
4751def test_compile (bit_width , dim ):
4852 torch ._dynamo .config .specialize_int = True
4953 torch .compile (pack , fullgraph = True )
5054 torch .compile (unpack , fullgraph = True )
51- test_tensor = torch .randint (0 , 2 ** bit_width , (32 , 32 , 32 ), dtype = torch .uint8 ).cuda ()
55+ test_tensor = torch .randint (0 , 2 ** bit_width , (32 , 32 , 32 ), dtype = torch .uint8 ).to (
56+ _DEVICE
57+ )
5258 packed = pack (test_tensor , bit_width , dim = dim )
5359 unpacked = unpack (packed , bit_width , dim = dim )
5460 assert unpacked .allclose (test_tensor )
5561
5662
5763# these test cases are for the example pack walk through in the bitpacking.py file
58- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
64+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
5965def test_pack_example ():
6066 test_tensor = torch .tensor (
6167 [0x30 , 0x29 , 0x17 , 0x5 , 0x20 , 0x16 , 0x9 , 0x22 ], dtype = torch .uint8
62- ).cuda ( )
68+ ).to ( _DEVICE )
6369 shard_4 , shard_2 = pack (test_tensor , 6 )
6470 print (shard_4 , shard_2 )
65- assert torch .tensor ([0 , 105 , 151 , 37 ], dtype = torch .uint8 ).cuda ().allclose (shard_4 )
66- assert torch .tensor ([39 , 146 ], dtype = torch .uint8 ).cuda ().allclose (shard_2 )
71+ assert (
72+ torch .tensor ([0 , 105 , 151 , 37 ], dtype = torch .uint8 ).to (_DEVICE ).allclose (shard_4 )
73+ )
74+ assert torch .tensor ([39 , 146 ], dtype = torch .uint8 ).to (_DEVICE ).allclose (shard_2 )
6775 unpacked = unpack ([shard_4 , shard_2 ], 6 )
6876 assert unpacked .allclose (test_tensor )
6977
0 commit comments