Skip to content

Commit 78b200c

Browse files
committed
Ensure that evaluation helpers always overridden
1 parent e60787c commit 78b200c

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

src/OperatorEnumConstruction.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,14 @@ function create_evaluation_helpers!(operators::OperatorEnum)
1111
@eval begin
1212
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
1313
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
14-
function (tree::Node{T})(X::AbstractArray{T,2})::AbstractArray{T,1} where {T<:Real}
14+
function (tree::Node)(X; kws...)
15+
length(keys(kws)) > 1 && error("Unknown keyword argument: $(key)")
1516
out, did_finish = eval_tree_array(tree, X, $operators)
1617
if !did_finish
1718
out .= T(NaN)
1819
end
1920
return out
2021
end
21-
function (tree::Node{T1})(X::AbstractArray{T2,2}) where {T1<:Real,T2<:Real}
22-
if T1 != T2
23-
T = promote_type(T1, T2)
24-
tree = convert(Node{T}, tree)
25-
X = T.(X)
26-
end
27-
return tree(X)
28-
end
2922
# Gradients:
3023
function Base.adjoint(tree::Node{T}) where {T}
3124
return X -> begin
@@ -42,7 +35,15 @@ function create_evaluation_helpers!(operators::GenericOperatorEnum)
4235
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
4336
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
4437

45-
function (tree::Node)(X; throw_errors::Bool=true)
38+
function (tree::Node)(X; kws...)
39+
throw_errors = true
40+
for key in keys(kws)
41+
if key == :throw_errors
42+
throw_errors = kws[key]
43+
else
44+
error("Unknown keyword argument: $(key)")
45+
end
46+
end
4647
out, did_finish = eval_tree_array(
4748
tree, X, $operators; throw_errors=throw_errors
4849
)
@@ -51,6 +52,11 @@ function create_evaluation_helpers!(operators::GenericOperatorEnum)
5152
end
5253
return out
5354
end
55+
function Base.adjoint(::Node{T}) where {T}
56+
return _ -> begin
57+
error("Gradients are not implemented for `GenericOperatorEnum`.")
58+
end
59+
end
5460
end
5561
end
5662

0 commit comments

Comments
 (0)