Skip to content

Commit a758af9

Browse files
committed
Avoid specializing OperatorEnum
1 parent 6f75504 commit a758af9

File tree

3 files changed

+16
-21
lines changed

3 files changed

+16
-21
lines changed

src/EvaluateEquationDerivative.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import ..EquationUtilsModule: count_constants, index_constants, NodeIndex
88
import ..EvaluateEquationModule: deg0_eval
99

1010
function assert_autodiff_enabled(operators::OperatorEnum)
11-
if operators.diff_binops === nothing && operators.diff_unaops === nothing
11+
if length(operators.diff_binops) == 0 && length(operators.diff_unaops) == 0
1212
error(
1313
"Found no differential operators. Did you forget to set `enable_autodiff=true` when creating the `OperatorEnum`?",
1414
)

src/OperatorEnum.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@ Defines an enum over operators, along with their derivatives.
1212
- `diff_binops`: A tuple of Zygote-computed derivatives of the binary operators.
1313
- `diff_unaops`: A tuple of Zygote-computed derivatives of the unary operators.
1414
"""
15-
struct OperatorEnum{A<:Tuple,B<:Tuple,dA<:Union{Tuple,Nothing},dB<:Union{Tuple,Nothing}} <:
16-
AbstractOperatorEnum
17-
binops::A
18-
unaops::B
19-
diff_binops::dA
20-
diff_unaops::dB
15+
struct OperatorEnum <: AbstractOperatorEnum
16+
binops::Vector{Function}
17+
unaops::Vector{Function}
18+
diff_binops::Vector{Function}
19+
diff_unaops::Vector{Function}
2120
end
2221

2322
"""
@@ -30,9 +29,9 @@ Defines an enum over operators, along with their derivatives.
3029
- `diff_binops`: A tuple of Zygote-computed derivatives of the binary operators.
3130
- `diff_unaops`: A tuple of Zygote-computed derivatives of the unary operators.
3231
"""
33-
struct GenericOperatorEnum{A<:Tuple,B<:Tuple} <: AbstractOperatorEnum
34-
binops::A
35-
unaops::B
32+
struct GenericOperatorEnum <: AbstractOperatorEnum
33+
binops::Vector{Function}
34+
unaops::Vector{Function}
3635
end
3736

3837
end

src/OperatorEnumConstruction.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,13 @@ function OperatorEnum(;
217217
define_helper_functions::Bool=true,
218218
)
219219
@assert length(binary_operators) > 0 || length(unary_operators) > 0
220-
binary_operators = Tuple(binary_operators)
221-
unary_operators = Tuple(unary_operators)
220+
221+
binary_operators = convert(Vector{Function}, collect(binary_operators))
222+
unary_operators = convert(Vector{Function}, collect(unary_operators))
222223

223224
if enable_autodiff
224-
diff_binary_operators = Any[]
225-
diff_unary_operators = Any[]
225+
diff_binary_operators = Function[]
226+
diff_unary_operators = Function[]
226227

227228
test_inputs = map(x -> convert(Float32, x), LinRange(-100, 100, 99))
228229
# Create grid over [-100, 100]^2:
@@ -259,13 +260,11 @@ function OperatorEnum(;
259260
break
260261
end
261262
end
262-
diff_binary_operators = Tuple(diff_binary_operators)
263-
diff_unary_operators = Tuple(diff_unary_operators)
264263
end
265264

266265
if !enable_autodiff
267-
diff_binary_operators = nothing
268-
diff_unary_operators = nothing
266+
diff_binary_operators = Function[]
267+
diff_unary_operators = Function[]
269268
end
270269

271270
operators = OperatorEnum(
@@ -300,9 +299,6 @@ and `(::Node)(X)`.
300299
function GenericOperatorEnum(;
301300
binary_operators=[], unary_operators=[], define_helper_functions::Bool=true
302301
)
303-
binary_operators = Tuple(binary_operators)
304-
unary_operators = Tuple(unary_operators)
305-
306302
@assert length(binary_operators) > 0 || length(unary_operators) > 0
307303

308304
operators = GenericOperatorEnum(binary_operators, unary_operators)

0 commit comments

Comments
 (0)