Skip to content

Commit 846826d

Browse files
committed
enzyme.randomSplit op test
1 parent 2d5d9a5 commit 846826d

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

test/probprog/random.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
using Reactant, Test
2+
using Reactant: TracedRArray, TracedRNumber, MLIR, TracedUtils, ConcreteRArray
3+
using Reactant.MLIR: IR
4+
using Reactant.MLIR.Dialects: enzyme
5+
6+
# `enzyme.randomSplit` op is not intended to be emitted directly in Reactant-land.
7+
# It is solely an intermediate representation within the `enzyme.mcmc` op lowering.
8+
function random_split(rng_state::TracedRArray{UInt64,1}, ::Val{N}) where {N}
9+
rng_mlir = TracedUtils.get_mlir_data(rng_state)
10+
rng_state_type = IR.TensorType([2], IR.Type(UInt64))
11+
output_types = [rng_state_type for _ in 1:N]
12+
op = enzyme.randomSplit(rng_mlir; output_rng_states=output_types)
13+
return ntuple(i -> TracedRArray{UInt64,1}((), IR.result(op, i), (2,)), Val(N))
14+
end
15+
16+
@testset "enzyme.randomSplit op" begin
17+
@testset "N=2, Seed [0, 42]" begin
18+
seed = ConcreteRArray(UInt64[0, 42])
19+
k1, k2 = @jit optimize = :probprog random_split(seed, Val(2))
20+
21+
@test Array(k1) == [0x99ba4efe6b200159, 0x4f6cc618de79f4b9]
22+
@test Array(k2) == [0xcddb151d375f238f, 0xf67a601be6bdada3]
23+
end
24+
25+
@testset "N=2, Seed [42, 0]" begin
26+
seed = ConcreteRArray(UInt64[42, 0])
27+
k1, k2 = @jit optimize = :probprog random_split(seed, Val(2))
28+
29+
@test Array(k1) == [0x4f6cc618de79f4b9, 0x99ba4efe6b200159]
30+
@test Array(k2) == [0xf67a601be6bdada3, 0xcddb151d375f238f]
31+
end
32+
33+
@testset "N=3, Seed [0, 42]" begin
34+
seed = ConcreteRArray(UInt64[0, 42])
35+
k1, k2, k3 = @jit optimize = :probprog random_split(seed, Val(3))
36+
37+
@test Array(k1) == [0x99ba4efe6b200159, 0x4f6cc618de79f4b9]
38+
@test Array(k2) == [0xcddb151d375f238f, 0xf67a601be6bdada3]
39+
@test Array(k3) == [0xa20e4081f71f4ea9, 0x2f36b83d4e83f1ba]
40+
end
41+
42+
@testset "N=4, Seed [0, 42]" begin
43+
seed = ConcreteRArray(UInt64[0, 42])
44+
k1, k2, k3, k4 = @jit optimize = :probprog random_split(seed, Val(4))
45+
46+
@test Array(k1) == [0x99ba4efe6b200159, 0x4f6cc618de79f4b9]
47+
@test Array(k2) == [0xcddb151d375f238f, 0xf67a601be6bdada3]
48+
@test Array(k3) == [0xa20e4081f71f4ea9, 0x2f36b83d4e83f1ba]
49+
@test Array(k4) == [0xe4e8dfbe9312778b, 0x982ff5502e6ccb51]
50+
end
51+
end

0 commit comments

Comments
 (0)