Skip to content

Commit 327c8d2

Browse files
committed
test: improve MultiScalarExpression test
1 parent 9061cd8 commit 327c8d2

File tree

3 files changed

+99
-89
lines changed

3 files changed

+99
-89
lines changed

src/ParametricExpression.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -331,15 +331,6 @@ end
331331
return node_type(; val=ex)
332332
end
333333
end
334-
335-
# And easy evaluation
336-
function (ex::ParametricExpression)(X, classes, operators=nothing; kws...)
337-
out, complete = eval_tree_array(ex, X, classes, operators; kws...)
338-
if !complete
339-
out .= NaN
340-
end
341-
return out
342-
end
343334
###############################################################################
344335

345336
end

test/test_multi_expression.jl

Lines changed: 98 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,111 @@
1-
"""Test if we can create a multi-expression expression type."""
1+
@testitem "Test if we can create a multi-expression expression type." begin
2+
using DynamicExpressions
3+
using DynamicExpressions: DynamicExpressions as DE
4+
using DynamicExpressions: Metadata, ExpressionInterface
5+
using Interfaces: Interfaces, test, @implements, Arguments
26

3-
using Test
4-
using DynamicExpressions
5-
using DynamicExpressions: DynamicExpressions as DE
6-
using DynamicExpressions: Metadata
7+
struct MultiScalarExpression{
8+
T,N<:AbstractExpressionNode{T},TREES<:NamedTuple,D<:NamedTuple
9+
} <: AbstractExpression{T,N}
10+
trees::TREES
11+
metadata::Metadata{D}
712

8-
struct MultiScalarExpression{
9-
T,N<:AbstractExpressionNode{T},TREES<:NamedTuple,D<:NamedTuple
10-
} <: AbstractExpression{T,N}
11-
trees::TREES
12-
metadata::Metadata{D}
13+
"""
14+
Create a multi-expression expression type.
1315
14-
"""
15-
Create a multi-expression expression type.
16-
17-
The `tree_factory` is a function that takes the trees by keyword argument,
18-
and stitches them together into a single tree (for printing or evaluation).
19-
"""
20-
function MultiScalarExpression(
21-
trees::NamedTuple; tree_factory::F, operators, variable_names
22-
) where {F<:Function}
23-
example_tree = first(values(trees))
24-
N = typeof(example_tree)
25-
T = eltype(example_tree)
26-
@assert all(t -> eltype(t) == T, values(trees))
27-
metadata = (; tree_factory, operators, variable_names)
28-
return new{T,N,typeof(trees),typeof(metadata)}(trees, Metadata(metadata))
16+
The `tree_factory` is a function that takes the trees by keyword argument,
17+
and stitches them together into a single tree (for printing or evaluation).
18+
"""
19+
function MultiScalarExpression(
20+
trees::NamedTuple; tree_factory::F, operators, variable_names
21+
) where {F<:Function}
22+
example_tree = first(values(trees))
23+
N = typeof(example_tree)
24+
T = eltype(example_tree)
25+
@assert all(t -> eltype(t) == T, values(trees))
26+
metadata = (; tree_factory, operators, variable_names)
27+
return new{T,N,typeof(trees),typeof(metadata)}(trees, Metadata(metadata))
28+
end
2929
end
30-
end
31-
32-
operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin, cos, exp])
33-
variable_names = ["a", "b", "c"]
3430

35-
ex1 = @parse_expression(c * 2.5 - cos(a), operators, variable_names)
36-
ex2 = @parse_expression(b * b * b + c / 0.2, operators, variable_names)
31+
operators = OperatorEnum(;
32+
binary_operators=[+, -, *, /], unary_operators=[sin, cos, exp]
33+
)
34+
variable_names = ["a", "b", "c"]
3735

38-
multi_ex = MultiScalarExpression(
39-
(; f=ex1.tree, g=ex2.tree);
40-
tree_factory=(; f, g) -> :($f + cos($g)),
41-
# TODO: Can we have a custom evaluation routine here, to enable aggregations in the middle part?
42-
operators,
43-
variable_names,
44-
)
36+
ex1 = @parse_expression(c * 2.5 - cos(a), operators, variable_names)
37+
ex2 = @parse_expression(b * b * b + c / 0.2, operators, variable_names)
4538

