From bc120d74cbf303b84f3d53b82765c8b9d4aa4e51 Mon Sep 17 00:00:00 2001 From: Christian Sonnabend Date: Thu, 13 Mar 2025 09:48:32 +0100 Subject: [PATCH] Making float16 variables compatible with GPU types --- Common/ML/include/ML/3rdparty/GPUORTFloat16.h | 126 ++++++++++-------- 1 file changed, 72 insertions(+), 54 deletions(-) diff --git a/Common/ML/include/ML/3rdparty/GPUORTFloat16.h b/Common/ML/include/ML/3rdparty/GPUORTFloat16.h index db65328409d3c..76fd6734cf9db 100644 --- a/Common/ML/include/ML/3rdparty/GPUORTFloat16.h +++ b/Common/ML/include/ML/3rdparty/GPUORTFloat16.h @@ -5,10 +5,18 @@ // - https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_float16.h // - https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_cxx_api.h +#ifndef GPUORTFLOAT16_H +#define GPUORTFLOAT16_H + +#ifndef GPUCA_GPUCODE_DEVICE #include #include #include #include +#endif + +#include "GPUCommonDef.h" +#include "GPUCommonMath.h" namespace o2 { @@ -50,19 +58,19 @@ struct Float16Impl { /// /// /// - constexpr static uint16_t ToUint16Impl(float v) noexcept; + GPUd() constexpr static uint16_t ToUint16Impl(float v) noexcept; /// /// Converts float16 to float /// /// float representation of float16 value - float ToFloatImpl() const noexcept; + GPUd() float ToFloatImpl() const noexcept; /// /// Creates an instance that represents absolute value. /// /// Absolute value - uint16_t AbsImpl() const noexcept + GPUd() uint16_t AbsImpl() const noexcept { return static_cast(val & ~kSignMask); } @@ -71,7 +79,7 @@ struct Float16Impl { /// Creates a new instance with the sign flipped. /// /// Flipped sign instance - uint16_t NegateImpl() const noexcept + GPUd() uint16_t NegateImpl() const noexcept { return IsNaN() ? val : static_cast(val ^ kSignMask); } @@ -92,13 +100,13 @@ struct Float16Impl { uint16_t val{0}; - Float16Impl() = default; + GPUdDefault() Float16Impl() = default; /// /// Checks if the value is negative /// /// true if negative - bool IsNegative() const noexcept + GPUd() bool IsNegative() const noexcept { return static_cast(val) < 0; } @@ -107,7 +115,7 @@ struct Float16Impl { /// Tests if the value is NaN /// /// true if NaN - bool IsNaN() const noexcept + GPUd() bool IsNaN() const noexcept { return AbsImpl() > kPositiveInfinityBits; } @@ -116,7 +124,7 @@ struct Float16Impl { /// Tests if the value is finite /// /// true if finite - bool IsFinite() const noexcept + GPUd() bool IsFinite() const noexcept { return AbsImpl() < kPositiveInfinityBits; } @@ -125,7 +133,7 @@ struct Float16Impl { /// Tests if the value represents positive infinity. /// /// true if positive infinity - bool IsPositiveInfinity() const noexcept + GPUd() bool IsPositiveInfinity() const noexcept { return val == kPositiveInfinityBits; } @@ -134,7 +142,7 @@ struct Float16Impl { /// Tests if the value represents negative infinity /// /// true if negative infinity - bool IsNegativeInfinity() const noexcept + GPUd() bool IsNegativeInfinity() const noexcept { return val == kNegativeInfinityBits; } @@ -143,7 +151,7 @@ struct Float16Impl { /// Tests if the value is either positive or negative infinity. /// /// True if absolute value is infinity - bool IsInfinity() const noexcept + GPUd() bool IsInfinity() const noexcept { return AbsImpl() == kPositiveInfinityBits; } @@ -152,7 +160,7 @@ struct Float16Impl { /// Tests if the value is NaN or zero. Useful for comparisons. /// /// True if NaN or zero. - bool IsNaNOrZero() const noexcept + GPUd() bool IsNaNOrZero() const noexcept { auto abs = AbsImpl(); return (abs == 0 || abs > kPositiveInfinityBits); @@ -162,7 +170,7 @@ struct Float16Impl { /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). /// /// True if so - bool IsNormal() const noexcept + GPUd() bool IsNormal() const noexcept { auto abs = AbsImpl(); return (abs < kPositiveInfinityBits) // is finite @@ -174,7 +182,7 @@ struct Float16Impl { /// Tests if the value is subnormal (denormal). /// /// True if so - bool IsSubnormal() const noexcept + GPUd() bool IsSubnormal() const noexcept { auto abs = AbsImpl(); return (abs < kPositiveInfinityBits) // is finite @@ -186,13 +194,13 @@ struct Float16Impl { /// Creates an instance that represents absolute value. /// /// Absolute value - Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } + GPUd() Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } /// /// Creates a new instance with the sign flipped. /// /// Flipped sign instance - Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } + GPUd() Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } /// /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check @@ -202,12 +210,12 @@ struct Float16Impl { /// first value /// second value /// True if both arguments represent zero - static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept + GPUd() static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept { return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; } - bool operator==(const Float16Impl& rhs) const noexcept + GPUd() bool operator==(const Float16Impl& rhs) const noexcept { if (IsNaN() || rhs.IsNaN()) { // IEEE defines that NaN is not equal to anything, including itself. @@ -216,9 +224,9 @@ struct Float16Impl { return val == rhs.val; } - bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); } + GPUd() bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); } - bool operator<(const Float16Impl& rhs) const noexcept + GPUd() bool operator<(const Float16Impl& rhs) const noexcept { if (IsNaN() || rhs.IsNaN()) { // IEEE defines that NaN is unordered with respect to everything, including itself. @@ -267,7 +275,7 @@ union float32_bits { }; // namespace detail template -inline constexpr uint16_t Float16Impl::ToUint16Impl(float v) noexcept +GPUdi() constexpr uint16_t Float16Impl::ToUint16Impl(float v) noexcept { detail::float32_bits f{}; f.f = v; @@ -316,7 +324,7 @@ inline constexpr uint16_t Float16Impl::ToUint16Impl(float v) noexcept } template -inline float Float16Impl::ToFloatImpl() const noexcept +GPUdi() float Float16Impl::ToFloatImpl() const noexcept { constexpr detail::float32_bits magic = {113 << 23}; constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift @@ -356,19 +364,19 @@ struct BFloat16Impl { /// /// /// - static uint16_t ToUint16Impl(float v) noexcept; + GPUd() static uint16_t ToUint16Impl(float v) noexcept; /// /// Converts bfloat16 to float /// /// float representation of bfloat16 value - float ToFloatImpl() const noexcept; + GPUd() float ToFloatImpl() const noexcept; /// /// Creates an instance that represents absolute value. /// /// Absolute value - uint16_t AbsImpl() const noexcept + GPUd() uint16_t AbsImpl() const noexcept { return static_cast(val & ~kSignMask); } @@ -377,7 +385,7 @@ struct BFloat16Impl { /// Creates a new instance with the sign flipped. /// /// Flipped sign instance - uint16_t NegateImpl() const noexcept + GPUd() uint16_t NegateImpl() const noexcept { return IsNaN() ? val : static_cast(val ^ kSignMask); } @@ -400,13 +408,13 @@ struct BFloat16Impl { uint16_t val{0}; - BFloat16Impl() = default; + GPUdDefault() BFloat16Impl() = default; /// /// Checks if the value is negative /// /// true if negative - bool IsNegative() const noexcept + GPUd() bool IsNegative() const noexcept { return static_cast(val) < 0; } @@ -415,7 +423,7 @@ struct BFloat16Impl { /// Tests if the value is NaN /// /// true if NaN - bool IsNaN() const noexcept + GPUd() bool IsNaN() const noexcept { return AbsImpl() > kPositiveInfinityBits; } @@ -424,7 +432,7 @@ struct BFloat16Impl { /// Tests if the value is finite /// /// true if finite - bool IsFinite() const noexcept + GPUd() bool IsFinite() const noexcept { return AbsImpl() < kPositiveInfinityBits; } @@ -433,7 +441,7 @@ struct BFloat16Impl { /// Tests if the value represents positive infinity. /// /// true if positive infinity - bool IsPositiveInfinity() const noexcept + GPUd() bool IsPositiveInfinity() const noexcept { return val == kPositiveInfinityBits; } @@ -442,7 +450,7 @@ struct BFloat16Impl { /// Tests if the value represents negative infinity /// /// true if negative infinity - bool IsNegativeInfinity() const noexcept + GPUd() bool IsNegativeInfinity() const noexcept { return val == kNegativeInfinityBits; } @@ -451,7 +459,7 @@ struct BFloat16Impl { /// Tests if the value is either positive or negative infinity. /// /// True if absolute value is infinity - bool IsInfinity() const noexcept + GPUd() bool IsInfinity() const noexcept { return AbsImpl() == kPositiveInfinityBits; } @@ -460,7 +468,7 @@ struct BFloat16Impl { /// Tests if the value is NaN or zero. Useful for comparisons. /// /// True if NaN or zero. - bool IsNaNOrZero() const noexcept + GPUd() bool IsNaNOrZero() const noexcept { auto abs = AbsImpl(); return (abs == 0 || abs > kPositiveInfinityBits); @@ -470,7 +478,7 @@ struct BFloat16Impl { /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). /// /// True if so - bool IsNormal() const noexcept + GPUd() bool IsNormal() const noexcept { auto abs = AbsImpl(); return (abs < kPositiveInfinityBits) // is finite @@ -482,7 +490,7 @@ struct BFloat16Impl { /// Tests if the value is subnormal (denormal). /// /// True if so - bool IsSubnormal() const noexcept + GPUd() bool IsSubnormal() const noexcept { auto abs = AbsImpl(); return (abs < kPositiveInfinityBits) // is finite @@ -494,13 +502,13 @@ struct BFloat16Impl { /// Creates an instance that represents absolute value. /// /// Absolute value - Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } + GPUd() Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } /// /// Creates a new instance with the sign flipped. /// /// Flipped sign instance - Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } + GPUd() Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } /// /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check @@ -510,7 +518,7 @@ struct BFloat16Impl { /// first value /// second value /// True if both arguments represent zero - static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept + GPUd() static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept { // IEEE defines that positive and negative zero are equal, this gives us a quick equality check // for two values by or'ing the private bits together and stripping the sign. They are both zero, @@ -520,14 +528,17 @@ struct BFloat16Impl { }; template -inline uint16_t BFloat16Impl::ToUint16Impl(float v) noexcept +GPUdi() uint16_t BFloat16Impl::ToUint16Impl(float v) noexcept { uint16_t result; - if (std::isnan(v)) { + if (o2::gpu::CAMath::IsNaN(v)) { result = kPositiveQNaNBits; } else { auto get_msb_half = [](float fl) { uint16_t result; +#ifdef GPUCA_GPUCODE + o2::gpu::CAMath::memcpy(&result, reinterpret_cast(&fl) + sizeof(uint16_t), sizeof(uint16_t)); +#else #ifdef __cpp_if_constexpr if constexpr (detail::endian::native == detail::endian::little) #else @@ -538,6 +549,7 @@ inline uint16_t BFloat16Impl::ToUint16Impl(float v) noexcept } else { std::memcpy(&result, &fl, sizeof(uint16_t)); } +#endif return result; }; @@ -554,14 +566,18 @@ inline uint16_t BFloat16Impl::ToUint16Impl(float v) noexcept } template -inline float BFloat16Impl::ToFloatImpl() const noexcept +GPUdi() float BFloat16Impl::ToFloatImpl() const noexcept { if (IsNaN()) { - return std::numeric_limits::quiet_NaN(); + return o2::gpu::CAMath::QuietNaN(); } float result; char* const first = reinterpret_cast(&result); char* const second = first + sizeof(uint16_t); +#ifdef GPUCA_GPUCODE + first[0] = first[1] = 0; + o2::gpu::CAMath::memcpy(second, &val, sizeof(uint16_t)); +#else #ifdef __cpp_if_constexpr if constexpr (detail::endian::native == detail::endian::little) #else @@ -574,6 +590,7 @@ inline float BFloat16Impl::ToFloatImpl() const noexcept std::memcpy(first, &val, sizeof(uint16_t)); std::memset(second, 0, sizeof(uint16_t)); } +#endif return result; } @@ -610,26 +627,26 @@ struct Float16_t : OrtDataType::Float16Impl { /// /// Default constructor /// - Float16_t() = default; + GPUdDefault() Float16_t() = default; /// /// Explicit conversion to uint16_t representation of float16. /// /// uint16_t bit representation of float16 /// new instance of Float16_t - constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); } + GPUd() constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); } /// /// __ctor from float. Float is converted into float16 16-bit representation. /// /// float value - explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); } + GPUd() explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); } /// /// Converts float16 to float /// /// float representation of float16 value - float ToFloat() const noexcept { return Base::ToFloatImpl(); } + GPUd() float ToFloat() const noexcept { return Base::ToFloatImpl(); } /// /// Checks if the value is negative @@ -710,7 +727,7 @@ struct Float16_t : OrtDataType::Float16Impl { /// /// User defined conversion operator. Converts Float16_t to float. /// - explicit operator float() const noexcept { return ToFloat(); } + GPUdi() explicit operator float() const noexcept { return ToFloat(); } using Base::operator==; using Base::operator!=; @@ -751,26 +768,26 @@ struct BFloat16_t : OrtDataType::BFloat16Impl { public: using Base = OrtDataType::BFloat16Impl; - BFloat16_t() = default; + GPUdDefault() BFloat16_t() = default; /// /// Explicit conversion to uint16_t representation of bfloat16. /// /// uint16_t bit representation of bfloat16 /// new instance of BFloat16_t - static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); } + GPUd() static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); } /// /// __ctor from float. Float is converted into bfloat16 16-bit representation. /// /// float value - explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); } + GPUd() explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); } /// /// Converts bfloat16 to float /// /// float representation of bfloat16 value - float ToFloat() const noexcept { return Base::ToFloatImpl(); } + GPUd() float ToFloat() const noexcept { return Base::ToFloatImpl(); } /// /// Checks if the value is negative @@ -851,7 +868,7 @@ struct BFloat16_t : OrtDataType::BFloat16Impl { /// /// User defined conversion operator. Converts BFloat16_t to float. /// - explicit operator float() const noexcept { return ToFloat(); } + GPUdi() explicit operator float() const noexcept { return ToFloat(); } // We do not have an inherited impl for the below operators // as the internal class implements them a little differently @@ -864,4 +881,5 @@ static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); } // namespace OrtDataType -} // namespace o2 \ No newline at end of file +} // namespace o2 +#endif \ No newline at end of file