11import math
2+ from typing import Optional
23
34import pytest
45from hypothesis import given
6+ from hypothesis import strategies as st
57
6- from array_api_tests .typing import DataType
8+ from array_api_tests .typing import Array , DataType
79
8- from . import _array_module as xp
910from . import dtype_helpers as dh
1011from . import hypothesis_helpers as hh
1112from . import pytest_helpers as ph
1213from . import xps
14+ from ._array_module import mod as xp
1315
1416pytestmark = [
1517 pytest .mark .ci ,
2123fft_shapes_strat = hh .shapes (min_dims = 1 ).filter (lambda s : math .prod (s ) > 1 )
2224
2325
26+ def n_axis_norm_kwargs (x : Array , data : st .DataObject ) -> tuple :
27+ size = math .prod (x .shape )
28+ n = data .draw (st .none () | st .integers (size // 2 , size * 2 ), label = "n" )
29+ axis = data .draw (st .integers (- 1 , x .ndim - 1 ), label = "axis" )
30+ norm = data .draw (st .sampled_from (["backward" , "ortho" , "forward" ]), label = "norm" )
31+ kwargs = data .draw (
32+ hh .specified_kwargs (
33+ ("n" , n , None ),
34+ ("axis" , axis , - 1 ),
35+ ("norm" , norm , "backward" ),
36+ ),
37+ label = "kwargs" ,
38+ )
39+ return n , axis , norm , kwargs
40+
41+
2442def assert_fft_dtype (func_name : str , * , in_dtype : DataType , out_dtype : DataType ):
2543 if in_dtype == xp .float32 :
2644 expected = xp .complex64
@@ -34,29 +52,80 @@ def assert_fft_dtype(func_name: str, *, in_dtype: DataType, out_dtype: DataType)
3452 )
3553
3654
37- @given (x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ))
38- def test_fft (x ):
39- out = xp .fft .fft (x )
55+ def assert_n_axis_shape (
56+ func_name : str , * , x : Array , n : Optional [int ], axis : int , out : Array
57+ ):
58+ if n is None :
59+ expected_shape = x .shape
60+ else :
61+ _axis = len (x .shape ) - 1 if axis == - 1 else axis
62+ expected_shape = x .shape [:_axis ] + (n ,) + x .shape [_axis + 1 :]
63+ ph .assert_shape (func_name , out_shape = out .shape , expected = expected_shape )
64+
65+
66+ @given (
67+ x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
68+ data = st .data (),
69+ )
70+ def test_fft (x , data ):
71+ n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
72+
73+ out = xp .fft .fft (x , ** kwargs )
74+
4075 assert_fft_dtype ("fft" , in_dtype = x .dtype , out_dtype = out .dtype )
41- ph .assert_shape ("fft" , out_shape = out .shape , expected = x .shape )
76+ assert_n_axis_shape ("fft" , x = x , n = n , axis = axis , out = out )
77+
4278
79+ @given (
80+ x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
81+ data = st .data (),
82+ )
83+ def test_ifft (x , data ):
84+ n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
85+
86+ out = xp .fft .ifft (x , ** kwargs )
4387
44- @given (x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ))
45- def test_ifft (x ):
46- out = xp .fft .ifft (x )
4788 assert_fft_dtype ("ifft" , in_dtype = x .dtype , out_dtype = out .dtype )
48- ph .assert_shape ("ifft" , out_shape = out .shape , expected = x .shape )
89+ assert_n_axis_shape ("ifft" , x = x , n = n , axis = axis , out = out )
90+
91+
92+ # TODO:
93+ # test_fftn
94+ # test_ifftn
95+
96+
97+ @given (
98+ x = xps .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
99+ data = st .data (),
100+ )
101+ def test_rfft (x , data ):
102+ n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
103+
104+ out = xp .fft .rfft (x , ** kwargs )
105+
106+ assert_fft_dtype ("rfft" , in_dtype = x .dtype , out_dtype = out .dtype )
107+ assert_n_axis_shape ("rfft" , x = x , n = n , axis = axis , out = out )
108+
109+
110+ @given (
111+ x = xps .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
112+ data = st .data (),
113+ )
114+ def test_irfft (x , data ):
115+ n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
49116
117+ out = xp .fft .irfft (x , ** kwargs )
50118
51- @given (x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ))
52- def test_fftn (x ):
53- out = xp .fft .fftn (x )
54- assert_fft_dtype ("fftn" , in_dtype = x .dtype , out_dtype = out .dtype )
55- ph .assert_shape ("fftn" , out_shape = out .shape , expected = x .shape )
119+ assert_fft_dtype ("irfft" , in_dtype = x .dtype , out_dtype = out .dtype )
120+ # TODO: assert shape
56121
57122
58- @given (x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ))
59- def test_ifftn (x ):
60- out = xp .fft .ifftn (x )
61- assert_fft_dtype ("ifftn" , in_dtype = x .dtype , out_dtype = out .dtype )
62- ph .assert_shape ("ifftn" , out_shape = out .shape , expected = x .shape )
123+ # TODO:
124+ # test_rfftn
125+ # test_irfftn
126+ # test_hfft
127+ # test_ihfft
128+ # fftfreq
129+ # rfftfreq
130+ # fftshift
131+ # ifftshift
0 commit comments