Skip to content

Commit efe65c0

Browse files
committed
Fix type assertion to be of Node's type
1 parent 1581a51 commit efe65c0

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/EvaluateEquation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,11 +472,11 @@ function eval(current_node)
472472
that it was not defined for.
473473
"""
474474
function eval_tree_array(
475-
tree, cX::AbstractArray{T,N}, operators::GenericOperatorEnum
476-
) where {T,N}
475+
tree::Node{T1}, cX::AbstractArray{T2,N}, operators::GenericOperatorEnum
476+
) where {T1,T2,N}
477477
if tree.degree == 0
478478
if tree.constant
479-
return (tree.val::T), true
479+
return (tree.val::T1), true
480480
else
481481
if N == 1
482482
return cX[tree.feature], true

test/test_tensor_operators.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,11 @@ tree = Node(1, c1)
3434
tree = Node(1, Node(1, c1), x1)
3535
@test repr(tree) == "vec_add(vec_square([1.0, 2.0, 3.0]), x1)"
3636
@test tree(X) == [3.0, 6.0, 11.0]
37+
38+
# Also test mixed scalar and floats:
39+
c2 = Node(T; val=2.0)
40+
@test repr(c2) == "2.0"
41+
tree = vec_add(vec_add(c1, x1), c2)
42+
@test repr(tree) == "vec_add(vec_add([1.0, 2.0, 3.0], x1), 2.0)"
43+
tree(X)
44+
@test tree(X) == [5.0, 6.0, 7.0]

0 commit comments

Comments
 (0)