Skip to content

Commit 155f7fc

Browse files
committed
Add tests for extended operators
1 parent 6330054 commit 155f7fc

File tree

3 files changed

+34
-6
lines changed

3 files changed

+34
-6
lines changed

test/test_error_handling.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ T = Union{baseT,Vector{baseT},Matrix{baseT}}
77

88
scalar_add(x::T, y::T) where {T<:Real} = x + y
99

10-
operators = GenericOperatorEnum(; binary_operators=[scalar_add], extend_user_operators=true)
10+
operators = GenericOperatorEnum(; binary_operators=[scalar_add])
1111

1212
x1, x2, x3 = [Node(T; feature=i) for i in 1:3]
1313

14-
tree = Node(1, x1, x2)
14+
@extend_operators operators
15+
tree = scalar_add(x1, x2)
1516

1617
# With error handling:
1718
try

test/test_tensor_operators.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function vec_add(x, y)
88
return x .+ y
99
end
1010

11-
operators = GenericOperatorEnum(; binary_operators=[vec_add], extend_user_operators=true)
11+
operators = GenericOperatorEnum(; binary_operators=[vec_add])
1212

1313
x1, x2, x3 = [Node(T; feature=i) for i in 1:3]
1414
c1 = Node(T; val=[1.0, 2.0, 3.0])
@@ -22,20 +22,30 @@ tree = Node(1, x1, c1)
2222
@test repr(tree) == "vec_add(x1, [1.0, 2.0, 3.0])"
2323
@test tree(X) == [3.0, 4.0, 5.0]
2424

25+
# Try same things, but with constructors:
26+
@extend_operators operators
27+
tree = vec_add(c1, x2)
28+
@test repr(tree) == "vec_add([1.0, 2.0, 3.0], x2)"
29+
@test tree(X) == [4.0, 5.0, 6.0]
30+
tree = vec_add(x1, c1)
31+
@test repr(tree) == "vec_add(x1, [1.0, 2.0, 3.0])"
32+
@test tree(X) == [3.0, 4.0, 5.0]
33+
2534
# Also test unary operators:
2635
function vec_square(x)
2736
return x .* x
2837
end
2938

30-
operators = GenericOperatorEnum(;
31-
binary_operators=[vec_add], unary_operators=[vec_square], extend_user_operators=true
32-
)
39+
operators = GenericOperatorEnum(; binary_operators=[vec_add], unary_operators=[vec_square])
40+
@extend_operators operators
3341
tree = Node(1, c1)
3442
@test repr(tree) == "vec_square([1.0, 2.0, 3.0])"
3543
@test tree(X) == [1.0, 4.0, 9.0]
44+
@test vec_square(c1).val == [1.0, 4.0, 9.0]
3645
tree = Node(1, Node(1, c1), x1)
3746
@test repr(tree) == "vec_add(vec_square([1.0, 2.0, 3.0]), x1)"
3847
@test tree(X) == [3.0, 6.0, 11.0]
48+
@test (vec_add(vec_square(c1), x1))(X) == [3.0, 6.0, 11.0]
3949

4050
# Also test mixed scalar and floats:
4151
c2 = Node(T; val=2.0)

test/test_tree_construction.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,20 @@ s = String(take!(io))
106106
set_node!(tree, tree2)
107107
@test tree !== tree2
108108
@test repr(tree) == repr(tree2)
109+
110+
# Test that we can work with custom operators:
111+
function op1(x, y)
112+
return x + y
113+
end
114+
function op2(x, y)
115+
return x ^ 2 + 1/((y)^2 + 0.1)
116+
end
117+
function op3(x)
118+
return sin(x) + cos(x)
119+
end
120+
operators = OperatorEnum(; default_params..., binary_operators=(op1, op2), unary_operators=(op3,))
121+
@extend_operators operators
122+
x1 = Node(; feature=1)
123+
x2 = Node(; feature=2)
124+
tree = op1(op2(x1, x2), op3(x1))
125+
@test repr(tree) == "op1(op2(x1, x2), op3(x1))"

0 commit comments

Comments
 (0)