Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions dev/x86_64/src/poly_decompose_32_avx2.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
* range: 0 <= f <= Q-1 = 32*GAMMA2 = 16*128*B
*/

/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
/* Compute f1' = ceil(f / 128) as floor((f + 127) / 2^7) */
f1 = _mm256_add_epi32(f, off);
f1 = _mm256_srli_epi32(f1, 7);
/*
Expand All @@ -72,10 +72,17 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
* _mm256_mulhi_epu16() below.
*/

/* check-magic: 2046 == intdiv(4092, 2) */
/*
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact
* for 0 <= f1' < 2^16. Note that half is rounded down since 1025 / 2^22 ≲
* 1 / 4092.
* 1 / 4092, so (for example) f1' = B / 2 = 2046 is mapped to
*
* round(2046 * 1025 / 2^22) = round(2046 * (1 / 4092 - epsilon))
* = round(1 / 2 - epsilon') = 0,
*
* where epsilon = 1 / 4092 - 1025 / 2^22 and epsilon' = 2046 * eps are both
* tiny but positive numbers.
Comment on lines 77 to +85
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hanno-becker May I have your review on this explanation? If this looks good, I plan to also apply this to the aarch64 implementation and resolve #654. Thank you!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jammychiou1 Apologies for the slow reply.

I think if we want to fix #654 we should provide a general explanation.

Here's an attempt:

The approximation error for `f1' / B ≈ f1' * 1025 / 2^22` is `f1' * (1025/2^22 - 1/B)`. 
For `eps := 1025/2^22 - 1/B` we have `eps = 1/4290772992 ~ 2^{-31.99} < 2^{-31}`. 
Hence `|f1'| * eps < 2^{-15}`. Given `B = 4092 ~ 2^12`, we thus have `|f1'| * eps < 1/B`. 
On the other hand, `1/B` is the spacing between the integral multiples of `1/B`, which
includes all rounding boundaries `n + 1/2` (since `B` is even). Hence, if `f1' / B` is not 
of the form `n + 1/2`, then moving from `f1' / B` to `f1' * 1025 / 2^22` does not cross 
a rounding boundary, and hence `round(f1' / B) = round(f1' * 1025 / 2^22)`, and it doesn't
matter on either side which version of rounding one uses. 
If `f1' / B` _is_ of the form `n + 1/2`, then `f1' * 1025 / 2^22` is slightly below it 
(and _not_ a multiple of `n + 1/2`), hence `round-(f1' / B) = round(f1' * 1025 / 2^22)`; 
where the round-down on the LHS is essential, and on the RHS the type of rounding 
again does not matter.

*
* round(f1' * 1025 / 2^22) is in turn computed in 2 steps as
* round(floor(f1' * 1025 / 2^16) / 2^6). The mulhi computes f1'' =
Expand All @@ -87,7 +94,9 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
*/
f1 = _mm256_mulhi_epu16(f1, v);
/*
* range: 0 <= f1'' < floor(2^16 * 1025 / 2^16) = 1025
* range: 0 <= f1'' = floor(f1' * 1025 / 2^16)
* <= f1' * 1025 / 2^16
* < 2^16 * 1025 / 2^16 = 1025
*
* Because 0 <= f1'' < 2^15, the multiplication in mulhrs is unsigned, that
* is, no erroneous sign-extension occurs.
Expand Down
15 changes: 12 additions & 3 deletions dev/x86_64/src/poly_decompose_88_avx2.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
* range: 0 <= f <= Q-1 = 88*GAMMA2 = 44*128*B
*/

/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
/* Compute f1' = ceil(f / 128) as floor((f + 127) / 2^7) */
f1 = _mm256_add_epi32(f, off);
f1 = _mm256_srli_epi32(f1, 7);
/*
Expand All @@ -73,10 +73,17 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
* _mm256_mulhi_epu16() below.
*/

