Skip to content

Commit 9061cd8

Browse files
committed
refactor: error message so parametric expression error easier to appear
1 parent ef10533 commit 9061cd8

File tree

3 files changed

+40
-11
lines changed

3 files changed

+40
-11
lines changed

src/DynamicExpressions.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ import .NodeModule:
6666
@reexport import .EvaluationHelpersModule
6767
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
6868
@reexport import .RandomModule: NodeSampler
69-
@reexport import .ExpressionModule:
70-
AbstractExpression, Expression, with_tree, default_node_type, node_type
71-
import .ExpressionModule: get_tree, get_operators, get_variable_names, Metadata
69+
@reexport import .ExpressionModule: AbstractExpression, Expression, with_tree
70+
import .ExpressionModule:
71+
get_tree, get_operators, get_variable_names, Metadata, default_node_type, node_type
7272
@reexport import .ParseModule: @parse_expression, parse_expression
7373
import .ParseModule: parse_leaf
7474
@reexport import .ParametricExpressionModule: ParametricExpression, ParametricNode

src/ParametricExpression.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,22 @@ function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T}
226226
Node{T},
227227
)
228228
end
229-
function eval_tree_array(
230-
::ParametricExpression{T}, ::AbstractMatrix{T}, operators; kws...
229+
#! format: off
230+
function (ex::ParametricExpression)(X::AbstractMatrix, operators=nothing; kws...)
231+
return eval_tree_array(ex, X, operators; kws...) # Will error
232+
end
233+
function eval_tree_array(::ParametricExpression{T}, ::AbstractMatrix{T}, operators=nothing; kws...) where {T}
234+
return error("Incorrect call. You must pass the `classes::Vector` argument when calling `eval_tree_array`.")
235+
end
236+
#! format: on
237+
function (ex::ParametricExpression)(
238+
X::AbstractMatrix{T}, classes::AbstractVector{<:Integer}, operators=nothing; kws...
231239
) where {T}
232-
return error(
233-
"Incorrect call. You must pass the `classes::Vector` argument when calling `eval_tree_array`.",
234-
)
240+
(output, flag) = eval_tree_array(ex, X, classes, operators; kws...) # Will error
241+
if !flag
242+
output .= NaN
243+
end
244+
return output
235245
end
236246
function eval_tree_array(
237247
ex::ParametricExpression{T},

test/test_parametric_expression.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ end
3737

3838
# Then, with different classes
3939
@test ex(X, [1, 2, 2, 3, 1]) [1.0, 3.0, 2.0, 2.0, 1.0]
40+
41+
# Helpful error if we use it incorrectly
42+
if VERSION >= v"1.9"
43+
@test_throws "Incorrect call. You must pass" ex(X)
44+
end
4045
end
4146

4247
@testitem "2 parameters, 2 variables" begin
@@ -162,6 +167,22 @@ end
162167
@test copy(ex) == ex
163168
end
164169

170+
@testitem "Passing node within ParametricExpression parsing" begin
171+
using DynamicExpressions
172+
173+
tree = ParametricNode{Float32}()
174+
tree.degree = 0
175+
tree.constant = true
176+
tree.val = 1.5
177+
ex = parse_expression(
178+
:($tree);
179+
expression_type=ParametricExpression,
180+
parameters=Float32[;;],
181+
parameter_names=nothing,
182+
)
183+
@test ex.tree == tree
184+
end
185+
165186
@testitem "Parametric expression conversion" begin
166187
using DynamicExpressions
167188

@@ -173,9 +194,7 @@ end
173194
extra_metadata = (; parameters=Float32[;;], parameter_names=nothing)
174195
)
175196
ex = @parse_expression(
176-
x1 + 1.5f0,
177-
binary_operators = [+, -, *],
178-
variable_names = ["x1"],
197+
x1 + 1.5f0, binary_operators = [+, -, *], variable_names = ["x1"],
179198
)
180199

181200
@test pex.tree isa ParametricNode{Float32}

0 commit comments

Comments
 (0)