Skip to content

Commit 04ccac1

Browse files
committed
Put target("avx2") pragma and __attribute__((always_inline)) on affected lambdas
1 parent 5c4d741 commit 04ccac1

File tree

10 files changed

+82
-61
lines changed

10 files changed

+82
-61
lines changed

cp-algo/linalg/matrix.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#ifndef CP_ALGO_LINALG_MATRIX_HPP
22
#define CP_ALGO_LINALG_MATRIX_HPP
3+
#pragma GCC push_options
4+
#pragma GCC target("avx2")
35
#include "../random/rng.hpp"
46
#include "../math/common.hpp"
57
#include "vector.hpp"
@@ -304,4 +306,5 @@ namespace cp_algo::linalg {
304306
template<typename base_t>
305307
auto operator *(base_t t, matrix<base_t> const& A) {return A * t;}
306308
}
309+
#pragma GCC pop_options
307310
#endif // CP_ALGO_LINALG_MATRIX_HPP

cp-algo/linalg/vector.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#ifndef CP_ALGO_LINALG_VECTOR_HPP
22
#define CP_ALGO_LINALG_VECTOR_HPP
3+
#pragma GCC push_options
4+
#pragma GCC target("avx2")
35
#include "../random/rng.hpp"
46
#include "../number_theory/modint.hpp"
57
#include "../util/big_alloc.hpp"
@@ -152,4 +154,5 @@ namespace cp_algo::linalg {
152154
size_t counter = 0;
153155
};
154156
}
157+
#pragma GCC pop_options
155158
#endif // CP_ALGO_LINALG_VECTOR_HPP

