Skip to content

Commit abc7e74

Browse files
committed
Add alignment parameter to simd_masked_{load,store}
1 parent 04535f2 commit abc7e74

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

core/src/intrinsics/simd.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
//!
33
//! In this module, a "vector" is any `repr(simd)` type.
44
5+
use crate::marker::ConstParamTy;
6+
57
/// Inserts an element into a vector, returning the updated vector.
68
///
79
/// `T` must be a vector with element type `U`, and `idx` must be `const`.
@@ -377,6 +379,19 @@ pub unsafe fn simd_gather<T, U, V>(val: T, ptr: U, mask: V) -> T;
377379
#[rustc_nounwind]
378380
pub unsafe fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);
379381

382+
/// A type for alignment options for SIMD masked load/store intrinsics.
383+
#[derive(Debug, ConstParamTy, PartialEq, Eq)]
384+
pub enum SimdAlign {
385+
// These values must match the compiler's `SimdAlign` defined in
386+
// `rustc_middle/src/ty/consts/int.rs`!
387+
/// No alignment requirements on the pointer
388+
Unaligned = 0,
389+
/// The pointer must be aligned to the element type of the SIMD vector
390+
Element = 1,
391+
/// The pointer must be aligned to the SIMD vector type
392+
Vector = 2,
393+
}
394+
380395
/// Reads a vector of pointers.
381396
///
382397
/// `T` must be a vector.
@@ -392,13 +407,12 @@ pub unsafe fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);
392407
/// `val`.
393408
///
394409
/// # Safety
395-
/// Unmasked values in `T` must be readable as if by `<ptr>::read` (e.g. aligned to the element
396-
/// type).
410+
/// `ptr` must be aligned according to the `ALIGN` parameter, see [`SimdAlign`] for details.
397411
///
398412
/// `mask` must only contain `0` or `!0` values.
399413
#[rustc_intrinsic]
400414
#[rustc_nounwind]
401-
pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
415+
pub unsafe fn simd_masked_load<V, U, T, const ALIGN: SimdAlign>(mask: V, ptr: U, val: T) -> T;
402416

403417
/// Writes to a vector of pointers.
404418
///
@@ -414,13 +428,12 @@ pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
414428
/// Otherwise if the corresponding value in `mask` is `0`, do nothing.
415429
///
416430
/// # Safety
417-
/// Unmasked values in `T` must be writeable as if by `<ptr>::write` (e.g. aligned to the element
418-
/// type).
431+
/// `ptr` must be aligned according to the `ALIGN` parameter, see [`SimdAlign`] for details.
419432
///
420433
/// `mask` must only contain `0` or `!0` values.
421434
#[rustc_intrinsic]
422435
#[rustc_nounwind]
423-
pub unsafe fn simd_masked_store<V, U, T>(mask: V, ptr: U, val: T);
436+
pub unsafe fn simd_masked_store<V, U, T, const ALIGN: SimdAlign>(mask: V, ptr: U, val: T);
424437

425438
/// Adds two simd vectors elementwise, with saturation.
426439
///

portable-simd/crates/core_simd/src/vector.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,14 @@ where
474474
or: Self,
475475
) -> Self {
476476
// SAFETY: The safety of reading elements through `ptr` is ensured by the caller.
477-
unsafe { core::intrinsics::simd::simd_masked_load(enable.to_int(), ptr, or) }
477+
unsafe {
478+
core::intrinsics::simd::simd_masked_load::<
479+
_,
480+
_,
481+
_,
482+
{ core::intrinsics::simd::SimdAlign::Element },
483+
>(enable.to_int(), ptr, or)
484+
}
478485
}
479486

480487
/// Reads from potentially discontiguous indices in `slice` to construct a SIMD vector.
@@ -723,7 +730,14 @@ where
723730
#[inline]
724731
pub unsafe fn store_select_ptr(self, ptr: *mut T, enable: Mask<<T as SimdElement>::Mask, N>) {
725732
// SAFETY: The safety of writing elements through `ptr` is ensured by the caller.
726-
unsafe { core::intrinsics::simd::simd_masked_store(enable.to_int(), ptr, self) }
733+
unsafe {
734+
core::intrinsics::simd::simd_masked_store::<
735+
_,
736+
_,
737+
_,
738+
{ core::intrinsics::simd::SimdAlign::Element },
739+
>(enable.to_int(), ptr, self)
740+
}
727741
}
728742

729743
/// Writes the values in a SIMD vector to potentially discontiguous indices in `slice`.

0 commit comments

Comments
 (0)