From dc5f652631763ae9f4af75a5205a02513885acad Mon Sep 17 00:00:00 2001 From: Felix Schlepper Date: Wed, 9 Jul 2025 07:51:48 +0200 Subject: [PATCH] ITS: fix TypedAllocator for cuda thrust Signed-off-by: Felix Schlepper --- .../ITS/tracking/GPU/cuda/TrackingKernels.cu | 43 ++++++++++++------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/Detectors/ITSMFT/ITS/tracking/GPU/cuda/TrackingKernels.cu b/Detectors/ITSMFT/ITS/tracking/GPU/cuda/TrackingKernels.cu index 8245aee33718c..38c59d520aa76 100644 --- a/Detectors/ITSMFT/ITS/tracking/GPU/cuda/TrackingKernels.cu +++ b/Detectors/ITSMFT/ITS/tracking/GPU/cuda/TrackingKernels.cu @@ -58,30 +58,43 @@ namespace gpu { template -class TypedAllocator : public thrust::device_allocator -{ - public: +struct TypedAllocator { using value_type = T; - using pointer = T*; + using pointer = thrust::device_ptr; + using const_pointer = thrust::device_ptr; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + + TypedAllocator() noexcept : mInternalAllocator(nullptr) {} + explicit TypedAllocator(ExternalAllocator* a) noexcept : mInternalAllocator(a) {} template - struct rebind { - using other = TypedAllocator; - }; + TypedAllocator(const TypedAllocator& o) noexcept : mInternalAllocator(o.mInternalAllocator) + { + } - explicit TypedAllocator(ExternalAllocator* allocPtr) - : mInternalAllocator(allocPtr) {} + pointer allocate(size_type n) + { + void* raw = mInternalAllocator->allocate(n * sizeof(T)); + return thrust::device_pointer_cast(static_cast(raw)); + } - T* allocate(size_t n) + void deallocate(pointer p, size_type n) noexcept { - return reinterpret_cast(mInternalAllocator->allocate(n * sizeof(T))); + if (!p) { + return; + } + void* raw = thrust::raw_pointer_cast(p); + mInternalAllocator->deallocate(static_cast(raw), n * sizeof(T)); } - void deallocate(T* p, size_t n) + bool operator==(TypedAllocator const& o) const noexcept + { + return mInternalAllocator == o.mInternalAllocator; + } + bool operator!=(TypedAllocator const& o) const noexcept { - char* raw_ptr = reinterpret_cast(p); - size_t bytes = n * sizeof(T); - mInternalAllocator->deallocate(raw_ptr, bytes); // redundant as internal dealloc is no-op. + return !(*this == o); } private: