Skip to content

Commit 1aabe1d

Browse files
committed
feat: add ChainRulesCore.rrule for eval_tree_array
1 parent 27b6199 commit 1aabe1d

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
44
version = "0.16.0"
55

66
[deps]
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
910
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"

src/ChainRules.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
module ChainRulesModule
2+
3+
using ChainRulesCore:
4+
ChainRulesCore, AbstractTangent, NoTangent, ZeroTangent, Tangent, @thunk, canonicalize
5+
using ..OperatorEnumModule: OperatorEnum
6+
using ..EquationModule: AbstractExpressionNode, with_type_parameters, tree_mapreduce
7+
using ..EvaluateEquationModule: eval_tree_array
8+
using ..EvaluateEquationDerivativeModule: eval_grad_tree_array
9+
10+
struct NodeTangent{T,N<:AbstractExpressionNode{T},A<:AbstractArray{T}} <: AbstractTangent
11+
tree::N
12+
gradient::A
13+
end
14+
function Base.:+(a::NodeTangent, b::NodeTangent)
15+
@assert a.tree == b.tree
16+
return NodeTangent(a.tree, a.gradient + b.gradient)
17+
end
18+
Base.:*(a::Number, b::NodeTangent) = NodeTangent(b.tree, a * b.gradient)
19+
Base.:*(a::NodeTangent, b::Number) = NodeTangent(a.tree, a.gradient * b)
20+
Base.zero(::Union{Type{NodeTangent},NodeTangent}) = ZeroTangent()
21+
22+
function ChainRulesCore.rrule(
23+
::typeof(eval_tree_array),
24+
tree::AbstractExpressionNode,
25+
X::AbstractMatrix,
26+
operators::OperatorEnum;
27+
turbo=Val(false),
28+
bumper=Val(false),
29+
)
30+
primal, complete = eval_tree_array(tree, X, operators; turbo, bumper)
31+
32+
if !complete
33+
primal .= NaN
34+
end
35+
36+
# TODO: Preferable to use the primal in the pullback somehow
37+
function pullback((dY, _))
38+
dtree = let dY = dY, tree = tree, operators = operators
39+
@thunk(
40+
let
41+
_, gradient, complete = eval_grad_tree_array(
42+
tree, X, operators; variable=Val(false)
43+
)
44+
if !complete
45+
gradient .= NaN
46+
end
47+
48+
NodeTangent(
49+
tree,
50+
sum(j -> gradient[:, j] * dY[j], eachindex(dY, axes(gradient, 2))),
51+
)
52+
end
53+
)
54+
end
55+
dX = let dY = dY, tree = tree, operators = operators
56+
@thunk(
57+
let
58+
_, gradient, complete = eval_grad_tree_array(
59+
tree, X, operators; variable=Val(true)
60+
)
61+
if !complete
62+
gradient .= NaN
63+
end
64+
65+
gradient .* reshape(dY, 1, length(dY))
66+
end
67+
)
68+
end
69+
return (NoTangent(), dtree, dX, NoTangent())
70+
end
71+
72+
return (primal, complete), pullback
73+
end
74+
75+
end

src/DynamicExpressions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ include("EquationUtils.jl")
88
include("Strings.jl")
99
include("EvaluateEquation.jl")
1010
include("EvaluateEquationDerivative.jl")
11+
include("ChainRules.jl")
1112
include("EvaluationHelpers.jl")
1213
include("SimplifyEquation.jl")
1314
include("OperatorEnumConstruction.jl")

0 commit comments

Comments
 (0)