/* check-magic: 744 == intdiv(1488, 2) */
/*
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact
* for 0 <= f1' < 2^16. Note that half is rounded down since 11275 / 2^24 ≲
* 1 / 1488.
* 1 / 1488, so (for example) f1' = B / 2 = 744 is mapped to
*
* round(744 * 11275 / 2^24) = round(744 * (1 / 1488 - epsilon))
* = round(1 / 2 - epsilon') = 0,
*
* where epsilon = 1 / 1488 - 11275 / 2^24 and epsilon' = 744 * eps are both
* tiny but positive numbers.
Comment on lines -79 to +86
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

*
* round(f1' * 11275 / 2^24) is in turn computed in 2 steps as
* round(floor(f1' * 11275 / 2^16) / 2^8). The mulhi computes f1'' =
Expand All @@ -88,7 +95,9 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
*/
f1 = _mm256_mulhi_epu16(f1, v);
/*
* range: 0 <= f1'' < floor(2^16 * 11275 / 2^16) = 11275
* range: 0 <= f1'' = floor(f1' * 11275 / 2^16)
* <= f1' * 11275 / 2^16
* < 2^16 * 11275 / 2^16 = 11275
*
* Because 0 <= f1'' < 2^15, the multiplication in mulhrs is unsigned, that
* is, no erroneous sign-extension occurs.
Expand Down
15 changes: 12 additions & 3 deletions mldsa/src/native/x86_64/src/poly_decompose_32_avx2.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
* range: 0 <= f <= Q-1 = 32*GAMMA2 = 16*128*B
*/

/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
/* Compute f1' = ceil(f / 128) as floor((f + 127) / 2^7) */
f1 = _mm256_add_epi32(f, off);
f1 = _mm256_srli_epi32(f1, 7);
/*
Expand All @@ -72,10 +72,17 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
* _mm256_mulhi_epu16() below.
*/

/* check-magic: 2046 == intdiv(4092, 2) */
/*
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact
* for 0 <= f1' < 2^16. Note that half is rounded down since 1025 / 2^22 ≲
* 1 / 4092.
* 1 / 4092, so (for example) f1' = B / 2 = 2046 is mapped to
*
* round(2046 * 1025 / 2^22) = round(2046 * (1 / 4092 - epsilon))
* = round(1 / 2 - epsilon') = 0,
*
* where epsilon = 1 / 4092 - 1025 / 2^22 and epsilon' = 2046 * eps are both
* tiny but positive numbers.
*
* round(f1' * 1025 / 2^22) is in turn computed in 2 steps as
* round(floor(f1' * 1025 / 2^16) / 2^6). The mulhi computes f1'' =
Expand All @@ -87,7 +94,9 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
*/
f1 = _mm256_mulhi_epu16(f1, v);
/*
* range: 0 <= f1'' < floor(2^16 * 1025 / 2^16) = 1025
* range: 0 <= f1'' = floor(f1' * 1025 / 2^16)
* <= f1' * 1025 / 2^16
* < 2^16 * 1025 / 2^16 = 1025
*
* Because 0 <= f1'' < 2^15, the multiplication in mulhrs is unsigned, that
* is, no erroneous sign-extension occurs.
Expand Down
15 changes: 12 additions & 3 deletions mldsa/src/native/x86_64/src/poly_decompose_88_avx2.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
* range: 0 <= f <= Q-1 = 88*GAMMA2 = 44*128*B
*/

/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
/* Compute f1' = ceil(f / 128) as floor((f + 127) / 2^7) */
f1 = _mm256_add_epi32(f, off);
f1 = _mm256_srli_epi32(f1, 7);
/*
Expand All @@ -73,10 +73,17 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
* _mm256_mulhi_epu16() below.
*/

/* check-magic: 744 == intdiv(1488, 2) */
/*
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact
* for 0 <= f1' < 2^16. Note that half is rounded down since 11275 / 2^24 ≲
* 1 / 1488.
* 1 / 1488, so (for example) f1' = B / 2 = 744 is mapped to
*
* round(744 * 11275 / 2^24) = round(744 * (1 / 1488 - epsilon))
* = round(1 / 2 - epsilon') = 0,
*
* where epsilon = 1 / 1488 - 11275 / 2^24 and epsilon' = 744 * eps are both
* tiny but positive numbers.
*
* round(f1' * 11275 / 2^24) is in turn computed in 2 steps as
* round(floor(f1' * 11275 / 2^16) / 2^8). The mulhi computes f1'' =
Expand All @@ -88,7 +95,9 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
*/
f1 = _mm256_mulhi_epu16(f1, v);
/*
* range: 0 <= f1'' < floor(2^16 * 11275 / 2^16) = 11275
* range: 0 <= f1'' = floor(f1' * 11275 / 2^16)
* <= f1' * 11275 / 2^16
* < 2^16 * 11275 / 2^16 = 11275
*
* Because 0 <= f1'' < 2^15, the multiplication in mulhrs is unsigned, that
* is, no erroneous sign-extension occurs.
Expand Down