@@ -179,7 +179,18 @@ function preserve_sharing(::Union{E,Type{E}}) where {T,N,E<:AbstractExpression{T
179179 return preserve_sharing (N)
180180end
181181
182- function get_operators (ex:: Expression , operators= nothing )
182+ function get_operators (
183+ tree:: AbstractExpressionNode , operators:: Union{AbstractOperatorEnum,Nothing} = nothing
184+ )
185+ if operators === nothing
186+ throw (ArgumentError (" `operators` must be provided for $(typeof (tree)) types." ))
187+ else
188+ return operators
189+ end
190+ end
191+ function get_operators (
192+ ex:: Expression , operators:: Union{AbstractOperatorEnum,Nothing} = nothing
193+ )
183194 return operators === nothing ? ex. metadata. operators : operators
184195end
185196function get_variable_names (ex:: Expression , variable_names= nothing )
249260import .. StringsModule: string_tree, print_tree
250261
251262function string_tree (
252- ex:: AbstractExpression , operators= nothing ; variable_names= nothing , kws...
263+ ex:: AbstractExpression ,
264+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
265+ variable_names= nothing ,
266+ kws... ,
253267)
254268 return string_tree (
255269 get_tree (ex),
@@ -260,7 +274,11 @@ function string_tree(
260274end
261275for io in ((), (:(io:: IO ),))
262276 @eval function print_tree (
263- $ (io... ), ex:: AbstractExpression , operators= nothing ; variable_names= nothing , kws...
277+ $ (io... ),
278+ ex:: AbstractExpression ,
279+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
280+ variable_names= nothing ,
281+ kws... ,
264282 )
265283 return println ($ (io... ), string_tree (ex, operators; variable_names, kws... ))
266284 end
@@ -283,7 +301,9 @@ function max_feature(ex::AbstractExpression)
283301 )
284302end
285303
286- function _validate_input (ex:: AbstractExpression , X, operators)
304+ function _validate_input (
305+ ex:: AbstractExpression , X, operators:: Union{AbstractOperatorEnum,Nothing}
306+ )
287307 if get_operators (ex, operators) isa OperatorEnum
288308 @assert X isa AbstractMatrix
289309 @assert max_feature (ex) <= size (X, 1 )
@@ -292,7 +312,10 @@ function _validate_input(ex::AbstractExpression, X, operators)
292312end
293313
294314function eval_tree_array (
295- ex:: AbstractExpression , cX:: AbstractMatrix , operators= nothing ; kws...
315+ ex:: AbstractExpression ,
316+ cX:: AbstractMatrix ,
317+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
318+ kws... ,
296319)
297320 _validate_input (ex, cX, operators)
298321 return eval_tree_array (get_tree (ex), cX, get_operators (ex, operators); kws... )
@@ -305,7 +328,10 @@ import ..EvaluateDerivativeModule: eval_grad_tree_array
305328# - differentiable_eval_tree_array
306329
307330function eval_grad_tree_array (
308- ex:: AbstractExpression , cX:: AbstractMatrix , operators= nothing ; kws...
331+ ex:: AbstractExpression ,
332+ cX:: AbstractMatrix ,
333+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
334+ kws... ,
309335)
310336 _validate_input (ex, cX, operators)
311337 return eval_grad_tree_array (get_tree (ex), cX, get_operators (ex, operators); kws... )
@@ -319,14 +345,16 @@ end
319345function _grad_evaluator (
320346 ex:: AbstractExpression ,
321347 cX:: AbstractMatrix ,
322- operators= nothing ;
348+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
323349 variable= Val (true ),
324350 kws... ,
325351)
326352 _validate_input (ex, cX, operators)
327353 return _grad_evaluator (get_tree (ex), cX, get_operators (ex, operators); variable, kws... )
328354end
329- function (ex:: AbstractExpression )(X, operators= nothing ; kws... )
355+ function (ex:: AbstractExpression )(
356+ X, operators:: Union{AbstractOperatorEnum,Nothing} = nothing ; kws...
357+ )
330358 _validate_input (ex, X, operators)
331359 return get_tree (ex)(X, get_operators (ex, operators); kws... )
332360end
0 commit comments