Skip to content

Commit 9931203

Browse files
committed
Fix some minor details in comments for AVX2 decompose
- The floor() in floor((f + 127) >> 7) was somewhat unecessary as the usual semantic for the right-shift operator (>>) has integer output anyway. Seeing as the right-shift operator is not used in other explanation comments, we decided to rewrite it as division by 2^7 for better consistency. - The bound of f1'' is correct but the proof was misleading. The new proof should be clearer. Signed-off-by: jammychiou1 <jammy.chiou1@gmail.com>
1 parent c707054 commit 9931203

File tree

4 files changed

+16
-8
lines changed

4 files changed

+16
-8
lines changed

dev/x86_64/src/poly_decompose_32_avx2.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
6161
* range: 0 <= f <= Q-1 = 32*GAMMA2 = 16*128*B
6262
*/
6363

64-
/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
64+
/* Compute f1' = ceil(f / 128) as floor((f + 127) / 2^7) */
6565
f1 = _mm256_add_epi32(f, off);
6666
f1 = _mm256_srli_epi32(f1, 7);
6767
/*
@@ -87,7 +87,9 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
8787
*/
8888
f1 = _mm256_mulhi_epu16(f1, v);
8989
/*
90-
* range: 0 <= f1'' < floor(2^16 * 1025 / 2^16) = 1025
90+
* range: 0 <= f1'' = floor(f1' * 1025 / 2^16)
91+
* <= f1' * 1025 / 2^16
92+
* < 2^16 * 1025 / 2^16 = 1025
9193
*
9294
* Because 0 <= f1'' < 2^15, the multiplication in mulhrs is unsigned, that
9395
* is, no erroneous sign-extension occurs.

dev/x86_64/src/poly_decompose_88_avx2.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
6262
* range: 0 <= f <= Q-1 = 88*GAMMA2 = 44*128*B
6363
*/
6464

65-
/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
65+
/* Compute f1' = ceil(f / 128) as floor((f + 127) / 2^7) */
6666
f1 = _mm256_add_epi32(f, off);
6767
f1 = _mm256_srli_epi32(f1, 7);
6868
/*
@@ -88,7 +88,9 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
8888
*/
8989
f1 = _mm256_mulhi_epu16(f1, v);
9090
/*
91-
* range: 0 <= f1'' < floor(2^16 * 11275 / 2^16) = 11275
91+
* range: 0 <= f1'' = floor(f1' * 11275 / 2^16)
92+
* <= f1' * 11275 / 2^16
93+
* < 2^16 * 11275 / 2^16 = 11275
9294
*
9395
* Because 0 <= f1'' < 2^15, the multiplication in mulhrs is unsigned, that
9496
* is, no erroneous sign-extension occurs.

mldsa/src/native/x86_64/src/poly_decompose_32_avx2.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
6161
* range: 0 <= f <= Q-1 = 32*GAMMA2 = 16*128*B
6262
*/
6363

64-
/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
64+
/* Compute f1' = ceil(f / 128) as floor((f + 127) / 2^7) */
6565
f1 = _mm256_add_epi32(f, off);
6666
f1 = _mm256_srli_epi32(f1, 7);
6767
/*
@@ -87,7 +87,9 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
8787
*/
8888
f1 = _mm256_mulhi_epu16(f1, v);
8989
/*
90-
* range: 0 <= f1'' < floor(2^16 * 1025 / 2^16) = 1025
90+
* range: 0 <= f1'' = floor(f1' * 1025 / 2^16)
91+
* <= f1' * 1025 / 2^16
92+
* < 2^16 * 1025 / 2^16 = 1025
9193
*
9294
* Because 0 <= f1'' < 2^15, the multiplication in mulhrs is unsigned, that
9395
* is, no erroneous sign-extension occurs.

mldsa/src/native/x86_64/src/poly_decompose_88_avx2.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
6262
* range: 0 <= f <= Q-1 = 88*GAMMA2 = 44*128*B
6363
*/
6464

65-
/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
65+
/* Compute f1' = ceil(f / 128) as floor((f + 127) / 2^7) */
6666
f1 = _mm256_add_epi32(f, off);
6767
f1 = _mm256_srli_epi32(f1, 7);
6868
/*
@@ -88,7 +88,9 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
8888
*/
8989
f1 = _mm256_mulhi_epu16(f1, v);
9090
/*
91-
* range: 0 <= f1'' < floor(2^16 * 11275 / 2^16) = 11275
91+
* range: 0 <= f1'' = floor(f1' * 11275 / 2^16)
92+
* <= f1' * 11275 / 2^16
93+
* < 2^16 * 11275 / 2^16 = 11275
9294
*
9395
* Because 0 <= f1'' < 2^15, the multiplication in mulhrs is unsigned, that
9496
* is, no erroneous sign-extension occurs.

0 commit comments

Comments
 (0)