Skip to content

Commit 91138d3

Browse files
committed
Add fast-compilation branches for gradients too
1 parent 1e96a1e commit 91138d3

File tree

2 files changed

+127
-57
lines changed

2 files changed

+127
-57
lines changed

src/EvaluateEquationDerivative.jl

Lines changed: 71 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import ..EquationModule: AbstractExpressionNode, constructorof
44
import ..OperatorEnumModule: OperatorEnum
55
import ..UtilsModule: is_bad_array, fill_similar
66
import ..EquationUtilsModule: count_constants, index_constants, NodeIndex
7-
import ..EvaluateEquationModule: deg0_eval, get_nuna, get_nbin
7+
import ..EvaluateEquationModule:
8+
deg0_eval, get_nuna, get_nbin, OPERATOR_LIMIT_BEFORE_SLOWDOWN
89
import ..ExtensionInterfaceModule: _zygote_gradient
910

1011
struct ResultOk2{A<:AbstractArray,B<:AbstractArray}
@@ -70,26 +71,42 @@ end
7071
)::ResultOk2 where {T<:Number}
7172
nuna = get_nuna(operators)
7273
nbin = get_nbin(operators)
73-
quote
74-
result = if tree.degree == 0
75-
diff_deg0_eval(tree, cX, direction)
76-
elseif tree.degree == 1
77-
op_idx = tree.op
74+
deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
75+
quote
76+
diff_deg1_eval(tree, cX, operators.unaops[op_idx], operators, direction)
77+
end
78+
else
79+
quote
7880
Base.Cartesian.@nif(
7981
$nuna,
8082
i -> i == op_idx,
8183
i ->
8284
diff_deg1_eval(tree, cX, operators.unaops[i], operators, direction)
8385
)
84-
else
85-
op_idx = tree.op
86+
end
87+
end
88+
deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
89+
diff_deg2_eval(tree, cX, operators.binops[op_idx], operators, direction)
90+
else
91+
quote
8692
Base.Cartesian.@nif(
8793
$nbin,
8894
i -> i == op_idx,
8995
i ->
9096
diff_deg2_eval(tree, cX, operators.binops[i], operators, direction)
9197
)
9298
end
99+
end
100+
quote
101+
result = if tree.degree == 0
102+
diff_deg0_eval(tree, cX, direction)
103+
elseif tree.degree == 1
104+
op_idx = tree.op
105+
$deg1_branch
106+
else
107+
op_idx = tree.op
108+
$deg2_branch
109+
end
93110
!result.ok && return result
94111
return ResultOk2(
95112
result.x, result.dx, !(is_bad_array(result.x) || is_bad_array(result.dx))
@@ -251,39 +268,57 @@ end
251268
)::ResultOk2 where {T<:Number,variable}
252269
nuna = get_nuna(operators)
253270
nbin = get_nbin(operators)
271+
deg1_branch_skeleton = quote
272+
grad_deg1_eval(
273+
tree,
274+
n_gradients,
275+
index_tree,
276+
cX,
277+
operators.unaops[i],
278+
operators,
279+
Val(variable),
280+
)
281+
end
282+
deg2_branch_skeleton = quote
283+
grad_deg2_eval(
284+
tree,
285+
n_gradients,
286+
index_tree,
287+
cX,
288+
operators.binops[i],
289+
operators,
290+
Val(variable),
291+
)
292+
end
293+
deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
294+
quote
295+
i = tree.op
296+
$deg1_branch_skeleton
297+
end
298+
else
299+
quote
300+
op_idx = tree.op
301+
Base.Cartesian.@nif($nuna, i -> i == op_idx, i -> $deg1_branch_skeleton)
302+
end
303+
end
304+
deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
305+
quote
306+
i = tree.op
307+
$deg2_branch_skeleton
308+
end
309+
else
310+
quote
311+
op_idx = tree.op
312+
Base.Cartesian.@nif($nbin, i -> i == op_idx, i -> $deg2_branch_skeleton)
313+
end
314+
end
254315
quote
255316
if tree.degree == 0
256317
grad_deg0_eval(tree, n_gradients, index_tree, cX, Val(variable))
257318
elseif tree.degree == 1
258-
op_idx = tree.op
259-
Base.Cartesian.@nif(
260-
$nuna,
261-
i -> i == op_idx,
262-
i -> grad_deg1_eval(
263-
tree,
264-
n_gradients,
265-
index_tree,
266-
cX,
267-
operators.unaops[i],
268-
operators,
269-
Val(variable),
270-
)
271-
)
319+
$deg1_branch
272320
else
273-
op_idx = tree.op
274-
Base.Cartesian.@nif(
275-
$nbin,
276-
i -> i == op_idx,
277-
i -> grad_deg2_eval(
278-
tree,
279-
n_gradients,
280-
index_tree,
281-
cX,
282-
operators.binops[i],
283-
operators,
284-
Val(variable),
285-
)
286-
)
321+
$deg2_branch
287322
end
288323
end
289324
end

