|
1 | | -"""Test if we can create a multi-expression expression type.""" |
| 1 | +@testitem "Test if we can create a multi-expression expression type." begin |
| 2 | + using DynamicExpressions |
| 3 | + using DynamicExpressions: DynamicExpressions as DE |
| 4 | + using DynamicExpressions: Metadata, ExpressionInterface |
| 5 | + using Interfaces: Interfaces, test, @implements, Arguments |
2 | 6 |
|
3 | | -using Test |
4 | | -using DynamicExpressions |
5 | | -using DynamicExpressions: DynamicExpressions as DE |
6 | | -using DynamicExpressions: Metadata |
| 7 | + struct MultiScalarExpression{ |
| 8 | + T,N<:AbstractExpressionNode{T},TREES<:NamedTuple,D<:NamedTuple |
| 9 | + } <: AbstractExpression{T,N} |
| 10 | + trees::TREES |
| 11 | + metadata::Metadata{D} |
7 | 12 |
|
8 | | -struct MultiScalarExpression{ |
9 | | - T,N<:AbstractExpressionNode{T},TREES<:NamedTuple,D<:NamedTuple |
10 | | -} <: AbstractExpression{T,N} |
11 | | - trees::TREES |
12 | | - metadata::Metadata{D} |
| 13 | + """ |
| 14 | + Create a multi-expression expression type. |
13 | 15 |
|
14 | | - """ |
15 | | - Create a multi-expression expression type. |
16 | | -
|
17 | | - The `tree_factory` is a function that takes the trees by keyword argument, |
18 | | - and stitches them together into a single tree (for printing or evaluation). |
19 | | - """ |
20 | | - function MultiScalarExpression( |
21 | | - trees::NamedTuple; tree_factory::F, operators, variable_names |
22 | | - ) where {F<:Function} |
23 | | - example_tree = first(values(trees)) |
24 | | - N = typeof(example_tree) |
25 | | - T = eltype(example_tree) |
26 | | - @assert all(t -> eltype(t) == T, values(trees)) |
27 | | - metadata = (; tree_factory, operators, variable_names) |
28 | | - return new{T,N,typeof(trees),typeof(metadata)}(trees, Metadata(metadata)) |
| 16 | + The `tree_factory` is a function that takes the trees by keyword argument, |
| 17 | + and stitches them together into a single tree (for printing or evaluation). |
| 18 | + """ |
| 19 | + function MultiScalarExpression( |
| 20 | + trees::NamedTuple; tree_factory::F, operators, variable_names |
| 21 | + ) where {F<:Function} |
| 22 | + example_tree = first(values(trees)) |
| 23 | + N = typeof(example_tree) |
| 24 | + T = eltype(example_tree) |
| 25 | + @assert all(t -> eltype(t) == T, values(trees)) |
| 26 | + metadata = (; tree_factory, operators, variable_names) |
| 27 | + return new{T,N,typeof(trees),typeof(metadata)}(trees, Metadata(metadata)) |
| 28 | + end |
29 | 29 | end |
30 | | -end |
31 | | - |
32 | | -operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin, cos, exp]) |
33 | | -variable_names = ["a", "b", "c"] |
34 | 30 |
|
35 | | -ex1 = @parse_expression(c * 2.5 - cos(a), operators, variable_names) |
36 | | -ex2 = @parse_expression(b * b * b + c / 0.2, operators, variable_names) |
| 31 | + operators = OperatorEnum(; |
| 32 | + binary_operators=[+, -, *, /], unary_operators=[sin, cos, exp] |
| 33 | + ) |
| 34 | + variable_names = ["a", "b", "c"] |
37 | 35 |
|
38 | | -multi_ex = MultiScalarExpression( |
39 | | - (; f=ex1.tree, g=ex2.tree); |
40 | | - tree_factory=(; f, g) -> :($f + cos($g)), |
41 | | - # TODO: Can we have a custom evaluation routine here, to enable aggregations in the middle part? |
42 | | - operators, |
43 | | - variable_names, |
44 | | -) |
| 36 | + ex1 = @parse_expression(c * 2.5 - cos(a), operators, variable_names) |
| 37 | + ex2 = @parse_expression(b * b * b + c / 0.2, operators, variable_names) |
45 | 38 |
|
46 | | -# Verify that the unimplemented methods raise an error |
47 | | -if VERSION >= v"1.9" |
48 | | - @test_throws "`get_operators` function must be implemented for" DE.get_operators( |
49 | | - multi_ex, nothing |
50 | | - ) |
51 | | - @test_throws "`get_variable_names` function must be implemented for" DE.get_variable_names( |
52 | | - multi_ex, nothing |
| 39 | + multi_ex = MultiScalarExpression( |
| 40 | + (; f=ex1.tree, g=ex2.tree); |
| 41 | + tree_factory=(; f, g) -> :($f + cos($g)), |
| 42 | + # TODO: Can we have a custom evaluation routine here, to enable aggregations in the middle part? |
| 43 | + operators, |
| 44 | + variable_names, |
53 | 45 | ) |
54 | | - @test_throws "`get_tree` function must be implemented for" DE.get_tree(multi_ex) |
55 | | - @test_throws "`copy` function must be implemented for" copy(multi_ex) |
56 | | - @test_throws "`hash` function must be implemented for" hash(multi_ex, UInt(0)) |
57 | | - @test_throws "`==` function must be implemented for" multi_ex == multi_ex |
58 | | - @test_throws "`get_constants` function must be implemented for" get_constants(multi_ex) |
59 | | - @test_throws "`set_constants!` function must be implemented for" set_constants!( |
60 | | - multi_ex, nothing, nothing |
61 | | - ) |
62 | | -end |
63 | 46 |
|
64 | | -tree_factory(f::F, trees) where {F} = f(; trees...) |
65 | | -function DE.get_tree(ex::MultiScalarExpression{T,N}) where {T,N} |
66 | | - fused_expression = parse_expression( |
67 | | - tree_factory(ex.metadata.tree_factory, ex.trees)::Expr; |
68 | | - calling_module=@__MODULE__, # TODO: Not needed |
69 | | - operators=DE.get_operators(ex, nothing), |
70 | | - variable_names=nothing, |
71 | | - node_type=N, |
72 | | - expression_type=Expression, |
73 | | - )::Expression{T,N} |
74 | | - return fused_expression.tree |
75 | | -end |
76 | | -function DE.get_operators(ex::MultiScalarExpression, operators) |
77 | | - return operators === nothing ? ex.metadata.operators : operators |
78 | | -end |
79 | | -function DE.get_variable_names(ex::MultiScalarExpression, variable_names) |
80 | | - return variable_names === nothing ? ex.metadata.variable_names : variable_names |
81 | | -end |
| 47 | + # Verify that the unimplemented methods raise an error |
| 48 | + if VERSION >= v"1.9" |
| 49 | + @test_throws "`get_operators` function must be implemented for" DE.get_operators( |
| 50 | + multi_ex, nothing |
| 51 | + ) |
| 52 | + @test_throws "`get_variable_names` function must be implemented for" DE.get_variable_names( |
| 53 | + multi_ex, nothing |
| 54 | + ) |
| 55 | + @test_throws "`get_tree` function must be implemented for" DE.get_tree(multi_ex) |
| 56 | + @test_throws "`copy` function must be implemented for" copy(multi_ex) |
| 57 | + @test_throws "`hash` function must be implemented for" hash(multi_ex, UInt(0)) |
| 58 | + @test_throws "`==` function must be implemented for" multi_ex == multi_ex |
| 59 | + @test_throws "`get_constants` function must be implemented for" get_constants( |
| 60 | + multi_ex |
| 61 | + ) |
| 62 | + @test_throws "`set_constants!` function must be implemented for" set_constants!( |
| 63 | + multi_ex, nothing, nothing |
| 64 | + ) |
| 65 | + end |
82 | 66 |
|
83 | | -s = sprint((io, ex) -> show(io, MIME"text/plain"(), ex), multi_ex) |
| 67 | + tree_factory(f::F, trees) where {F} = f(; trees...) |
| 68 | + function DE.get_tree(ex::MultiScalarExpression{T,N}) where {T,N} |
| 69 | + fused_expression = parse_expression( |
| 70 | + tree_factory(ex.metadata.tree_factory, ex.trees)::Expr; |
| 71 | + calling_module=@__MODULE__, # TODO: Not needed |
| 72 | + operators=DE.get_operators(ex, nothing), |
| 73 | + variable_names=nothing, |
| 74 | + node_type=N, |
| 75 | + expression_type=Expression, |
| 76 | + )::Expression{T,N} |
| 77 | + return fused_expression.tree |
| 78 | + end |
| 79 | + function DE.get_operators(ex::MultiScalarExpression, operators=nothing) |
| 80 | + return operators === nothing ? ex.metadata.operators : operators |
| 81 | + end |
| 82 | + function DE.get_variable_names(ex::MultiScalarExpression, variable_names=nothing) |
| 83 | + return variable_names === nothing ? ex.metadata.variable_names : variable_names |
| 84 | + end |
| 85 | + function Base.copy(ex::MultiScalarExpression) |
| 86 | + t = NamedTuple{keys(ex.trees)}(map(copy, values(ex.trees))) |
| 87 | + m = ex.metadata |
| 88 | + return MultiScalarExpression( |
| 89 | + t; |
| 90 | + tree_factory=m.tree_factory, |
| 91 | + operators=copy(m.operators), |
| 92 | + variable_names=copy(m.variable_names), |
| 93 | + ) |
| 94 | + end |
| 95 | + function Base.:(==)(ex1::MultiScalarExpression, ex2::MultiScalarExpression) |
| 96 | + return isempty(Base.structdiff(ex1.trees, ex2.trees)) && |
| 97 | + all(i -> ex1.trees[i] == ex2.trees[i], keys(ex1.trees)) && |
| 98 | + ex1.metadata == ex2.metadata |
| 99 | + end |
| 100 | + |
| 101 | + s = sprint((io, ex) -> show(io, MIME"text/plain"(), ex), multi_ex) |
84 | 102 |
|
85 | | -@test s == "((c * 2.5) - cos(a)) + cos(((b * b) * b) + (c / 0.2))" |
| 103 | + @test s == "((c * 2.5) - cos(a)) + cos(((b * b) * b) + (c / 0.2))" |
86 | 104 |
|
87 | | -s = sprint((io, ex) -> print_tree(io, ex), multi_ex) |
| 105 | + s = sprint((io, ex) -> print_tree(io, ex), multi_ex) |
88 | 106 |
|
89 | | -@test s == "((c * 2.5) - cos(a)) + cos(((b * b) * b) + (c / 0.2))\n" |
| 107 | + @test s == "((c * 2.5) - cos(a)) + cos(((b * b) * b) + (c / 0.2))\n" |
| 108 | + |
| 109 | + @implements ExpressionInterface MultiScalarExpression [Arguments()] |
| 110 | + test(ExpressionInterface, MultiScalarExpression, [multi_ex]) |
| 111 | +end |
0 commit comments