Skip to content

Commit 9963ef5

Browse files
committed
feat: introduce generic extract_gradients for flattening gradients
1 parent ff0725a commit 9963ef5

File tree

6 files changed

+50
-3
lines changed

6 files changed

+50
-3
lines changed

src/ChainRules.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ struct NodeTangent{T,N<:AbstractExpressionNode{T},A<:AbstractArray{T}} <: Abstra
1717
tree::N
1818
gradient::A
1919
end
20+
function extract_gradient(gradient::NodeTangent, ::AbstractExpressionNode)
21+
return gradient.gradient
22+
end
2023
function Base.:+(a::NodeTangent, b::NodeTangent)
2124
# @assert a.tree == b.tree
2225
return NodeTangent(a.tree, a.gradient + b.gradient)

src/DynamicExpressions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ import .NodeModule:
6161
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
6262
@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array
6363
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
64-
@reexport import .ChainRulesModule: NodeTangent
64+
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
6565
@reexport import .SimplifyModule: combine_operators, simplify_tree!
6666
@reexport import .EvaluationHelpersModule
6767
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node

src/Expression.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using DispatchDoctor: @unstable
55
using ..NodeModule: AbstractExpressionNode, Node
66
using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
77
using ..UtilsModule: Undefined
8+
using ..ChainRulesModule: NodeTangent
89

910
import ..NodeModule: copy_node, set_node!, count_nodes, tree_mapreduce, constructorof
1011
import ..NodeUtilsModule:
@@ -16,6 +17,7 @@ import ..NodeUtilsModule:
1617
has_constants,
1718
get_constants,
1819
set_constants!
20+
import ..ChainRulesModule: extract_gradient
1921

2022
"""A wrapper for a named tuple to avoid piracy."""
2123
struct Metadata{NT<:NamedTuple}
@@ -140,6 +142,12 @@ end
140142
function set_constants!(ex::AbstractExpression{T}, constants, refs) where {T}
141143
return error("`set_constants!` function must be implemented for $(typeof(ex)) types.")
142144
end
145+
function extract_gradient(gradient, ex::AbstractExpression)
146+
# Should match `get_constants`
147+
return error(
148+
"`extract_gradient` function must be implemented for $(typeof(ex)) types with $(typeof(gradient)) gradient.",
149+
)
150+
end
143151
function get_contents(ex::AbstractExpression)
144152
return error("`get_contents` function must be implemented for $(typeof(ex)) types.")
145153
end
@@ -263,6 +271,18 @@ end
263271
function set_constants!(ex::Expression{T}, constants, refs) where {T}
264272
return set_constants!(get_tree(ex), constants, refs)
265273
end
274+
function extract_gradient(
275+
gradient::@NamedTuple{
276+
tree::NT,
277+
metadata::@NamedTuple{
278+
_data::@NamedTuple{operators::Nothing, variable_names::Nothing}
279+
}
280+
},
281+
ex::Expression{T,N},
282+
) where {T,N<:AbstractExpressionNode{T},NT<:NodeTangent{T,N}}
283+
# TODO: This messy gradient type is produced by ChainRules. There is probably a better way to do this.
284+
return extract_gradient(gradient.tree, get_tree(ex))
285+
end
266286

267287
import ..StringsModule: string_tree, print_tree
268288

src/Interfaces.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ ei_components = (
159159
default_node_type = "returns the default node type for the expression" => _check_default_node,
160160
constructorof = "gets the constructor function for a type" => _check_constructorof,
161161
tree_mapreduce = "applies a function across the tree" => _check_tree_mapreduce
162+
# TODO: add extract_gradient(gradient, ex::AbstractExpression)
162163
)
163164
)
164165
ei_description = (

src/ParametricExpression.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import ..StringsModule: string_tree
2020
import ..EvaluateModule: eval_tree_array
2121
import ..EvaluateDerivativeModule: eval_grad_tree_array
2222
import ..EvaluationHelpersModule: _grad_evaluator
23+
import ..ChainRulesModule: extract_gradient
2324
import ..ExpressionModule:
2425
get_contents,
2526
get_metadata,
@@ -207,7 +208,7 @@ has_constants(ex::ParametricExpression) = _interface_error()
207208
has_operators(ex::ParametricExpression) = has_operators(get_tree(ex))
208209
function get_constants(ex::ParametricExpression{T}) where {T}
209210
constants, constant_refs = get_constants(get_tree(ex))
210-
parameters = ex.metadata.parameters
211+
parameters = get_metadata(ex).parameters
211212
flat_parameters = parameters[:]
212213
num_constants = length(constants)
213214
num_parameters = length(flat_parameters)
@@ -218,9 +219,27 @@ function set_constants!(ex::ParametricExpression{T}, x, refs) where {T}
218219
# First, set the usual constants
219220
set_constants!(get_tree(ex), @view(x[1:(refs.num_constants)]), refs.constant_refs)
220221
# Then, copy in the parameters
221-
ex.metadata.parameters[:] .= @view(x[(refs.num_constants + 1):end])
222+
get_metadata(ex).parameters[:] .= @view(x[(refs.num_constants + 1):end])
222223
return ex
223224
end
225+
function extract_gradient(
226+
gradient::@NamedTuple{
227+
tree::NT,
228+
metadata::@NamedTuple{
229+
_data::@NamedTuple{
230+
operators::Nothing,
231+
variable_names::Nothing,
232+
parameters::PARAM,
233+
parameter_names::Nothing,
234+
}
235+
}
236+
},
237+
ex::ParametricExpression{T,N},
238+
) where {T,N<:ParametricNode{T},NT<:NodeTangent{T,N},PARAM<:AbstractMatrix{T}}
239+
d_constants = extract_gradient(gradient.tree, get_tree(ex))
240+
d_params = gradient.metadata._data.parameters[:]
241+
return vcat(d_constants, d_params) # Same shape as `get_constants`
242+
end
224243

225244
function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T}
226245
num_params = UInt16(size(ex.metadata.parameters, 1))

test/test_parametric_expression.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,4 +316,8 @@ end
316316
@test grad.tree.gradient true_grad[3]
317317
# Gradient w.r.t. the parameters:
318318
@test grad.metadata._data.parameters true_grad[2]
319+
320+
# Gradient extractor
321+
@test extract_gradient(grad, ex) vcat(true_grad[3], true_grad[2][:])
322+
@test axes(extract_gradient(grad, ex)) == axes(first(get_constants(ex)))
319323
end

0 commit comments

Comments
 (0)