11import math
2- from typing import Optional
2+ from typing import List , Optional
33
44import pytest
55from hypothesis import given
1010from . import dtype_helpers as dh
1111from . import hypothesis_helpers as hh
1212from . import pytest_helpers as ph
13+ from . import shape_helpers as sh
1314from . import xps
1415from ._array_module import mod as xp
1516
2324fft_shapes_strat = hh .shapes (min_dims = 1 ).filter (lambda s : math .prod (s ) > 1 )
2425
2526
26- def n_axis_norm_kwargs (x : Array , data : st .DataObject ) -> tuple :
27+ def draw_n_axis_norm_kwargs (x : Array , data : st .DataObject ) -> tuple :
2728 size = math .prod (x .shape )
28- n = data .draw (st .none () | st .integers (size // 2 , size * 2 ), label = "n" )
29+ n = data .draw (st .none () | st .integers (( size // 2 ), math . ceil ( size * 1.5 ) ), label = "n" )
2930 axis = data .draw (st .integers (- 1 , x .ndim - 1 ), label = "axis" )
3031 norm = data .draw (st .sampled_from (["backward" , "ortho" , "forward" ]), label = "norm" )
3132 kwargs = data .draw (
@@ -39,6 +40,32 @@ def n_axis_norm_kwargs(x: Array, data: st.DataObject) -> tuple:
3940 return n , axis , norm , kwargs
4041
4142
43+ def draw_s_axes_norm_kwargs (x : Array , data : st .DataObject ) -> tuple :
44+ all_axes = list (range (x .ndim ))
45+ axes = data .draw (
46+ st .none () | st .lists (st .sampled_from (all_axes ), min_size = 1 , unique = True ),
47+ label = "axes" ,
48+ )
49+ _axes = all_axes if axes is None else axes
50+ axes_sides = [x .shape [axis ] for axis in _axes ]
51+ s_strat = st .tuples (
52+ * [st .integers (max (side // 2 , 1 ), math .ceil (side * 1.5 )) for side in axes_sides ]
53+ )
54+ if axes is None :
55+ s_strat = st .none () | s_strat
56+ s = data .draw (s_strat , label = "s" )
57+ norm = data .draw (st .sampled_from (["backward" , "ortho" , "forward" ]), label = "norm" )
58+ kwargs = data .draw (
59+ hh .specified_kwargs (
60+ ("s" , s , None ),
61+ ("axes" , axes , None ),
62+ ("norm" , norm , "backward" ),
63+ ),
64+ label = "kwargs" ,
65+ )
66+ return s , axes , norm , kwargs
67+
68+
4269def assert_fft_dtype (func_name : str , * , in_dtype : DataType , out_dtype : DataType ):
4370 if in_dtype == xp .float32 :
4471 expected = xp .complex64
@@ -63,12 +90,32 @@ def assert_n_axis_shape(
6390 ph .assert_shape (func_name , out_shape = out .shape , expected = expected_shape )
6491
6592
93+ def assert_s_axes_shape (
94+ func_name : str ,
95+ * ,
96+ x : Array ,
97+ s : Optional [List [int ]],
98+ axes : Optional [List [int ]],
99+ out : Array ,
100+ ):
101+ _axes = sh .normalise_axis (axes , x .ndim )
102+ _s = x .shape if s is None else s
103+ expected = []
104+ for i in range (x .ndim ):
105+ if i in _axes :
106+ side = _s [_axes .index (i )]
107+ else :
108+ side = x .shape [i ]
109+ expected .append (side )
110+ ph .assert_shape (func_name , out_shape = out .shape , expected = tuple (expected ))
111+
112+
66113@given (
67114 x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
68115 data = st .data (),
69116)
70117def test_fft (x , data ):
71- n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
118+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
72119
73120 out = xp .fft .fft (x , ** kwargs )
74121
@@ -81,25 +128,46 @@ def test_fft(x, data):
81128 data = st .data (),
82129)
83130def test_ifft (x , data ):
84- n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
131+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
85132
86133 out = xp .fft .ifft (x , ** kwargs )
87134
88135 assert_fft_dtype ("ifft" , in_dtype = x .dtype , out_dtype = out .dtype )
89136 assert_n_axis_shape ("ifft" , x = x , n = n , axis = axis , out = out )
90137
91138
92- # TODO:
93- # test_fftn
94- # test_ifftn
139+ @given (
140+ x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
141+ data = st .data (),
142+ )
143+ def test_fftn (x , data ):
144+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
145+
146+ out = xp .fft .fftn (x , ** kwargs )
147+
148+ assert_fft_dtype ("fftn" , in_dtype = x .dtype , out_dtype = out .dtype )
149+ assert_s_axes_shape ("fftn" , x = x , s = s , axes = axes , out = out )
150+
151+
152+ @given (
153+ x = xps .arrays (dtype = hh .all_floating_dtypes (), shape = fft_shapes_strat ),
154+ data = st .data (),
155+ )
156+ def test_ifftn (x , data ):
157+ s , axes , norm , kwargs = draw_s_axes_norm_kwargs (x , data )
158+
159+ out = xp .fft .ifftn (x , ** kwargs )
160+
161+ assert_fft_dtype ("ifftn" , in_dtype = x .dtype , out_dtype = out .dtype )
162+ assert_s_axes_shape ("ifftn" , x = x , s = s , axes = axes , out = out )
95163
96164
97165@given (
98166 x = xps .arrays (dtype = xps .floating_dtypes (), shape = fft_shapes_strat ),
99167 data = st .data (),
100168)
101169def test_rfft (x , data ):
102- n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
170+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
103171
104172 out = xp .fft .rfft (x , ** kwargs )
105173
@@ -112,7 +180,7 @@ def test_rfft(x, data):
112180 data = st .data (),
113181)
114182def test_irfft (x , data ):
115- n , axis , norm , kwargs = n_axis_norm_kwargs (x , data )
183+ n , axis , norm , kwargs = draw_n_axis_norm_kwargs (x , data )
116184
117185 out = xp .fft .irfft (x , ** kwargs )
118186
0 commit comments