Skip to content

Commit d118673

Browse files
committed
Create test for operators defined within module
1 parent 977e0c7 commit d118673

File tree

3 files changed

+62
-23
lines changed

3 files changed

+62
-23
lines changed

test/test_custom_operators.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
using DynamicExpressions
2+
using Test
3+
using Random
4+
5+
# Test that we can work with custom operators:
6+
function op1(x::T, y::T)::T where {T<:Real}
7+
return x + y
8+
end
9+
function op2(x::T, y::T)::T where {T<:Real}
10+
return x^2 + 1 / ((y)^2 + 0.1)
11+
end
12+
function op3(x::T)::T where {T<:Real}
13+
return sin(x) + cos(x)
14+
end
15+
local operators, tree
16+
operators = OperatorEnum(; binary_operators=(op1, op2), unary_operators=(op3,))
17+
@extend_operators operators
18+
x1 = Node(; feature=1)
19+
x2 = Node(; feature=2)
20+
tree = op1(op2(x1, x2), op3(x1))
21+
@test repr(tree) == "op1(op2(x1, x2), op3(x1))"
22+
# Test evaluation:
23+
X = randn(MersenneTwister(0), Float32, 2, 10);
24+
@test tree(X) ((x1, x2) -> op1(op2(x1, x2), op3(x1))).(X[1, :], X[2, :])
25+
26+
# Now, test that we can work with operators defined in modules
27+
module A
28+
29+
using DynamicExpressions
30+
using Random
31+
32+
function my_func_a(x::T, y::T) where {T<:Real}
33+
return x^2 * y
34+
end
35+
36+
function my_func_b(x::T) where {T<:Real}
37+
return x^3
38+
end
39+
40+
operators = OperatorEnum(; binary_operators=[my_func_a], unary_operators=[my_func_b])
41+
@extend_operators operators
42+
43+
function create_and_eval_tree()
44+
x1 = Node(Float64; feature=1)
45+
x2 = Node(Float64; feature=2)
46+
c1 = Node(Float64; val=0.2)
47+
tree = my_func_a(my_func_a(x2, 0.2), my_func_b(x1))
48+
func = (x1, x2) -> my_func_a(my_func_a(x2, 0.2), my_func_b(x1))
49+
X = randn(MersenneTwister(0), 2, 20)
50+
return tree(X), func.(X[1, :], X[2, :])
51+
end
52+
53+
end
54+
55+
# Now, test that we can work with operators defined in other modules
56+
import .A: create_and_eval_tree
57+
prediction, truth = create_and_eval_tree()
58+
@test prediction truth

test/test_tree_construction.jl

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -106,26 +106,3 @@ s = String(take!(io))
106106
set_node!(tree, tree2)
107107
@test tree !== tree2
108108
@test repr(tree) == repr(tree2)
109-
110-
# Test that we can work with custom operators:
111-
function op1(x, y)
112-
return x + y
113-
end
114-
function op2(x, y)
115-
return x^2 + 1 / ((y)^2 + 0.1)
116-
end
117-
function op3(x)
118-
return sin(x) + cos(x)
119-
end
120-
local operators, tree
121-
operators = OperatorEnum(;
122-
default_params..., binary_operators=(op1, op2), unary_operators=(op3,)
123-
)
124-
@extend_operators operators
125-
x1 = Node(; feature=1)
126-
x2 = Node(; feature=2)
127-
tree = op1(op2(x1, x2), op3(x1))
128-
@test repr(tree) == "op1(op2(x1, x2), op3(x1))"
129-
# Test evaluation:
130-
X = randn(MersenneTwister(0), 2, 10);
131-
@test tree(X) ((x1, x2) -> op1(op2(x1, x2), op3(x1))).(X[1, :], X[2, :])

test/unittest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,7 @@ end
5959
@safetestset "Test equality operator" begin
6060
include("test_equality.jl")
6161
end
62+
63+
@safetestset "Test operators within module" begin
64+
include("test_custom_operators.jl")
65+
end

0 commit comments

Comments
 (0)