Skip to content

Commit 4b5d681

Browse files
committed
test: fix type instability in multi expression
1 parent 3fe8aa0 commit 4b5d681

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

test/test_multi_expression.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ using DynamicExpressions
55
using DynamicExpressions: DynamicExpressions as DE
66
using DynamicExpressions: Metadata
77

8-
struct MultiScalarExpression{T,TREES<:NamedTuple,D<:NamedTuple} <: AbstractExpression{T}
8+
struct MultiScalarExpression{
9+
T,N<:AbstractExpressionNode{T},TREES<:NamedTuple,D<:NamedTuple
10+
} <: AbstractExpression{T}
911
trees::TREES
1012
metadata::Metadata{D}
1113

@@ -18,10 +20,12 @@ struct MultiScalarExpression{T,TREES<:NamedTuple,D<:NamedTuple} <: AbstractExpre
1820
function MultiScalarExpression(
1921
trees::NamedTuple; tree_factory::F, operators, variable_names
2022
) where {F<:Function}
21-
T = eltype(first(values(trees)))
23+
example_tree = first(values(trees))
24+
N = typeof(example_tree)
25+
T = eltype(example_tree)
2226
@assert all(t -> eltype(t) == T, values(trees))
2327
metadata = (; tree_factory, operators, variable_names)
24-
return new{T,typeof(trees),typeof(metadata)}(trees, Metadata(metadata))
28+
return new{T,N,typeof(trees),typeof(metadata)}(trees, Metadata(metadata))
2529
end
2630
end
2731

@@ -54,15 +58,15 @@ if VERSION >= v"1.9"
5458
end
5559

5660
tree_factory(f::F, trees) where {F} = f(; trees...)
57-
function DE.get_tree(ex::MultiScalarExpression{N}) where {N}
61+
function DE.get_tree(ex::MultiScalarExpression{T,N}) where {T,N}
5862
fused_expression = parse_expression(
5963
tree_factory(ex.metadata.tree_factory, ex.trees)::Expr;
6064
calling_module=@__MODULE__, # TODO: Not needed
6165
operators=DE.get_operators(ex, nothing),
6266
variable_names=nothing,
63-
node_type=typeof(first(values(ex.trees))),
67+
node_type=N,
6468
expression_type=Expression,
65-
)
69+
)::Expression{T,N}
6670
return fused_expression.tree
6771
end
6872
function DE.get_operators(ex::MultiScalarExpression, operators)

0 commit comments

Comments
 (0)