Skip to content

Commit ed5deaa

Browse files
committed
fix: extract_gradient for regular expressions
1 parent d976098 commit ed5deaa

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

src/Expression.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,7 @@ function set_constants!(ex::Expression{T}, constants, refs) where {T}
272272
return set_constants!(get_tree(ex), constants, refs)
273273
end
274274
function extract_gradient(
275-
gradient::@NamedTuple{
276-
tree::NT,
277-
metadata::@NamedTuple{
278-
_data::@NamedTuple{operators::Nothing, variable_names::Nothing}
279-
}
280-
},
281-
ex::Expression{T,N},
275+
gradient::@NamedTuple{tree::NT, metadata::Nothing}, ex::Expression{T,N}
282276
) where {T,N<:AbstractExpressionNode{T},NT<:NodeTangent{T,N}}
283277
# TODO: This messy gradient type is produced by ChainRules. There is probably a better way to do this.
284278
return extract_gradient(gradient.tree, get_tree(ex))

test/test_expressions.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,19 @@ end
7676
end
7777
end
7878

79+
@testitem "Can also get derivatives of expression itself" begin
80+
using DynamicExpressions
81+
using Zygote: Zygote
82+
using DifferentiationInterface: AutoZygote, gradient
83+
84+
ex = @parse_expression(x1 + 1.5, binary_operators = [+], variable_names = ["x1"])
85+
d_ex = gradient(AutoZygote(), ex) do ex
86+
sum(ex(ones(1, 5)))
87+
end
88+
@test d_ex isa NamedTuple
89+
@test extract_gradient(d_ex, ex) [5.0]
90+
end
91+
7992
@testitem "Expression simplification" begin
8093
using DynamicExpressions
8194

0 commit comments

Comments
 (0)