Skip to content

Commit bde6ac1

Browse files
committed
Iron out atomics and add tests for Float32
1 parent a00fad6 commit bde6ac1

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

src/device/atomics.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Atomic operation device overrides and fallbacks
2+
3+
# Fallback wrappers for Float32 atomic_inc!/atomic_dec!
4+
# Intel Level Zero doesn't support these directly for floating-point types,
5+
# so we implement them using atomic_add!/atomic_sub!
6+
7+
@device_override @inline function SPIRVIntrinsics.atomic_inc!(p::LLVMPtr{Float32,AS}) where {AS}
8+
SPIRVIntrinsics.atomic_add!(p, Float32(1))
9+
end
10+
11+
@device_override @inline function SPIRVIntrinsics.atomic_dec!(p::LLVMPtr{Float32,AS}) where {AS}
12+
SPIRVIntrinsics.atomic_sub!(p, Float32(1))
13+
end
14+
15+
# Float64 fallbacks (if Float64 is supported on device)
16+
@device_override @inline function SPIRVIntrinsics.atomic_inc!(p::LLVMPtr{Float64,AS}) where {AS}
17+
SPIRVIntrinsics.atomic_add!(p, Float64(1))
18+
end
19+
20+
@device_override @inline function SPIRVIntrinsics.atomic_dec!(p::LLVMPtr{Float64,AS}) where {AS}
21+
SPIRVIntrinsics.atomic_sub!(p, Float64(1))
22+
end

src/oneAPI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Base.Experimental.@MethodTable(method_table)
3434
include("device/runtime.jl")
3535
include("device/array.jl")
3636
include("device/quirks.jl")
37+
include("device/atomics.jl")
3738

3839
# essential stuff
3940
include("context.jl")

test/device/intrinsics.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ end
276276

277277
@testset "atomics (low level)" begin
278278

279-
@testset "atomic_add($T)" for T in [Int32, UInt32]
279+
@testset "atomic_add($T)" for T in [Int32, UInt32, Float32]
280280
a = oneArray([zero(T)])
281281

282282
function kernel(a, b)
@@ -288,7 +288,7 @@ end
288288
@test Array(a)[1] == T(256)
289289
end
290290

291-
@testset "atomic_sub($T)" for T in [Int32, UInt32]
291+
@testset "atomic_sub($T)" for T in [Int32, UInt32, Float32]
292292
a = oneArray([T(256)])
293293

294294
function kernel(a, b)
@@ -300,7 +300,7 @@ end
300300
@test Array(a)[1] == T(0)
301301
end
302302

303-
@testset "atomic_inc($T)" for T in [Int32, UInt32]
303+
@testset "atomic_inc($T)" for T in [Int32, UInt32, Float32]
304304
a = oneArray([zero(T)])
305305

306306
function kernel(a)
@@ -312,7 +312,7 @@ end
312312
@test Array(a)[1] == T(256)
313313
end
314314

315-
@testset "atomic_dec($T)" for T in [Int32, UInt32]
315+
@testset "atomic_dec($T)" for T in [Int32, UInt32, Float32]
316316
a = oneArray([T(256)])
317317

318318
function kernel(a)
@@ -324,25 +324,25 @@ end
324324
@test Array(a)[1] == T(0)
325325
end
326326

327-
@testset "atomic_min($T)" for T in [Int32, UInt32]
327+
@testset "atomic_min($T)" for T in [Int32, UInt32, Float32]
328328
a = oneArray([T(256)])
329329

330330
function kernel(a, T)
331331
i = get_global_id()
332-
oneAPI.atomic_min!(pointer(a), i%T)
332+
oneAPI.atomic_min!(pointer(a), T(i))
333333
return
334334
end
335335

336336
@oneapi items=256 kernel(a, T)
337337
@test Array(a)[1] == one(T)
338338
end
339339

340-
@testset "atomic_max($T)" for T in [Int32, UInt32]
340+
@testset "atomic_max($T)" for T in [Int32, UInt32, Float32]
341341
a = oneArray([zero(T)])
342342

343343
function kernel(a, T)
344344
i = get_global_id()
345-
oneAPI.atomic_max!(pointer(a), i%T)
345+
oneAPI.atomic_max!(pointer(a), T(i))
346346
return
347347
end
348348

0 commit comments

Comments
 (0)