Skip to content

Commit e51d0c6

Browse files
Update EnzymeAD/Enzyme-JAX to commit 04c15dc7c3736cca6f0a93c9934829fde489cc11 (#1980)
* Update EnzymeAD/Enzyme-JAX to commit 04c15dc7c3736cca6f0a93c9934829fde489cc11 Diff: EnzymeAD/Enzyme-JAX@ddea75d...04c15dc * test: explicit options --------- Co-authored-by: enzymead-bot[bot] <238314553+enzymead-bot[bot]@users.noreply.github.com> Co-authored-by: Avik Pal <avikpal@mit.edu>
1 parent a0ebfd5 commit e51d0c6

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "ddea75dacb5b9ba708e0b44a0161d4acf9adae31"
7+
ENZYMEXLA_COMMIT = "04c15dc7c3736cca6f0a93c9934829fde489cc11"
88

99
ENZYMEXLA_SHA256 = ""
1010

test/batching.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ function run_auto_batching_tests(f::F, args...) where {F}
4949
)
5050
@test occursin("stablehlo.while", hlo)
5151

52-
hlo = repr(@code_hlo f(args...))
52+
hlo = repr(
53+
@code_hlo compile_options = CompileOptions(;
54+
disable_auto_batching_passes=false
55+
) f(args...)
56+
)
5357
@test !occursin("stablehlo.while", hlo)
5458
end
5559
end
@@ -115,9 +119,13 @@ end
115119
input1 = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 10))
116120
input2 = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 10))
117121

118-
hlo = @code_hlo optimize = false mctr(map_with_scalar_indexing, 1:8, input1, input2)
122+
hlo = @code_hlo compile_options = CompileOptions(; disable_auto_batching_passes=true) mctr(
123+
map_with_scalar_indexing, 1:8, input1, input2
124+
)
119125
@test contains(repr(hlo), "stablehlo.while")
120-
hlo = @code_hlo optimize = true mctr(map_with_scalar_indexing, 1:8, input1, input2)
126+
hlo = @code_hlo compile_options = CompileOptions(; disable_auto_batching_passes=false) mctr(
127+
map_with_scalar_indexing, 1:8, input1, input2
128+
)
121129
@test !contains(repr(hlo), "stablehlo.while")
122130

123131
res_ra = @jit mctr(map_with_scalar_indexing, 1:8, input1, input2)

0 commit comments

Comments
 (0)