Skip to content

Commit 5664f79

Browse files
committed
Automatically disable @nif if too long
1 parent 88b0a7a commit 5664f79

File tree

2 files changed

+59
-15
lines changed

2 files changed

+59
-15
lines changed

src/EvaluateEquation.jl

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
66
import ..UtilsModule: @maybe_turbo, is_bad_array, fill_similar, counttuple
77
import ..EquationUtilsModule: is_constant
88

9+
const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15
10+
911
struct ResultOk{A<:AbstractArray}
1012
x::A
1113
ok::Bool
@@ -160,7 +162,20 @@ end
160162
::Val{turbo},
161163
) where {T<:Number,turbo}
162164
nbin = get_nbin(operators)
163-
quote
165+
long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
166+
if long_compilation_time
167+
return quote
168+
result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo))
169+
!result_l.ok && return result_l
170+
@return_on_nonfinite_array result_l.x
171+
result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo))
172+
!result_r.ok && return result_r
173+
@return_on_nonfinite_array result_r.x
174+
# op(x, y), for any x or y
175+
deg2_eval(result_l.x, result_r.x, operators.binops[op_idx], Val(turbo))
176+
end
177+
end
178+
return quote
164179
return Base.Cartesian.@nif(
165180
$nbin,
166181
i -> i == op_idx,
@@ -201,9 +216,18 @@ end
201216
::Val{turbo},
202217
) where {T<:Number,turbo}
203218
nuna = get_nuna(operators)
219+
long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
220+
if long_compilation_time
221+
return quote
222+
result = _eval_tree_array(tree.l, cX, operators, Val(turbo))
223+
!result.ok && return result
224+
@return_on_nonfinite_array result.x
225+
deg1_eval(result.x, operators.unaops[op_idx], Val(turbo))
226+
end
227+
end
204228
# This @nif lets us generate an if statement over choice of operator,
205229
# which means the compiler will be able to completely avoid type inference on operators.
206-
quote
230+
return quote
207231
Base.Cartesian.@nif(
208232
$nuna,
209233
i -> i == op_idx,
@@ -240,6 +264,8 @@ end
240264
::Val{turbo},
241265
) where {T<:Number,F,turbo}
242266
nbin = counttuple(binops)
267+
# (Note this is only called from dispatch_deg1_eval, which has already
268+
# checked for long compilation times, so we don't need to check here)
243269
quote
244270
Base.Cartesian.@nif(
245271
$nbin,
@@ -450,21 +476,28 @@ over an entire array when the values are all the same.
450476
) where {T<:Number}
451477
nuna = get_nuna(operators)
452478
nbin = get_nbin(operators)
453-
quote
454-
if tree.degree == 0
455-
return deg0_eval_constant(tree)::ResultOk{Vector{T}}
456-
elseif tree.degree == 1
457-
op_idx = tree.op
458-
return Base.Cartesian.@nif(
479+
deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
480+
quote
481+
deg1_eval_constant(tree, operators.unaops[op_idx], operators)::ResultOk{Vector{T}}
482+
end
483+
else
484+
quote
485+
Base.Cartesian.@nif(
459486
$nuna,
460487
i -> i == op_idx,
461488
i -> deg1_eval_constant(
462489
tree, operators.unaops[i], operators
463490
)::ResultOk{Vector{T}}
464491
)
465-
else
466-
op_idx = tree.op
467-
return Base.Cartesian.@nif(
492+
end
493+
end
494+
deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
495+
quote
496+
deg2_eval_constant(tree, operators.binops[op_idx], operators)::ResultOk{Vector{T}}
497+
end
498+
else
499+
quote
500+
Base.Cartesian.@nif(
468501
$nbin,
469502
i -> i == op_idx,
470503
i -> deg2_eval_constant(
@@ -473,6 +506,17 @@ over an entire array when the values are all the same.
473506
)
474507
end
475508
end
509+
return quote
510+
if tree.degree == 0
511+
return deg0_eval_constant(tree)::ResultOk{Vector{T}}
512+
elseif tree.degree == 1
513+
op_idx = tree.op
514+
return $deg1_branch
515+
else
516+
op_idx = tree.op
517+
return $deg2_branch
518+
end
519+
end
476520
end
477521

478522
@inline function deg0_eval_constant(tree::AbstractExpressionNode{T}) where {T<:Number}

src/OperatorEnumConstruction.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module OperatorEnumConstructionModule
22

33
import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
44
import ..EquationModule: string_tree, Node, GraphNode, AbstractExpressionNode, constructorof
5-
import ..EvaluateEquationModule: eval_tree_array
5+
import ..EvaluateEquationModule: eval_tree_array, OPERATOR_LIMIT_BEFORE_SLOWDOWN
66
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient
77
import ..EvaluationHelpersModule: _grad_evaluator
88

@@ -365,10 +365,10 @@ function OperatorEnum(;
365365
:OperatorEnum,
366366
)
367367
for (op, s) in ((binary_operators, "binary"), (unary_operators, "unary"))
368-
if length(op) > 15
368+
if length(op) > OPERATOR_LIMIT_BEFORE_SLOWDOWN
369369
@warn(
370-
"You have passed over 15 $(s) operators. " *
371-
"Note that this will result in very slow compilation times. " *
370+
"You have passed over $(OPERATOR_LIMIT_BEFORE_SLOWDOWN) $(s) operators. " *
371+
"To prevent long compilation times, some optimizations will be disabled. " *
372372
"If this presents an issue, please open an issue on https://github.com/SymbolicML/DynamicExpressions.jl"
373373
)
374374
break

0 commit comments

Comments
 (0)