Skip to content

Commit 1fcdfce

Browse files
committed
feat: generalize get_operators
1 parent dbb2866 commit 1fcdfce

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,18 @@ end
111111
function Base.convert(
112112
::typeof(SymbolicUtils.Symbolic),
113113
tree::Union{AbstractExpression,AbstractExpressionNode},
114-
operators::AbstractOperatorEnum;
114+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
115115
variable_names::Union{Array{String,1},Nothing}=nothing,
116116
index_functions::Bool=false,
117117
# Deprecated:
118118
varMap=nothing,
119119
)
120120
variable_names = deprecate_varmap(variable_names, varMap, :convert)
121121
return node_to_symbolic(
122-
tree, operators; variable_names=variable_names, index_functions=index_functions
122+
tree,
123+
get_operators(tree, operators);
124+
variable_names=variable_names,
125+
index_functions=index_functions,
123126
)
124127
end
125128

src/Expression.jl

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,18 @@ function preserve_sharing(::Union{E,Type{E}}) where {T,N,E<:AbstractExpression{T
179179
return preserve_sharing(N)
180180
end
181181

182-
function get_operators(ex::Expression, operators=nothing)
182+
function get_operators(
183+
tree::AbstractExpressionNode, operators::Union{AbstractOperatorEnum,Nothing}=nothing
184+
)
185+
if operators === nothing
186+
throw(ArgumentError("`operators` must be provided for $(typeof(tree)) types."))
187+
else
188+
return operators
189+
end
190+
end
191+
function get_operators(
192+
ex::Expression, operators::Union{AbstractOperatorEnum,Nothing}=nothing
193+
)
183194
return operators === nothing ? ex.metadata.operators : operators
184195
end
185196
function get_variable_names(ex::Expression, variable_names=nothing)
@@ -249,7 +260,10 @@ end
249260
import ..StringsModule: string_tree, print_tree
250261

251262
function string_tree(
252-
ex::AbstractExpression, operators=nothing; variable_names=nothing, kws...
263+
ex::AbstractExpression,
264+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
265+
variable_names=nothing,
266+
kws...,
253267
)
254268
return string_tree(
255269
get_tree(ex),
@@ -260,7 +274,11 @@ function string_tree(
260274
end
261275
for io in ((), (:(io::IO),))
262276
@eval function print_tree(
263-
$(io...), ex::AbstractExpression, operators=nothing; variable_names=nothing, kws...
277+
$(io...),
278+
ex::AbstractExpression,
279+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
280+
variable_names=nothing,
281+
kws...,
264282
)
265283
return println($(io...), string_tree(ex, operators; variable_names, kws...))
266284
end
@@ -283,7 +301,9 @@ function max_feature(ex::AbstractExpression)
283301
)
284302
end
285303

286-
function _validate_input(ex::AbstractExpression, X, operators)
304+
function _validate_input(
305+
ex::AbstractExpression, X, operators::Union{AbstractOperatorEnum,Nothing}
306+
)
287307
if get_operators(ex, operators) isa OperatorEnum
288308
@assert X isa AbstractMatrix
289309
@assert max_feature(ex) <= size(X, 1)
@@ -292,7 +312,10 @@ function _validate_input(ex::AbstractExpression, X, operators)
292312
end
293313

294314
function eval_tree_array(
295-
ex::AbstractExpression, cX::AbstractMatrix, operators=nothing; kws...
315+
ex::AbstractExpression,
316+
cX::AbstractMatrix,
317+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
318+
kws...,
296319
)
297320
_validate_input(ex, cX, operators)
298321
return eval_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
@@ -305,7 +328,10 @@ import ..EvaluateDerivativeModule: eval_grad_tree_array
305328
# - differentiable_eval_tree_array
306329

307330
function eval_grad_tree_array(
308-
ex::AbstractExpression, cX::AbstractMatrix, operators=nothing; kws...
331+
ex::AbstractExpression,
332+
cX::AbstractMatrix,
333+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
334+
kws...,
309335
)
310336
_validate_input(ex, cX, operators)
311337
return eval_grad_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
@@ -319,14 +345,16 @@ end
319345
function _grad_evaluator(
320346
ex::AbstractExpression,
321347
cX::AbstractMatrix,
322-
operators=nothing;
348+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
323349
variable=Val(true),
324350
kws...,
325351
)
326352
_validate_input(ex, cX, operators)
327353
return _grad_evaluator(get_tree(ex), cX, get_operators(ex, operators); variable, kws...)
328354
end
329-
function (ex::AbstractExpression)(X, operators=nothing; kws...)
355+
function (ex::AbstractExpression)(
356+
X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...
357+
)
330358
_validate_input(ex, X, operators)
331359
return get_tree(ex)(X, get_operators(ex, operators); kws...)
332360
end

0 commit comments

Comments
 (0)