Skip to content

Commit 770c6cd

Browse files
authored
fix: enable batching except slice to batch passes (#1985)
* fix: enable batching by default * feat: finegrained control of options
1 parent 82c83b8 commit 770c6cd

File tree

4 files changed

+82
-34
lines changed

4 files changed

+82
-34
lines changed

docs/src/tutorials/raising.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ raising).
123123

124124
```@example raising_stablehlo
125125
@code_hlo compile_options = CompileOptions(;
126-
disable_auto_batching_passes=true
126+
disable_loop_raising_passes=true
127127
) compute_attractive_force(positions_ra, masses_ra, 2.0f0)
128128
```
129129

@@ -133,7 +133,7 @@ tensor IR.
133133

134134
```@example raising_stablehlo
135135
hlo = @code_hlo compile_options=CompileOptions(;
136-
disable_auto_batching_passes=false
136+
disable_loop_raising_passes=false
137137
) compute_attractive_force(positions_ra, masses_ra, 2.0f0)
138138
@assert !contains(repr(hlo), "stablehlo.while") #hide
139139
hlo
@@ -145,7 +145,7 @@ the values are identical.
145145
```@example raising_stablehlo
146146
y_jl = compute_attractive_force(positions, masses, 2.0f0)
147147
y_ra = @jit compile_options=CompileOptions(;
148-
disable_auto_batching_passes=false
148+
disable_loop_raising_passes=false
149149
) compute_attractive_force(positions_ra, masses_ra, 2.0f0)
150150
maximum(abs, Array(y_ra) .- y_jl)
151151
```
@@ -154,7 +154,7 @@ Let's time the execution of the two versions.
154154

155155
```@example raising_stablehlo
156156
fn1 = @compile sync=true compile_options=CompileOptions(;
157-
disable_auto_batching_passes=true
157+
disable_loop_raising_passes=true
158158
) compute_attractive_force(positions_ra, masses_ra, 2.0f0)
159159
fn2 = @compile sync=true compute_attractive_force(positions_ra, masses_ra, 2.0f0)
160160
```

src/CompileOptions.jl

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ Fine-grained control over the compilation options for the Reactant compiler.
9494
the computation graph. If `:down`, they will be propagated down. Defaults to `:up`.
9595
- `max_constant_threshold`: If the number of elements in a constant is greater than this
9696
threshold (for a non-splatted constant), we will throw an error.
97-
- `inline`: If `true`, all functions will be inlined. This is `true` by default.
97+
- `inline`: If `true`, all functions will be inlined. (Default: `true`).
9898
9999
## Raising Options
100100
@@ -107,7 +107,7 @@ Fine-grained control over the compilation options for the Reactant compiler.
107107
## Dialect Specific Options
108108
109109
- `legalize_chlo_to_stablehlo`: If `true`, `chlo` dialect ops will be converted to
110-
`stablehlo` ops. This is `false` by default.
110+
`stablehlo` ops. (Default: `false`).
111111
112112
## Backend Specific Options
113113
@@ -153,13 +153,21 @@ Fine-grained control over the compilation options for the Reactant compiler.
153153
notice or deprecation cycle.
154154
155155
- `disable_scatter_gather_optimization_passes`: Disables the scatter-gather
156-
optimization passes. This is `false` by default.
156+
optimization passes. (Default: `false`).
157157
- `disable_pad_optimization_passes`: Disables the pad optimization passes. This is
158158
`false` by default.
159159
- `disable_licm_optimization_passes`: Disables the Loop Invariant Code Motion (LICM)
160-
optimization passes. This is `false` by default.
161-
- `disable_auto_batching_passes`: Disables the auto-batching optimization passes. This
162-
is `false` by default.
160+
optimization passes. (Default: `false`).
161+
- `disable_reduce_slice_fusion_passes`: Disables fusion of slice elementwise and reduce
162+
operations. (Default `false`).
163+
- `disable_slice_to_batch_passes`: Disables the slice to batch fusion optimization passes.
164+
(Default: `true`). _(Note that this is generally an expensive pass to run)_
165+
- `disable_concat_to_batch_passes`: Disables concatenate to batch fusion passes.
166+
(Default: `false`).
167+
- `disable_loop_raising_passes`: Disables raising passes for `stablehlo.while`.
168+
(Default: `false`).
169+
- `disable_structured_tensors_passes`: Disables structured tensors detection and
170+
propagation passes. (Default `false`).
163171
"""
164172
struct CompileOptions
165173
optimization_passes::Union{Symbol,String}
@@ -188,7 +196,11 @@ struct CompileOptions
188196
disable_scatter_gather_optimization_passes::Bool
189197
disable_pad_optimization_passes::Bool
190198
disable_licm_optimization_passes::Bool
191-
disable_auto_batching_passes::Bool
199+
disable_reduce_slice_fusion_passes::Bool
200+
disable_slice_to_batch_passes::Bool
201+
disable_concat_to_batch_passes::Bool
202+
disable_loop_raising_passes::Bool
203+
disable_structured_tensors_passes::Bool
192204
end
193205

194206
function CompileOptions(;
@@ -212,7 +224,11 @@ function CompileOptions(;
212224
disable_scatter_gather_optimization_passes::Bool=false,
213225
disable_pad_optimization_passes::Bool=false,
214226
disable_licm_optimization_passes::Bool=false,
215-
disable_auto_batching_passes::Bool=true,
227+
disable_reduce_slice_fusion_passes::Bool=false,
228+
disable_slice_to_batch_passes::Bool=true, # expensive + introduces all-to-all in GB25
229+
disable_concat_to_batch_passes::Bool=false,
230+
disable_loop_raising_passes::Bool=false,
231+
disable_structured_tensors_passes::Bool=false,
216232
)
217233
optimization_passes isa Bool &&
218234
(optimization_passes = ifelse(optimization_passes, :all, :none))
@@ -261,7 +277,11 @@ function CompileOptions(;
261277
disable_scatter_gather_optimization_passes,
262278
disable_pad_optimization_passes,
263279
disable_licm_optimization_passes,
264-
disable_auto_batching_passes,
280+
disable_reduce_slice_fusion_passes,
281+
disable_slice_to_batch_passes,
282+
disable_concat_to_batch_passes,
283+
disable_loop_raising_passes,
284+
disable_structured_tensors_passes,
265285
)
266286
end
267287

@@ -303,7 +323,11 @@ function __compile_options_with_reversed_propagation(compile_options::CompileOpt
303323
compile_options.disable_scatter_gather_optimization_passes,
304324
compile_options.disable_pad_optimization_passes,
305325
compile_options.disable_licm_optimization_passes,
306-
compile_options.disable_auto_batching_passes,
326+
compile_options.disable_reduce_slice_fusion_passes,
327+
compile_options.disable_slice_to_batch_passes,
328+
compile_options.disable_concat_to_batch_passes,
329+
compile_options.disable_loop_raising_passes,
330+
compile_options.disable_structured_tensors_passes,
307331
)
308332
end
309333

src/Compiler.jl

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -926,18 +926,49 @@ function optimization_passes(
926926
if !is_sharded
927927
# these passes don't have optimized sharding implementations
928928
if raise_shlo_to_blas_lapack
929-
append!(transform_passes_list, ["dot_general_to_syrk"])
929+
if !compile_options.disable_structured_tensors_passes
930+
append!(transform_passes_list, ["dot_general_to_syrk"])
931+
end
930932
end
931933
end
932934

933-
if !compile_options.disable_auto_batching_passes
935+
if !compile_options.disable_slice_to_batch_passes
936+
append!(
937+
transform_passes_list,
938+
[
939+
"dot_general_slice_to_batch",
940+
"gather_slice_to_batch",
941+
"iota_slice_to_batch",
942+
"reduce_slice_to_batch",
943+
"sort_slice_to_batch",
944+
"transpose_slice_to_batch",
945+
"broadcastindim_slice_to_batch",
946+
"reducewindow_slice_to_batch",
947+
"elementwise_slice_to_batch",
948+
"convolution_slice_to_batch",
949+
],
950+
)
951+
end
952+
953+
if !compile_options.disable_reduce_slice_fusion_passes
934954
append!(
935955
transform_passes_list,
936956
[
937957
"add_reduce_slice_fusion",
938958
"mul_reduce_slice_fusion",
939959
"min_reduce_slice_fusion",
940960
"max_reduce_slice_fusion",
961+
"and_reduce_slice_fusion",
962+
"xor_reduce_slice_fusion",
963+
"or_reduce_slice_fusion",
964+
],
965+
)
966+
end
967+
968+
if !compile_options.disable_concat_to_batch_passes
969+
append!(
970+
transform_passes_list,
971+
[
941972
"concat_insert_dim_dot_general",
942973
"concat_insert_dim_gather",
943974
"concat_insert_dim_iota",
@@ -946,21 +977,14 @@ function optimization_passes(
946977
"concat_insert_dim_reduce_window",
947978
"concat_insert_dim_elementwise",
948979
"concat_insert_dim_convolution",
949-
"dot_general_slice_to_batch",
950-
"gather_slice_to_batch",
951-
"iota_slice_to_batch",
952-
"reduce_slice_to_batch",
953-
"sort_slice_to_batch",
954-
"transpose_slice_to_batch",
955-
"broadcastindim_slice_to_batch",
956-
"reducewindow_slice_to_batch",
957-
"elementwise_slice_to_batch",
958-
"convolution_slice_to_batch",
959-
"greedy_while_loop_batch_fission",
960980
],
961981
)
962982
end
963983

984+
if !compile_options.disable_loop_raising_passes
985+
append!(transform_passes_list, ["greedy_while_loop_batch_fission"])
986+
end
987+
964988
if !compile_options.disable_licm_optimization_passes
965989
append!(
966990
transform_passes_list,

test/batching.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,23 @@ function run_auto_batching_tests(f::F, args...) where {F}
3535
@testset "$(nameof(F))" begin
3636
@testset "Correctness" begin
3737
res1 = @jit f(args...)
38-
res2 = @jit compile_options = CompileOptions(;
39-
disable_auto_batching_passes=true
40-
) f(args...)
38+
res2 = @jit compile_options = CompileOptions(; disable_loop_raising_passes=true) f(
39+
args...
40+
)
4141
@test res1 res2
4242
end
4343

4444
@testset "No while loops" begin
4545
hlo = repr(
4646
@code_hlo compile_options = CompileOptions(;
47-
disable_auto_batching_passes=true
47+
disable_loop_raising_passes=true
4848
) f(args...)
4949
)
5050
@test occursin("stablehlo.while", hlo)
5151

5252
hlo = repr(
5353
@code_hlo compile_options = CompileOptions(;
54-
disable_auto_batching_passes=false
54+
disable_loop_raising_passes=false
5555
) f(args...)
5656
)
5757
@test !occursin("stablehlo.while", hlo)
@@ -119,11 +119,11 @@ end
119119
input1 = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 10))
120120
input2 = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 10))
121121

122-
hlo = @code_hlo compile_options = CompileOptions(; disable_auto_batching_passes=true) mctr(
122+
hlo = @code_hlo compile_options = CompileOptions(; disable_loop_raising_passes=true) mctr(
123123
map_with_scalar_indexing, 1:8, input1, input2
124124
)
125125
@test contains(repr(hlo), "stablehlo.while")
126-
hlo = @code_hlo compile_options = CompileOptions(; disable_auto_batching_passes=false) mctr(
126+
hlo = @code_hlo compile_options = CompileOptions(; disable_loop_raising_passes=false) mctr(
127127
map_with_scalar_indexing, 1:8, input1, input2
128128
)
129129
@test !contains(repr(hlo), "stablehlo.while")

0 commit comments

Comments
 (0)