@@ -11,21 +11,14 @@ function create_evaluation_helpers!(operators::OperatorEnum)
1111 @eval begin
1212 Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
1313 Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
14- function (tree:: Node{T} )(X:: AbstractArray{T,2} ):: AbstractArray{T,1} where {T<: Real }
14+ function (tree:: Node )(X; kws... )
15+ length (keys (kws)) > 1 && error (" Unknown keyword argument: $(key) " )
1516 out, did_finish = eval_tree_array (tree, X, $ operators)
1617 if ! did_finish
1718 out .= T (NaN )
1819 end
1920 return out
2021 end
21- function (tree:: Node{T1} )(X:: AbstractArray{T2,2} ) where {T1<: Real ,T2<: Real }
22- if T1 != T2
23- T = promote_type (T1, T2)
24- tree = convert (Node{T}, tree)
25- X = T .(X)
26- end
27- return tree (X)
28- end
2922 # Gradients:
3023 function Base. adjoint (tree:: Node{T} ) where {T}
3124 return X -> begin
@@ -42,7 +35,15 @@ function create_evaluation_helpers!(operators::GenericOperatorEnum)
4235 Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
4336 Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
4437
45- function (tree:: Node )(X; throw_errors:: Bool = true )
38+ function (tree:: Node )(X; kws... )
39+ throw_errors = true
40+ for key in keys (kws)
41+ if key == :throw_errors
42+ throw_errors = kws[key]
43+ else
44+ error (" Unknown keyword argument: $(key) " )
45+ end
46+ end
4647 out, did_finish = eval_tree_array (
4748 tree, X, $ operators; throw_errors= throw_errors
4849 )
@@ -51,6 +52,11 @@ function create_evaluation_helpers!(operators::GenericOperatorEnum)
5152 end
5253 return out
5354 end
55+ function Base. adjoint (:: Node{T} ) where {T}
56+ return _ -> begin
57+ error (" Gradients are not implemented for `GenericOperatorEnum`." )
58+ end
59+ end
5460 end
5561end
5662
0 commit comments