46-
# Verify that the unimplemented methods raise an error
47-
if VERSION >= v"1.9"
48-
@test_throws "`get_operators` function must be implemented for" DE.get_operators(
49-
multi_ex, nothing
50-
)
51-
@test_throws "`get_variable_names` function must be implemented for" DE.get_variable_names(
52-
multi_ex, nothing
39+
multi_ex = MultiScalarExpression(
40+
(; f=ex1.tree, g=ex2.tree);
41+
tree_factory=(; f, g) -> :($f + cos($g)),
42+
# TODO: Can we have a custom evaluation routine here, to enable aggregations in the middle part?
43+
operators,
44+
variable_names,
5345
)
54-
@test_throws "`get_tree` function must be implemented for" DE.get_tree(multi_ex)
55-
@test_throws "`copy` function must be implemented for" copy(multi_ex)
56-
@test_throws "`hash` function must be implemented for" hash(multi_ex, UInt(0))
57-
@test_throws "`==` function must be implemented for" multi_ex == multi_ex
58-
@test_throws "`get_constants` function must be implemented for" get_constants(multi_ex)
59-
@test_throws "`set_constants!` function must be implemented for" set_constants!(
60-
multi_ex, nothing, nothing
61-
)
62-
end
6346

64-
tree_factory(f::F, trees) where {F} = f(; trees...)
65-
function DE.get_tree(ex::MultiScalarExpression{T,N}) where {T,N}
66-
fused_expression = parse_expression(
67-
tree_factory(ex.metadata.tree_factory, ex.trees)::Expr;
68-
calling_module=@__MODULE__, # TODO: Not needed
69-
operators=DE.get_operators(ex, nothing),
70-
variable_names=nothing,
71-
node_type=N,
72-
expression_type=Expression,
73-
)::Expression{T,N}
74-
return fused_expression.tree
75-
end
76-
function DE.get_operators(ex::MultiScalarExpression, operators)
77-
return operators === nothing ? ex.metadata.operators : operators
78-
end
79-
function DE.get_variable_names(ex::MultiScalarExpression, variable_names)
80-
return variable_names === nothing ? ex.metadata.variable_names : variable_names
81-
end
47+
# Verify that the unimplemented methods raise an error
48+
if VERSION >= v"1.9"
49+
@test_throws "`get_operators` function must be implemented for" DE.get_operators(
50+
multi_ex, nothing
51+
)
52+
@test_throws "`get_variable_names` function must be implemented for" DE.get_variable_names(
53+
multi_ex, nothing
54+
)
55+
@test_throws "`get_tree` function must be implemented for" DE.get_tree(multi_ex)
56+
@test_throws "`copy` function must be implemented for" copy(multi_ex)
57+
@test_throws "`hash` function must be implemented for" hash(multi_ex, UInt(0))
58+
@test_throws "`==` function must be implemented for" multi_ex == multi_ex
59+
@test_throws "`get_constants` function must be implemented for" get_constants(
60+
multi_ex
61+
)
62+
@test_throws "`set_constants!` function must be implemented for" set_constants!(
63+
multi_ex, nothing, nothing
64+
)
65+
end
8266

83-
s = sprint((io, ex) -> show(io, MIME"text/plain"(), ex), multi_ex)
67+
tree_factory(f::F, trees) where {F} = f(; trees...)
68+
function DE.get_tree(ex::MultiScalarExpression{T,N}) where {T,N}
69+
fused_expression = parse_expression(
70+
tree_factory(ex.metadata.tree_factory, ex.trees)::Expr;
71+
calling_module=@__MODULE__, # TODO: Not needed
72+
operators=DE.get_operators(ex, nothing),
73+
variable_names=nothing,
74+
node_type=N,
75+
expression_type=Expression,
76+
)::Expression{T,N}
77+
return fused_expression.tree
78+
end
79+
function DE.get_operators(ex::MultiScalarExpression, operators=nothing)
80+
return operators === nothing ? ex.metadata.operators : operators
81+
end
82+
function DE.get_variable_names(ex::MultiScalarExpression, variable_names=nothing)
83+
return variable_names === nothing ? ex.metadata.variable_names : variable_names
84+
end
85+
function Base.copy(ex::MultiScalarExpression)
86+
t = NamedTuple{keys(ex.trees)}(map(copy, values(ex.trees)))
87+
m = ex.metadata
88+
return MultiScalarExpression(
89+
t;
90+
tree_factory=m.tree_factory,
91+
operators=copy(m.operators),
92+
variable_names=copy(m.variable_names),
93+
)
94+
end
95+
function Base.:(==)(ex1::MultiScalarExpression, ex2::MultiScalarExpression)
96+
return isempty(Base.structdiff(ex1.trees, ex2.trees)) &&
97+
all(i -> ex1.trees[i] == ex2.trees[i], keys(ex1.trees)) &&
98+
ex1.metadata == ex2.metadata
99+
end
100+
101+
s = sprint((io, ex) -> show(io, MIME"text/plain"(), ex), multi_ex)
84102

85-
@test s == "((c * 2.5) - cos(a)) + cos(((b * b) * b) + (c / 0.2))"
103+
@test s == "((c * 2.5) - cos(a)) + cos(((b * b) * b) + (c / 0.2))"
86104

87-
s = sprint((io, ex) -> print_tree(io, ex), multi_ex)
105+
s = sprint((io, ex) -> print_tree(io, ex), multi_ex)
88106

89-
@test s == "((c * 2.5) - cos(a)) + cos(((b * b) * b) + (c / 0.2))\n"
107+
@test s == "((c * 2.5) - cos(a)) + cos(((b * b) * b) + (c / 0.2))\n"
108+
109+
@implements ExpressionInterface MultiScalarExpression [Arguments()]
110+
test(ExpressionInterface, MultiScalarExpression, [multi_ex])
111+
end

test/unittest.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,6 @@ end
104104
include("test_extra_node_fields.jl")
105105
end
106106

107-
@testitem "Test multi expression" begin
108-
include("test_multi_expression.jl")
109-
end
110-
111107
@testitem "Test containers preserved" begin
112108
include("test_container_preserved.jl")
113109
end
@@ -127,6 +123,7 @@ end
127123
end
128124

129125
include("test_expressions.jl")
126+
include("test_multi_expression.jl")
130127
include("test_parse.jl")
131128
include("test_parametric_expression.jl")
132129
include("test_operator_construction_edgecases.jl")

0 commit comments

Comments
 (0)