Skip to content

Commit 1e96a1e

Browse files
committed
Proper testing of extension error messages
1 parent c54966e commit 1e96a1e

File tree

7 files changed

+53
-28
lines changed

7 files changed

+53
-28
lines changed

ext/DynamicExpressionsZygoteExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module DynamicExpressionsZygoteExt
22

33
import Zygote: gradient
4-
import DynamicExpressions.EvaluateEquationDerivativeModule: _zygote_gradient
4+
import DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient
55

66
function _zygote_gradient(op::F, ::Val{1}) where {F}
77
function (x)

src/DynamicExpressions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module DynamicExpressions
22

33
include("Utils.jl")
4+
include("ExtensionInterface.jl")
45
include("OperatorEnum.jl")
56
include("Equation.jl")
67
include("EquationUtils.jl")
@@ -9,7 +10,6 @@ include("EvaluateEquationDerivative.jl")
910
include("EvaluationHelpers.jl")
1011
include("SimplifyEquation.jl")
1112
include("OperatorEnumConstruction.jl")
12-
include("ExtensionInterface.jl")
1313
include("Random.jl")
1414

1515
import PackageExtensionCompat: @require_extensions

src/EvaluateEquationDerivative.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@ import ..OperatorEnumModule: OperatorEnum
55
import ..UtilsModule: is_bad_array, fill_similar
66
import ..EquationUtilsModule: count_constants, index_constants, NodeIndex
77
import ..EvaluateEquationModule: deg0_eval, get_nuna, get_nbin
8+
import ..ExtensionInterfaceModule: _zygote_gradient
89

910
struct ResultOk2{A<:AbstractArray,B<:AbstractArray}
1011
x::A
1112
dx::B
1213
ok::Bool
1314
end
1415

15-
_zygote_gradient(args...) = error("Please load the Zygote.jl package.")
16-
1716
"""
1817
eval_diff_tree_array(tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer; turbo::Bool=false)
1918

src/ExtensionInterface.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
module ExtensionInterfaceModule
22

33
function node_to_symbolic(args...; kws...)
4-
return error(
5-
"Please load the `SymbolicUtils` package to use `node_to_symbolic(::AbstractExpressionNode, ::AbstractOperatorEnum; kws...)`.",
6-
)
4+
return error("Please load the `SymbolicUtils` package to use `node_to_symbolic`.")
75
end
86
function symbolic_to_node(args...; kws...)
9-
return error(
10-
"Please load the `SymbolicUtils` package to use `symbolic_to_node(::Symbolic, ::AbstractOperatorEnum; kws...)`.",
11-
)
7+
return error("Please load the `SymbolicUtils` package to use `symbolic_to_node`.")
8+
end
9+
10+
function _zygote_gradient(args...)
11+
return error("Please load the Zygote.jl package.")
1212
end
1313

1414
end

test/test_deprecations.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using DynamicExpressions
2+
using Test
3+
using Zygote
4+
5+
operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
6+
x1, x2 = Node{Float64}(; feature=1), Node{Float64}(; feature=2)
7+
tree = cos(2.1 * x1)
8+
9+
# Also test warnings:
10+
for constructor in (OperatorEnum, GenericOperatorEnum)
11+
operators = constructor(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
12+
VERSION >= v"1.9" &&
13+
@test_warn "The `tree(X; kws...)` syntax is deprecated" tree([1.0; 2.0;;])
14+
15+
constructor == GenericOperatorEnum && continue
16+
17+
tree'([1.0; 2.0;;])
18+
VERSION >= v"1.9" &&
19+
@test_warn "The `tree(X; kws...)` syntax is deprecated" tree'([1.0; 2.0;;])
20+
end

test/test_initial_errors.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
11
using DynamicExpressions
22
using Test
3-
using Zygote
43

54
# Before defining OperatorEnum, calling the implicit (deprecated)
65
# syntax should fail:
7-
tree = Node(; feature=1)
6+
tree = Node{Float64}(; feature=1)
87

98
if VERSION >= v"1.8"
10-
@test_throws ErrorException tree([1.0 2.0]')
11-
@test_throws "Please use the " tree([1.0 2.0]')
12-
@test_throws ErrorException tree'([1.0 2.0]')
13-
@test_throws "Please use the " tree'([1.0 2.0]')
9+
@test_throws ErrorException tree([1.0; 2.0;;])
10+
@test_throws "Please use the " tree([1.0; 2.0;;])
11+
@test_throws ErrorException tree'([1.0; 2.0;;])
12+
@test_throws "Please use the " tree'([1.0; 2.0;;])
1413
end
1514

15+
# Initial strings are still somewhat useful
1616
@test string(tree) == "x1"
1717
@test string(Node(1, tree)) == "unary_operator[1](x1)"
1818
@test string(Node(1, tree, tree)) == "binary_operator[1](x1, x1)"
1919

20-
# Also test warnings:
21-
for constructor in (OperatorEnum, GenericOperatorEnum)
22-
operators = constructor(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
23-
tree([1.0 2.0]')
24-
# Can't test for this:
25-
# expected_warn_msg = "The `tree(X; kws...)` syntax is deprecated"
26-
# @test occursin(expected_warn_msg, msg)
20+
# Before loading extensions, should fail with helpful message:
21+
operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
22+
x1, x2 = Node{Float64}(; feature=1), Node{Float64}(; feature=2)
23+
tree = cos(2.1 * x1) + sin(x2)
2724

28-
constructor == GenericOperatorEnum && continue
25+
if VERSION >= v"1.9"
26+
@test_throws(
27+
"Please load the `SymbolicUtils` package to use `node_to_symbolic`.",
28+
node_to_symbolic(tree, operators)
29+
)
30+
@test_throws(
31+
"Please load the `SymbolicUtils` package to use `symbolic_to_node`.",
32+
symbolic_to_node(tree, operators)
33+
)
2934

30-
tree'([1.0 2.0]')
31-
# Can't test for this:
32-
# expected_warn_msg = "The `tree'(X; kws...)` syntax is deprecated"
33-
# @test occursin(expected_warn_msg, msg)
35+
@test_throws("Please load the Zygote.jl package.", tree'(ones(2, 10)))
3436
end

test/unittest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ end
88
include("test_initial_errors.jl")
99
end
1010

11+
@safetestset "Test deprecations" begin
12+
include("test_deprecations.jl")
13+
end
14+
1115
@safetestset "Test tree construction and scoring" begin
1216
include("test_tree_construction.jl")
1317
end

0 commit comments

Comments
 (0)