|
| 1 | +using Reactant, Test |
| 2 | +using Reactant: TracedRArray, TracedRNumber, MLIR, TracedUtils, ConcreteRArray |
| 3 | +using Reactant.MLIR: IR |
| 4 | +using Reactant.MLIR.Dialects: enzyme |
| 5 | + |
| 6 | +# `enzyme.randomSplit` op is not intended to be emitted directly in Reactant-land. |
| 7 | +# It is solely an intermediate representation within the `enzyme.mcmc` op lowering. |
| 8 | +function random_split(rng_state::TracedRArray{UInt64,1}, ::Val{N}) where {N} |
| 9 | + rng_mlir = TracedUtils.get_mlir_data(rng_state) |
| 10 | + rng_state_type = IR.TensorType([2], IR.Type(UInt64)) |
| 11 | + output_types = [rng_state_type for _ in 1:N] |
| 12 | + op = enzyme.randomSplit(rng_mlir; output_rng_states=output_types) |
| 13 | + return ntuple(i -> TracedRArray{UInt64,1}((), IR.result(op, i), (2,)), Val(N)) |
| 14 | +end |
| 15 | + |
| 16 | +@testset "enzyme.randomSplit op" begin |
| 17 | + @testset "N=2, Seed [0, 42]" begin |
| 18 | + seed = ConcreteRArray(UInt64[0, 42]) |
| 19 | + k1, k2 = @jit optimize = :probprog random_split(seed, Val(2)) |
| 20 | + |
| 21 | + @test Array(k1) == [0x99ba4efe6b200159, 0x4f6cc618de79f4b9] |
| 22 | + @test Array(k2) == [0xcddb151d375f238f, 0xf67a601be6bdada3] |
| 23 | + end |
| 24 | + |
| 25 | + @testset "N=2, Seed [42, 0]" begin |
| 26 | + seed = ConcreteRArray(UInt64[42, 0]) |
| 27 | + k1, k2 = @jit optimize = :probprog random_split(seed, Val(2)) |
| 28 | + |
| 29 | + @test Array(k1) == [0x4f6cc618de79f4b9, 0x99ba4efe6b200159] |
| 30 | + @test Array(k2) == [0xf67a601be6bdada3, 0xcddb151d375f238f] |
| 31 | + end |
| 32 | + |
| 33 | + @testset "N=3, Seed [0, 42]" begin |
| 34 | + seed = ConcreteRArray(UInt64[0, 42]) |
| 35 | + k1, k2, k3 = @jit optimize = :probprog random_split(seed, Val(3)) |
| 36 | + |
| 37 | + @test Array(k1) == [0x99ba4efe6b200159, 0x4f6cc618de79f4b9] |
| 38 | + @test Array(k2) == [0xcddb151d375f238f, 0xf67a601be6bdada3] |
| 39 | + @test Array(k3) == [0xa20e4081f71f4ea9, 0x2f36b83d4e83f1ba] |
| 40 | + end |
| 41 | + |
| 42 | + @testset "N=4, Seed [0, 42]" begin |
| 43 | + seed = ConcreteRArray(UInt64[0, 42]) |
| 44 | + k1, k2, k3, k4 = @jit optimize = :probprog random_split(seed, Val(4)) |
| 45 | + |
| 46 | + @test Array(k1) == [0x99ba4efe6b200159, 0x4f6cc618de79f4b9] |
| 47 | + @test Array(k2) == [0xcddb151d375f238f, 0xf67a601be6bdada3] |
| 48 | + @test Array(k3) == [0xa20e4081f71f4ea9, 0x2f36b83d4e83f1ba] |
| 49 | + @test Array(k4) == [0xe4e8dfbe9312778b, 0x982ff5502e6ccb51] |
| 50 | + end |
| 51 | +end |
0 commit comments