11#ifndef CP_ALGO_MATH_FFT_HPP
22#define CP_ALGO_MATH_FFT_HPP
3+ #pragma GCC push_options
4+ #pragma GCC target("avx2")
35#include " ../number_theory/modint.hpp"
46#include " ../util/checkpoint.hpp"
57#include " ../random/rng.hpp"
@@ -29,7 +31,7 @@ namespace cp_algo::math::fft {
2931 }
3032 }
3133
32- simd_target static std::pair<vftype, vftype>
34+ static std::pair<vftype, vftype>
3335 do_split (auto const & a, size_t idx, u64x4 mul) {
3436 if (idx >= std::size (a)) {
3537 return std::pair{vftype (), vftype ()};
@@ -48,7 +50,7 @@ namespace cp_algo::math::fft {
4850 }
4951
5052 dft (size_t n): A(n), B(n) {init ();}
51- simd_target dft (auto const & a, size_t n, bool partial = true ): A(n), B(n) {
53+ dft (auto const & a, size_t n, bool partial = true ): A(n), B(n) {
5254 init ();
5355 base b2x32 = bpow (base (2 ), 32 );
5456 u64x4 cur = {
@@ -77,7 +79,7 @@ namespace cp_algo::math::fft {
7779 }
7880 }
7981 }
80- simd_target static void do_dot_iter (point rt, vpoint& Cv, vpoint& Dv, vpoint const & Av, vpoint const & Bv, vpoint& AC, vpoint& AD, vpoint& BC, vpoint& BD) {
82+ static void do_dot_iter (point rt, vpoint& Cv, vpoint& Dv, vpoint const & Av, vpoint const & Bv, vpoint& AC, vpoint& AD, vpoint& BC, vpoint& BD) {
8183 AC += Av * Cv; AD += Av * Dv;
8284 BC += Bv * Cv; BD += Bv * Dv;
8385 real (Cv) = rotate_right (real (Cv));
@@ -93,8 +95,8 @@ namespace cp_algo::math::fft {
9395 }
9496
9597 template <bool overwrite = true , bool partial = true >
96- simd_target void dot (auto const & C, auto const & D, auto &Aout, auto &Bout, auto &Cout) const {
97- cvector::exec_on_evals<1 >(A.size () / flen, [&](size_t k, point rt) {
98+ void dot (auto const & C, auto const & D, auto &Aout, auto &Bout, auto &Cout) const {
99+ cvector::exec_on_evals<1 >(A.size () / flen, [&](size_t k, point rt) __attribute__ ((always_inline)) {
98100 k *= flen;
99101 vpoint AC, AD, BC, BD;
100102 AC = AD = BC = BD = vz;
@@ -125,11 +127,11 @@ namespace cp_algo::math::fft {
125127 checkpoint (" dot" );
126128 }
127129
128- [[gnu::target( " avx2 " )]] void dot (auto &&C, auto const & D) {
130+ void dot (auto &&C, auto const & D) {
129131 dot (C, D, A, B, C);
130132 }
131133
132- simd_target static void do_recover_iter (size_t idx, auto A, auto B, auto C, auto mul, uint64_t splitsplit, auto &res) {
134+ static void do_recover_iter (size_t idx, auto A, auto B, auto C, auto mul, uint64_t splitsplit, auto &res) {
133135 auto A0 = lround (A), A1 = lround (C), A2 = lround (B);
134136 auto Ai = A0 + A1 * split () + A2 * splitsplit + uint64_t (base::modmod ());
135137 auto Au = montgomery_reduce (u64x4 (Ai), mod, imod);
@@ -140,7 +142,7 @@ namespace cp_algo::math::fft {
140142 }
141143 }
142144
143- simd_target void recover_mod (auto &&C, auto &res, size_t k) {
145+ void recover_mod (auto &&C, auto &res, size_t k) {
144146 size_t check = (k + flen - 1 ) / flen * flen;
145147 assert (res.size () >= check);
146148 size_t n = A.size ();
@@ -168,7 +170,7 @@ namespace cp_algo::math::fft {
168170 checkpoint (" recover mod" );
169171 }
170172
171- simd_target void mul (auto &&C, auto const & D, auto &res, size_t k) {
173+ void mul (auto &&C, auto const & D, auto &res, size_t k) {
172174 assert (A.size () == C.size ());
173175 size_t n = A.size ();
174176 if (!n) {
@@ -181,10 +183,10 @@ namespace cp_algo::math::fft {
181183 C.ifft ();
182184 recover_mod (C, res, k);
183185 }
184- simd_target void mul_inplace (auto &&B, auto & res, size_t k) {
186+ void mul_inplace (auto &&B, auto & res, size_t k) {
185187 mul (B.A , B.B , res, k);
186188 }
187- simd_target void mul (auto const & B, auto & res, size_t k) {
189+ void mul (auto const & B, auto & res, size_t k) {
188190 mul (cvector (B.A ), B.B , res, k);
189191 }
190192 big_vector<base> operator *= (dft &B) {
@@ -209,7 +211,7 @@ namespace cp_algo::math::fft {
209211 template <modint_type base> uint32_t dft<base>::mod = {};
210212 template <modint_type base> uint32_t dft<base>::imod = {};
211213
212- [[gnu::target( " avx2 " )]] void mul_slow (auto &a, auto const & b, size_t k) {
214+ void mul_slow (auto &a, auto const & b, size_t k) {
213215 if (std::empty (a) || std::empty (b)) {
214216 a.clear ();
215217 } else {
@@ -230,7 +232,7 @@ namespace cp_algo::math::fft {
230232 }
231233 return std::max (flen, std::bit_ceil (as + bs - 1 ) / 2 );
232234 }
233- [[gnu::target( " avx2 " )]] void mul_truncate (auto &a, auto const & b, size_t k) {
235+ void mul_truncate (auto &a, auto const & b, size_t k) {
234236 using base = std::decay_t <decltype (a[0 ])>;
235237 if (std::min ({k, std::size (a), std::size (b)}) < magic) {
236238 mul_slow (a, b, k);
@@ -247,7 +249,7 @@ namespace cp_algo::math::fft {
247249 }
248250
249251 // store mod x^n-k in first half, x^n+k in second half
250- simd_target void mod_split (auto &&x, size_t n, auto k) {
252+ void mod_split (auto &&x, size_t n, auto k) {
251253 using base = std::decay_t <decltype (k)>;
252254 dft<base>::init ();
253255 assert (std::size (x) == 2 * n);
@@ -279,7 +281,7 @@ namespace cp_algo::math::fft {
279281 }
280282 cp_algo::checkpoint (" mod split" );
281283 }
282- [[gnu::target( " avx2 " )]] void cyclic_mul (auto &a, auto &&b, size_t k) {
284+ void cyclic_mul (auto &a, auto &&b, size_t k) {
283285 assert (std::popcount (k) == 1 );
284286 assert (std::size (a) == std::size (b) && std::size (a) == k);
285287 using base = std::decay_t <decltype (a[0 ])>;
@@ -312,13 +314,13 @@ namespace cp_algo::math::fft {
312314 }
313315 cp_algo::checkpoint (" mod join" );
314316 }
315- [[gnu::target( " avx2 " )]] auto make_copy (auto &&x) {
317+ auto make_copy (auto &&x) {
316318 return x;
317319 }
318- [[gnu::target( " avx2 " )]] void cyclic_mul (auto &a, auto const & b, size_t k) {
320+ void cyclic_mul (auto &a, auto const & b, size_t k) {
319321 return cyclic_mul (a, make_copy (b), k);
320322 }
321- [[gnu::target( " avx2 " )]] void mul (auto &a, auto &&b) {
323+ void mul (auto &a, auto &&b) {
322324 size_t N = size (a) + size (b);
323325 if (N > (1 << 20 )) {
324326 N--;
@@ -331,7 +333,7 @@ namespace cp_algo::math::fft {
331333 mul_truncate (a, b, N - 1 );
332334 }
333335 }
334- [[gnu::target( " avx2 " )]] void mul (auto &a, auto const & b) {
336+ void mul (auto &a, auto const & b) {
335337 size_t N = size (a) + size (b);
336338 if (N > (1 << 20 )) {
337339 mul (a, make_copy (b));
@@ -340,4 +342,5 @@ namespace cp_algo::math::fft {
340342 }
341343 }
342344}
345+ #pragma GCC pop_options
343346#endif // CP_ALGO_MATH_FFT_HPP
0 commit comments