Skip to content

Commit 1443cbb

Browse files
committed
Include tests for undefined evaluation
1 parent 357d3c2 commit 1443cbb

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

src/OperatorEnumConstruction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function create_evaluation_helpers!(operators::OperatorEnum)
1515
length(keys(kws)) > 1 && error("Unknown keyword argument: $(key)")
1616
out, did_finish = eval_tree_array(tree, X, $operators)
1717
if !did_finish
18-
out .= T(NaN)
18+
out .= convert(eltype(out), NaN)
1919
end
2020
return out
2121
end

test/test_evaluation.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,38 @@ truth = cos(3.0f0 + 4.0f0)
9898
@test DynamicExpressions.EvaluateEquationModule.deg1_l2_ll0_lr0_eval(
9999
tree, [0.0f0]', Val(1), Val(1), operators
100100
)[1][1] truth
101+
102+
# Test for presence of NaNs:
103+
operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
104+
x1 = Node(Float64; feature=1)
105+
tree = sin(x1 / 0.0)
106+
X = randn(Float32, 3, 10);
107+
@test isnan(tree(X)[1])
108+
109+
# And, with generic operator enum, this should be an actual error:
110+
operators = GenericOperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
111+
x1 = Node(Float64; feature=1)
112+
tree = sin(x1 / 0.0)
113+
X = randn(Float32, 10);
114+
@noinline stack = try
115+
tree(X)[1]
116+
@test false
117+
catch e
118+
@test isa(e, ErrorException)
119+
# Check that "Failed to evaluate" is in the message:
120+
@test occursin("Failed to evaluate", e.msg)
121+
current_exceptions()
122+
end;
123+
@test length(stack) == 2
124+
@test isa(stack[1].exception, DomainError)
125+
126+
# If a method is not defined, we should get a nothing:
127+
X = randn(Float32, 1, 10);
128+
@test tree(X; throw_errors=false) === nothing
129+
# or a MethodError:
130+
try
131+
tree(X; throw_errors=true)
132+
@test false
133+
catch e
134+
@test isa(current_exceptions()[1].exception, MethodError)
135+
end

0 commit comments

Comments
 (0)