Skip to content

Commit 0de6c6a

Browse files
committed
Include tensor operator test
1 parent 480b0aa commit 0de6c6a

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

src/Equation.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ function Base.convert(
8383
get!(id_map, tree) do
8484
if tree.degree == 0
8585
if tree.constant
86-
Node(0, tree.constant, convert(T1, (tree.val::T2)))
86+
val = tree.val::T2
87+
if !(T2 <: T1)
88+
# e.g., we don't want to convert Float32 to Union{Float32,Vector{Float32}}!
89+
val = convert(T1, val)
90+
end
91+
Node(T1, 0, tree.constant, val)
8792
else
8893
Node(T1, 0, tree.constant, nothing, tree.feature)
8994
end
@@ -138,6 +143,10 @@ function Node(
138143
"You must specify either `val` or `feature` when creating a leaf node, not both.",
139144
)
140145
elseif T2 <: Nothing
146+
if !(T1 <: T)
147+
# Only convert if not already in the type union.
148+
val = convert(T, val)
149+
end
141150
return Node(T, 0, true, val)
142151
else
143152
return Node(T, 0, false, nothing, feature)

test/test_tensor_operators.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using DynamicExpressions
2+
using Test
3+
4+
baseT = Float64
5+
T = Union{baseT,Vector{baseT},Matrix{baseT}}
6+
7+
function vec_add(x, y)
8+
return x .+ y
9+
end
10+
11+
operators = GenericOperatorEnum(; binary_operators=[vec_add], extend_user_operators=true)
12+
13+
x1, x2, x3 = [Node(T; feature=i) for i in 1:3]
14+
c1 = Node(T; val=[1.0, 2.0, 3.0])
15+
16+
X = [[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]
17+
18+
tree = Node(1, c1, x2)
19+
@test repr(tree) == "vec_add([1.0, 2.0, 3.0], x2)"
20+
@test tree(X) == [4.0, 5.0, 6.0]
21+
tree = Node(1, x1, c1)
22+
@test repr(tree) == "vec_add(x1, [1.0, 2.0, 3.0])"
23+
@test tree(X) == [3.0, 4.0, 5.0]

test/unittest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,7 @@ end
4747
@safetestset "Test generic operators" begin
4848
include("test_generic_operators.jl")
4949
end
50+
51+
@safetestset "Test tensor operators" begin
52+
include("test_tensor_operators.jl")
53+
end

0 commit comments

Comments
 (0)