test/test_derivatives.jl

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -142,28 +142,63 @@ for type in [Float16, Float32, Float64], turbo in [true, false]
142142
println("Done.")
143143
end
144144

145-
println("Testing NodeIndex.")
146-
147-
import DynamicExpressions: get_constants, NodeIndex, index_constants
148-
149-
operators = OperatorEnum(;
150-
binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin)
151-
)
152-
@extend_operators operators
153-
tree = equation3(nx1, nx2, nx3)
154-
155-
"""Check whether the ordering of constant_list is the same as the ordering of node_index."""
156-
function check_tree(tree::Node, node_index::NodeIndex, constant_list::AbstractVector)
157-
if tree.degree == 0
158-
(!tree.constant) || tree.val == constant_list[node_index.val::UInt16]
159-
elseif tree.degree == 1
160-
check_tree(tree.l, node_index.l, constant_list)
161-
else
162-
check_tree(tree.l, node_index.l, constant_list) &&
163-
check_tree(tree.r, node_index.r, constant_list)
145+
@testset "NodeIndex" begin
146+
import DynamicExpressions: get_constants, NodeIndex, index_constants
147+
148+
operators = OperatorEnum(;
149+
binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin)
150+
)
151+
@extend_operators operators
152+
tree = equation3(nx1, nx2, nx3)
153+
154+
"""Check whether the ordering of constant_list is the same as the ordering of node_index."""
155+
@eval function check_tree(
156+
tree::Node, node_index::NodeIndex, constant_list::AbstractVector
157+
)
158+
if tree.degree == 0
159+
(!tree.constant) || tree.val == constant_list[node_index.val::UInt16]
160+
elseif tree.degree == 1
161+
check_tree(tree.l, node_index.l, constant_list)
162+
else
163+
check_tree(tree.l, node_index.l, constant_list) &&
164+
check_tree(tree.r, node_index.r, constant_list)
165+
end
164166
end
167+
168+
@test check_tree(tree, index_constants(tree), get_constants(tree))
165169
end
166170

167-
@test check_tree(tree, index_constants(tree), get_constants(tree))
171+
@testset "Test many operators" begin
172+
# Since we use `@nif` in evaluating expressions,
173+
# we can see if there are any issues with LARGE numbers of operators.
174+
num_ops = 100
175+
binary_operators = [@eval function (x, y)
176+
return x + y
177+
end for i in 1:num_ops]
178+
unary_operators = [@eval function (x)
179+
return x^2
180+
end for i in 1:num_ops]
181+
182+
# This OperatorEnum will trigger the fallback code for fast compilation.
183+
many_ops_operators = OperatorEnum(;
184+
binary_operators=cat([+, -, *, /], binary_operators; dims=1),
185+
unary_operators=cat([sin, cos], unary_operators; dims=1),
186+
)
168187

169-
println("Done.")
188+
# This OperatorEnum will go through the regular evaluation code.
189+
only_basic_ops_operator = OperatorEnum(;
190+
binary_operators=[+, -, *, /], unary_operators=[sin, cos]
191+
)
192+
193+
# We want to compare their gradients
194+
num_tests = 100
195+
n_features = 3
196+
for _ in 1:num_tests
197+
tree = gen_random_tree_fixed_size(20, only_basic_ops_operator, n_features, Float64)
198+
X = randn(Float64, n_features, 10)
199+
basic_eval = tree'(X, only_basic_ops_operator)
200+
many_ops_eval = tree'(X, many_ops_operators)
201+
@test (all(isnan, basic_eval) && all(isnan, many_ops_eval)) ||
202+
basic_eval many_ops_eval
203+
end
204+
end

0 commit comments

Comments
 (0)