Skip to content

Commit 40d3c1d

Browse files
committed
Test equality across types
1 parent f4ffa63 commit 40d3c1d

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

src/base.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ function any(f::F, tree::AbstractNode) where {F<:Function}
153153
end
154154
end
155155

156-
function Base.:(==)(a::AbstractExpressionNode, b::AbstractExpressionNode)::Bool
157-
if constructorof(typeof(a)) !== constructorof(typeof(b))
158-
return false
159-
end
160-
if preserve_sharing(typeof(a)) || preserve_sharing(typeof(b))
156+
function Base.:(==)(a::AbstractExpressionNode, b::AbstractExpressionNode)
157+
return Base.:(==)(promote(a, b)...)
158+
end
159+
function Base.:(==)(a::N, b::N)::Bool where {N<:AbstractExpressionNode}
160+
if preserve_sharing(N)
161161
return inner_is_equal_shared(a, b, Dict{UInt,Nothing}(), Dict{UInt,Nothing}())
162162
else
163163
return inner_is_equal(a, b)

test/test_graphs.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,29 @@ end
368368
@test (x + 1).l === x
369369
end
370370
end
371+
372+
@testset "Joint operations" begin
373+
operators = OperatorEnum(;
374+
binary_operators=(+, -, *, ^, /), unary_operators=(cos, exp, sin)
375+
)
376+
x = GraphNode(Float64; feature=1)
377+
y = Node(Float64; feature=1)
378+
379+
@test x == y
380+
381+
@test promote(x, y) isa Tuple{typeof(x),typeof(x)}
382+
383+
# Node with GraphNode - will convert both
384+
tree1 = sin(x) * x
385+
tree2 = sin(y) * y
386+
@test tree1 != tree2
387+
388+
# GraphNode against GraphNode
389+
tree1 = sin(x) * x
390+
tree2 = sin(x) * x
391+
@test tree1 == tree2
392+
393+
# Is aware of different shared structure
394+
tree2 = sin(x) * GraphNode(Float64; feature=1)
395+
@test tree1 != tree2
396+
end

0 commit comments

Comments
 (0)