cp-algo/math/cvector.hpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#ifndef CP_ALGO_MATH_CVECTOR_HPP
22
#define CP_ALGO_MATH_CVECTOR_HPP
3+
#pragma GCC push_options
4+
#pragma GCC target("avx2")
35
#include "../util/simd.hpp"
46
#include "../util/complex.hpp"
57
#include "../util/checkpoint.hpp"
@@ -15,7 +17,7 @@ namespace cp_algo::math::fft {
1517
using point = complex<ftype>;
1618
using vpoint = complex<vftype>;
1719
static constexpr vftype vz = {};
18-
simd_target vpoint vi(vpoint const& r) {
20+
vpoint vi(vpoint const& r) {
1921
return {-imag(r), real(r)};
2022
}
2123

@@ -30,7 +32,7 @@ namespace cp_algo::math::fft {
3032
vpoint& at(size_t k) {return r[k / flen];}
3133
vpoint at(size_t k) const {return r[k / flen];}
3234
template<class pt = point>
33-
simd_inline void set(size_t k, pt const& t) {
35+
inline void set(size_t k, pt const& t) {
3436
if constexpr(std::is_same_v<pt, point>) {
3537
real(r[k / flen])[k % flen] = real(t);
3638
imag(r[k / flen])[k % flen] = imag(t);
@@ -39,7 +41,7 @@ namespace cp_algo::math::fft {
3941
}
4042
}
4143
template<class pt = point>
42-
simd_inline pt get(size_t k) const {
44+
inline pt get(size_t k) const {
4345
if constexpr(std::is_same_v<pt, point>) {
4446
return {real(r[k / flen])[k % flen], imag(r[k / flen])[k % flen]};
4547
} else {
@@ -79,18 +81,18 @@ namespace cp_algo::math::fft {
7981
return roots[std::bit_width(n)];
8082
}
8183
template<int step>
82-
simd_target static void exec_on_eval(size_t n, size_t k, auto &&callback) {
84+
static void exec_on_eval(size_t n, size_t k, auto &&callback) {
8385
callback(k, root(4 * step * n) * eval_point(step * k));
8486
}
8587
template<int step>
86-
simd_target static void exec_on_evals(size_t n, auto &&callback) {
88+
static void exec_on_evals(size_t n, auto &&callback) {
8789
point factor = root(4 * step * n);
8890
for(size_t i = 0; i < n; i++) {
8991
callback(i, factor * eval_point(step * i));
9092
}
9193
}
9294

93-
simd_target static void do_dot_iter(point rt, vpoint& Bv, vpoint const& Av, vpoint& res) {
95+
static void do_dot_iter(point rt, vpoint& Bv, vpoint const& Av, vpoint& res) {
9496
res += Av * Bv;
9597
real(Bv) = rotate_right(real(Bv));
9698
imag(Bv) = rotate_right(imag(Bv));
@@ -99,9 +101,9 @@ namespace cp_algo::math::fft {
99101
imag(Bv)[0] = x * imag(rt) + y * real(rt);
100102
}
101103

102-
simd_target void dot(cvector const& t) {
104+
void dot(cvector const& t) {
103105
size_t n = this->size();
104-
exec_on_evals<1>(n / flen, [&](size_t k, point rt) {
106+
exec_on_evals<1>(n / flen, [&](size_t k, point rt) __attribute__((always_inline)) {
105107
k *= flen;
106108
auto [Ax, Ay] = at(k);
107109
auto Bv = t.at(k);
@@ -115,11 +117,11 @@ namespace cp_algo::math::fft {
115117
checkpoint("dot");
116118
}
117119
template<bool partial = true>
118-
simd_target void ifft() {
120+
void ifft() {
119121
size_t n = size();
120122
if constexpr (!partial) {
121123
point pi(0, 1);
122-
exec_on_evals<4>(n / 4, [&](size_t k, point rt) {
124+
exec_on_evals<4>(n / 4, [&](size_t k, point rt) __attribute__((always_inline)) {
123125
k *= 4;
124126
point v1 = conj(rt);
125127
point v2 = v1 * v1;
@@ -136,7 +138,7 @@ namespace cp_algo::math::fft {
136138
}
137139
bool parity = std::countr_zero(n) % 2;
138140
if(parity) {
139-
exec_on_evals<2>(n / (2 * flen), [&](size_t k, point rt) {
141+
exec_on_evals<2>(n / (2 * flen), [&](size_t k, point rt) __attribute__((always_inline)) {
140142
k *= 2 * flen;
141143
vpoint cvrt = {vz + real(rt), vz - imag(rt)};
142144
auto B = at(k) - at(k + flen);
@@ -149,7 +151,7 @@ namespace cp_algo::math::fft {
149151
size_t level = std::countr_one(leaf + 3);
150152
for(size_t lvl = 4 + parity; lvl <= level; lvl += 2) {
151153
size_t i = (1 << lvl) / 4;
152-
exec_on_eval<4>(n >> lvl, leaf >> lvl, [&](size_t k, point rt) {
154+
exec_on_eval<4>(n >> lvl, leaf >> lvl, [&](size_t k, point rt) __attribute__((always_inline)) {
153155
k <<= lvl;
154156
vpoint v1 = {vz + real(rt), vz - imag(rt)};
155157
vpoint v2 = v1 * v1;
@@ -177,15 +179,15 @@ namespace cp_algo::math::fft {
177179
}
178180
}
179181
template<bool partial = true>
180-
simd_target void fft() {
182+
void fft() {
181183
size_t n = size();
182184
bool parity = std::countr_zero(n) % 2;
183185
for(size_t leaf = 0; leaf < n; leaf += 4 * flen) {
184186
size_t level = std::countr_zero(n + leaf);
185187
level -= level % 2 != parity;
186188
for(size_t lvl = level; lvl >= 4; lvl -= 2) {
187189
size_t i = (1 << lvl) / 4;
188-
exec_on_eval<4>(n >> lvl, leaf >> lvl, [&](size_t k, point rt) {
190+
exec_on_eval<4>(n >> lvl, leaf >> lvl, [&](size_t k, point rt) __attribute__((always_inline)) {
189191
k <<= lvl;
190192
vpoint v1 = {vz + real(rt), vz + imag(rt)};
191193
vpoint v2 = v1 * v1;
@@ -204,7 +206,7 @@ namespace cp_algo::math::fft {
204206
}
205207
}
206208
if(parity) {
207-
exec_on_evals<2>(n / (2 * flen), [&](size_t k, point rt) {
209+
exec_on_evals<2>(n / (2 * flen), [&](size_t k, point rt) __attribute__((always_inline)) {
208210
k *= 2 * flen;
209211
vpoint vrt = {vz + real(rt), vz + imag(rt)};
210212
auto t = at(k + flen) * vrt;
@@ -214,7 +216,7 @@ namespace cp_algo::math::fft {
214216
}
215217
if constexpr (!partial) {
216218
point pi(0, 1);
217-
exec_on_evals<4>(n / 4, [&](size_t k, point rt) {
219+
exec_on_evals<4>(n / 4, [&](size_t k, point rt) __attribute__((always_inline)) {
218220
k *= 4;
219221
point v1 = rt;
220222
point v2 = v1 * v1;
@@ -252,4 +254,5 @@ namespace cp_algo::math::fft {
252254
return res;
253255
}();
254256
}
257+
#pragma GCC pop_options
255258
#endif // CP_ALGO_MATH_CVECTOR_HPP

cp-algo/math/factorials.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#ifndef CP_ALGO_MATH_FACTORIALS_HPP
22
#define CP_ALGO_MATH_FACTORIALS_HPP
3+
#pragma GCC push_options
4+
#pragma GCC target("avx2")
35
#include "../util/checkpoint.hpp"
46
#include "../util/bump_alloc.hpp"
57
#include "../util/simd.hpp"
@@ -9,7 +11,7 @@
911

1012
namespace cp_algo::math {
1113
template<bool use_bump_alloc = false, int maxn = -1>
12-
simd_target auto facts(auto const& args) {
14+
auto facts(auto const& args) {
1315
static_assert(!use_bump_alloc || maxn > 0, "maxn must be set if use_bump_alloc is true");
1416
constexpr int max_mod = 1'000'000'000;
1517
constexpr int accum = 4;
@@ -93,4 +95,5 @@ namespace cp_algo::math {
9395
return res;
9496
}
9597
}
98+
#pragma GCC pop_options
9699
#endif // CP_ALGO_MATH_FACTORIALS_HPP

cp-algo/math/fft.hpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#ifndef CP_ALGO_MATH_FFT_HPP
22
#define CP_ALGO_MATH_FFT_HPP
3+
#pragma GCC push_options
4+
#pragma GCC target("avx2")
35
#include "../number_theory/modint.hpp"
46
#include "../util/checkpoint.hpp"
57
#include "../random/rng.hpp"
@@ -29,7 +31,7 @@ namespace cp_algo::math::fft {
2931
}
3032
}
3133

32-
simd_target static std::pair<vftype, vftype>
34+
static std::pair<vftype, vftype>
3335
do_split(auto const& a, size_t idx, u64x4 mul) {
3436
if(idx >= std::size(a)) {
3537
return std::pair{vftype(), vftype()};
@@ -48,7 +50,7 @@ namespace cp_algo::math::fft {
4850
}
4951

5052
dft(size_t n): A(n), B(n) {init();}
51-
simd_target dft(auto const& a, size_t n, bool partial = true): A(n), B(n) {
53+
dft(auto const& a, size_t n, bool partial = true): A(n), B(n) {
5254
init();
5355
base b2x32 = bpow(base(2), 32);
5456
u64x4 cur = {
@@ -77,7 +79,7 @@ namespace cp_algo::math::fft {
7779
}
7880
}
7981
}
80-
simd_target static void do_dot_iter(point rt, vpoint& Cv, vpoint& Dv, vpoint const& Av, vpoint const& Bv, vpoint& AC, vpoint& AD, vpoint& BC, vpoint& BD) {
82+
static void do_dot_iter(point rt, vpoint& Cv, vpoint& Dv, vpoint const& Av, vpoint const& Bv, vpoint& AC, vpoint& AD, vpoint& BC, vpoint& BD) {
8183
AC += Av * Cv; AD += Av * Dv;
8284
BC += Bv * Cv; BD += Bv * Dv;
8385
real(Cv) = rotate_right(real(Cv));
@@ -93,8 +95,8 @@ namespace cp_algo::math::fft {
9395
}
9496

9597
template<bool overwrite = true, bool partial = true>
96-
simd_target void dot(auto const& C, auto const& D, auto &Aout, auto &Bout, auto &Cout) const {
97-
cvector::exec_on_evals<1>(A.size() / flen, [&](size_t k, point rt) {
98+
void dot(auto const& C, auto const& D, auto &Aout, auto &Bout, auto &Cout) const {
99+
cvector::exec_on_evals<1>(A.size() / flen, [&](size_t k, point rt) __attribute__((always_inline)) {
98100
k *= flen;
99101
vpoint AC, AD, BC, BD;
100102
AC = AD = BC = BD = vz;
@@ -125,11 +127,11 @@ namespace cp_algo::math::fft {
125127
checkpoint("dot");
126128
}
127129

128-
[[gnu::target("avx2")]] void dot(auto &&C, auto const& D) {
130+
void dot(auto &&C, auto const& D) {
129131
dot(C, D, A, B, C);
130132
}
131133

132-
simd_target static void do_recover_iter(size_t idx, auto A, auto B, auto C, auto mul, uint64_t splitsplit, auto &res) {
134+
static void do_recover_iter(size_t idx, auto A, auto B, auto C, auto mul, uint64_t splitsplit, auto &res) {
133135
auto A0 = lround(A), A1 = lround(C), A2 = lround(B);
134136
auto Ai = A0 + A1 * split() + A2 * splitsplit + uint64_t(base::modmod());
135137
auto Au = montgomery_reduce(u64x4(Ai), mod, imod);
@@ -140,7 +142,7 @@ namespace cp_algo::math::fft {
140142
}
141143
}
142144

143-
simd_target void recover_mod(auto &&C, auto &res, size_t k) {
145+
void recover_mod(auto &&C, auto &res, size_t k) {
144146
size_t check = (k + flen - 1) / flen * flen;
145147
assert(res.size() >= check);
146148
size_t n = A.size();
@@ -168,7 +170,7 @@ namespace cp_algo::math::fft {
168170
checkpoint("recover mod");
169171
}
170172

171-
simd_target void mul(auto &&C, auto const& D, auto &res, size_t k) {
173+
void mul(auto &&C, auto const& D, auto &res, size_t k) {
172174
assert(A.size() == C.size());
173175
size_t n = A.size();
174176
if(!n) {
@@ -181,10 +183,10 @@ namespace cp_algo::math::fft {
181183
C.ifft();
182184
recover_mod(C, res, k);
183185
}
184-
simd_target void mul_inplace(auto &&B, auto& res, size_t k) {
186+
void mul_inplace(auto &&B, auto& res, size_t k) {
185187
mul(B.A, B.B, res, k);
186188
}
187-
simd_target void mul(auto const& B, auto& res, size_t k) {
189+
void mul(auto const& B, auto& res, size_t k) {
188190
mul(cvector(B.A), B.B, res, k);
189191
}
190192
big_vector<base> operator *= (dft &B) {
@@ -209,7 +211,7 @@ namespace cp_algo::math::fft {
209211
template<modint_type base> uint32_t dft<base>::mod = {};
210212
template<modint_type base> uint32_t dft<base>::imod = {};
211213

212-
[[gnu::target("avx2")]] void mul_slow(auto &a, auto const& b, size_t k) {
214+
void mul_slow(auto &a, auto const& b, size_t k) {
213215
if(std::empty(a) || std::empty(b)) {
214216
a.clear();
215217
} else {
@@ -230,7 +232,7 @@ namespace cp_algo::math::fft {
230232
}
231233
return std::max(flen, std::bit_ceil(as + bs - 1) / 2);
232234
}
233-
[[gnu::target("avx2")]] void mul_truncate(auto &a, auto const& b, size_t k) {
235+
void mul_truncate(auto &a, auto const& b, size_t k) {
234236
using base = std::decay_t<decltype(a[0])>;
235237
if(std::min({k, std::size(a), std::size(b)}) < magic) {
236238
mul_slow(a, b, k);
@@ -247,7 +249,7 @@ namespace cp_algo::math::fft {
247249
}
248250

249251
// store mod x^n-k in first half, x^n+k in second half
250-
simd_target void mod_split(auto &&x, size_t n, auto k) {
252+
void mod_split(auto &&x, size_t n, auto k) {
251253
using base = std::decay_t<decltype(k)>;
252254
dft<base>::init();
253255
assert(std::size(x) == 2 * n);
@@ -279,7 +281,7 @@ namespace cp_algo::math::fft {
279281
}
280282
cp_algo::checkpoint("mod split");
281283
}
282-
[[gnu::target("avx2")]] void cyclic_mul(auto &a, auto &&b, size_t k) {
284+
void cyclic_mul(auto &a, auto &&b, size_t k) {
283285
assert(std::popcount(k) == 1);
284286
assert(std::size(a) == std::size(b) && std::size(a) == k);
285287
using base = std::decay_t<decltype(a[0])>;
@@ -312,13 +314,13 @@ namespace cp_algo::math::fft {
312314
}
313315
cp_algo::checkpoint("mod join");
314316
}
315-
[[gnu::target("avx2")]] auto make_copy(auto &&x) {
317+
auto make_copy(auto &&x) {
316318
return x;
317319
}
318-
[[gnu::target("avx2")]] void cyclic_mul(auto &a, auto const& b, size_t k) {
320+
void cyclic_mul(auto &a, auto const& b, size_t k) {
319321
return cyclic_mul(a, make_copy(b), k);
320322
}
321-
[[gnu::target("avx2")]] void mul(auto &a, auto &&b) {
323+
void mul(auto &a, auto &&b) {
322324
size_t N = size(a) + size(b);
323325
if(N > (1 << 20)) {
324326
N--;
@@ -331,7 +333,7 @@ namespace cp_algo::math::fft {
331333
mul_truncate(a, b, N - 1);
332334
}
333335
}
334-
[[gnu::target("avx2")]] void mul(auto &a, auto const& b) {
336+
void mul(auto &a, auto const& b) {
335337
size_t N = size(a) + size(b);
336338
if(N > (1 << 20)) {
337339
mul(a, make_copy(b));
@@ -340,4 +342,5 @@ namespace cp_algo::math::fft {
340342
}
341343
}
342344
}
345+
#pragma GCC pop_options
343346
#endif // CP_ALGO_MATH_FFT_HPP

cp-algo/math/fft64.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#ifndef CP_ALGO_MATH_FFT64_HPP
22
#define CP_ALGO_MATH_FFT64_HPP
3+
#pragma GCC push_options
4+
#pragma GCC target("avx2")
35
#include "../random/rng.hpp"
46
#include "../math/common.hpp"
57
#include "../math/cvector.hpp"
@@ -46,7 +48,7 @@ namespace cp_algo::math::fft {
4648
}
4749
}
4850

49-
simd_target static void do_dot_iter(point rt, std::array<vpoint, 4>& B, std::array<vpoint, 4> const& A, std::array<vpoint, 4>& C) {
51+
static void do_dot_iter(point rt, std::array<vpoint, 4>& B, std::array<vpoint, 4> const& A, std::array<vpoint, 4>& C) {
5052
for(size_t k = 0; k < 4; k++) {
5153
for(size_t i = 0; i <= k; i++) {
5254
C[k] += A[i] * B[k - i];
@@ -63,7 +65,7 @@ namespace cp_algo::math::fft {
6365

6466
void dot(dft64 const& t) {
6567
size_t N = cv[0].size();
66-
cvector::exec_on_evals<1>(N / flen, [&](size_t k, point rt) {
68+
cvector::exec_on_evals<1>(N / flen, [&](size_t k, point rt) __attribute__((always_inline)) {
6769
k *= flen;
6870
auto [A0x, A0y] = cv[0].at(k);
6971
auto [A1x, A1y] = cv[1].at(k);
@@ -127,4 +129,5 @@ namespace cp_algo::math::fft {
127129
A.recover_mod(a, n + m - 1);
128130
}
129131
}
132+
#pragma GCC pop_options
130133
#endif // CP_ALGO_MATH_FFT64_HPP

0 commit comments

Comments
 (0)