Skip to content

Commit 9d3de40

Browse files
author
KristofferC
committed
WIP: Add support for HyperHessian backend
1 parent 0dd1abf commit 9d3de40

File tree

10 files changed

+339
-2
lines changed

10 files changed

+339
-2
lines changed

.github/workflows/Test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ jobs:
107107
- FiniteDifferences
108108
- ForwardDiff
109109
- GTPSA
110+
- HyperHessians
110111
- Mooncake
111112
- PolyesterForwardDiff
112113
- ReverseDiff

DifferentiationInterface/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
1414
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1515
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1616
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
17+
HyperHessians = "06b494a0-c8e0-40cc-ad32-d99506a00a6c"
1718
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1819
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1920
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -38,6 +39,7 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
3839
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
3940
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
4041
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
42+
DifferentiationInterfaceHyperHessiansExt = "HyperHessians"
4143
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
4244
DifferentiationInterfaceGTPSAExt = "GTPSA"
4345
DifferentiationInterfaceMooncakeExt = "Mooncake"
@@ -63,6 +65,7 @@ Diffractor = "=0.2.6"
6365
Enzyme = "0.13.39"
6466
EnzymeCore = "0.8.8"
6567
FastDifferentiation = "0.4.3"
68+
HyperHessians = "0.1"
6669
FiniteDiff = "2.27.0"
6770
FiniteDifferences = "0.12.31"
6871
ForwardDiff = "0.10.36,1"

DifferentiationInterface/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ We support the following backends defined by [ADTypes.jl](https://github.com/Sci
3737
- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl)
3838
- [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl)
3939
- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)
40+
- [HyperHessians.jl](https://github.com/KristofferC/HyperHessians.jl)
4041
- [GTPSA.jl](https://github.com/bmad-sim/GTPSA.jl)
4142
- [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl)
4243
- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl)

DifferentiationInterface/docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
66
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
77
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
88
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
9+
HyperHessians = "06b494a0-c8e0-40cc-ad32-d99506a00a6c"
910
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1011
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -21,6 +22,7 @@ Documenter = "1"
2122
DocumenterInterLinks = "1.1"
2223
FiniteDiff = "2.29"
2324
ForwardDiff = "1.2.2"
25+
HyperHessians = "0.1"
2426
PrettyTables = "3.1"
2527
SparseConnectivityTracer = "1.1.2"
2628
SparseMatrixColorings = "0.4.23"

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ We support the following dense backend choices from [ADTypes.jl](https://github.
1212
- [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences)
1313
- [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff)
1414
- [`AutoGTPSA`](@extref ADTypes.AutoGTPSA)
15+
- [`AutoHyperHessians`](https://github.com/KristofferC/HyperHessians.jl)
1516
- [`AutoMooncake`](@extref ADTypes.AutoMooncake) and [`AutoMooncakeForward`](@extref ADTypes.AutoMooncake) (the latter is experimental)
1617
- [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff)
1718
- [`AutoReverseDiff`](@extref ADTypes.AutoReverseDiff)
@@ -32,11 +33,11 @@ In practice, many AD backends have custom implementations for high-level operato
3233
!!! details
3334

3435
In the rough summary table below,
35-
36+
3637
- ✅ means that we reuse the custom implementation from the backend;
3738
- ❌ means that a custom implementation doesn't exist, so we use our default fallbacks;
3839
- 🔀 means it's complicated or not done yet.
39-
40+
4041
| | `pf` | `pb` | `der` | `grad` | `jac` | `hess` | `hvp` | `der2` |
4142
|:-------------------------- |:---- |:---- |:----- |:------ |:----- |:------ |:----- |:------ |
4243
| `AutoChainRules` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
@@ -48,6 +49,7 @@ In practice, many AD backends have custom implementations for high-level operato
4849
| `AutoFiniteDifferences` | 🔀 | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
4950
| `AutoForwardDiff` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
5051
| `AutoGTPSA` | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
52+
| `AutoHyperHessians` | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ |
5153
| `AutoMooncake` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
5254
| `AutoMooncakeForward` | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
5355
| `AutoPolyesterForwardDiff` | 🔀 | ❌ | 🔀 | ✅ | ✅ | 🔀 | 🔀 | 🔀 |
@@ -69,6 +71,7 @@ Moreover, each context type is supported by a specific subset of backends:
6971
| `AutoFiniteDifferences` |||
7072
| `AutoForwardDiff` |||
7173
| `AutoGTPSA` |||
74+
| `AutoHyperHessians` |||
7275
| `AutoMooncake` |||
7376
| `AutoMooncakeForward` |||
7477
| `AutoPolyesterForwardDiff` |||
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
module DifferentiationInterfaceHyperHessiansExt
2+
3+
import DifferentiationInterface as DI
4+
import .DI: AutoHyperHessians
5+
using ADTypes: ForwardMode
6+
using HyperHessians:
7+
DirectionalHVPConfig,
8+
HessianConfig,
9+
Chunk,
10+
chunksize,
11+
pickchunksize,
12+
hessian,
13+
hessian!,
14+
hessian_gradient_value,
15+
hessian_gradient_value!,
16+
hessian,
17+
hvp,
18+
hvp!,
19+
hvp_gradient_value,
20+
hvp_gradient_value!
21+
22+
## Traits
23+
DI.check_available(::DI.AutoHyperHessians) = true
24+
DI.inplace_support(::DI.AutoHyperHessians) = DI.InPlaceSupported()
25+
DI.hvp_mode(::DI.AutoHyperHessians) = DI.ForwardOverForward()
26+
DI.mode(::DI.AutoHyperHessians) = ForwardMode()
27+
28+
chunk_from_backend(backend::DI.AutoHyperHessians, x) =
29+
isnothing(backend.chunksize) ? Chunk(x) : Chunk{backend.chunksize}()
30+
chunk_from_backend(backend::DI.AutoHyperHessians, N::Integer, ::Type{T}) where {T} =
31+
isnothing(backend.chunksize) ? Chunk(pickchunksize(N, T), T) : Chunk{backend.chunksize}()
32+
33+
function DI.pick_batchsize(backend::DI.AutoHyperHessians, x::AbstractArray)
34+
B = chunksize(chunk_from_backend(backend, x))
35+
return DI.BatchSizeSettings{B}(length(x))
36+
end
37+
38+
## Second derivative (scalar input)
39+
40+
struct HyperHessiansSecondDerivativePrep{SIG} <: DI.SecondDerivativePrep{SIG}
41+
_sig::Val{SIG}
42+
end
43+
44+
function DI.prepare_second_derivative_nokwarg(
45+
strict::Val, f, backend::DI.AutoHyperHessians, x::Number, contexts::Vararg{DI.Context, C}
46+
) where {C}
47+
_sig = DI.signature(f, backend, x, contexts...; strict)
48+
return HyperHessiansSecondDerivativePrep(_sig)
49+
end
50+
51+
function DI.second_derivative(
52+
f,
53+
prep::HyperHessiansSecondDerivativePrep,
54+
backend::DI.AutoHyperHessians,
55+
x::Number,
56+
contexts::Vararg{DI.Context, C},
57+
) where {C}
58+
DI.check_prep(f, prep, backend, x, contexts...)
59+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
60+
return hessian(fc, x)
61+
end
62+
63+
function DI.second_derivative!(
64+
f,
65+
der2,
66+
prep::HyperHessiansSecondDerivativePrep,
67+
backend::DI.AutoHyperHessians,
68+
x::Number,
69+
contexts::Vararg{DI.Context, C},
70+
) where {C}
71+
DI.check_prep(f, prep, backend, x, contexts...)
72+
copyto!(der2, DI.second_derivative(f, prep, backend, x, contexts...))
73+
return der2
74+
end
75+
76+
function DI.value_derivative_and_second_derivative(
77+
f,
78+
prep::HyperHessiansSecondDerivativePrep,
79+
backend::DI.AutoHyperHessians,
80+
x::Number,
81+
contexts::Vararg{DI.Context, C},
82+
) where {C}
83+
DI.check_prep(f, prep, backend, x, contexts...)
84+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
85+
res = hessian_gradient_value(fc, x)
86+
return res.value, res.gradient, res.hessian
87+
end
88+
89+
function DI.value_derivative_and_second_derivative!(
90+
f,
91+
der,
92+
der2,
93+
prep::HyperHessiansSecondDerivativePrep,
94+
backend::DI.AutoHyperHessians,
95+
x::Number,
96+
contexts::Vararg{DI.Context, C},
97+
) where {C}
98+
DI.check_prep(f, prep, backend, x, contexts...)
99+
y, new_der, new_der2 = DI.value_derivative_and_second_derivative(f, prep, backend, x, contexts...)
100+
copyto!(der, new_der)
101+
copyto!(der2, new_der2)
102+
return y, der, der2
103+
end
104+
105+
## Preparation structs
106+
107+
struct HyperHessiansHessianPrep{SIG, C} <: DI.HessianPrep{SIG}
108+
_sig::Val{SIG}
109+
cfg::C
110+
end
111+
112+
struct HyperHessiansHVPPrep{SIG, C} <: DI.HVPPrep{SIG}
113+
_sig::Val{SIG}
114+
cfg::C
115+
end
116+
117+
## Hessian
118+
119+
function DI.prepare_hessian_nokwarg(
120+
strict::Val, f, backend::DI.AutoHyperHessians, x::AbstractArray, contexts::Vararg{DI.Context, C}
121+
) where {C}
122+
_sig = DI.signature(f, backend, x, contexts...; strict)
123+
cfg = HessianConfig(x, chunk_from_backend(backend, x))
124+
return HyperHessiansHessianPrep(_sig, cfg)
125+
end
126+
127+
function DI.hessian(
128+
f,
129+
prep::HyperHessiansHessianPrep,
130+
backend::DI.AutoHyperHessians,
131+
x,
132+
contexts::Vararg{DI.Context, C},
133+
) where {C}
134+
DI.check_prep(f, prep, backend, x, contexts...)
135+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
136+
return hessian(fc, x, prep.cfg)
137+
end
138+
139+
function DI.hessian!(
140+
f,
141+
hess,
142+
prep::HyperHessiansHessianPrep,
143+
backend::DI.AutoHyperHessians,
144+
x,
145+
contexts::Vararg{DI.Context, C},
146+
) where {C}
147+
DI.check_prep(f, prep, backend, x, contexts...)
148+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
149+
return hessian!(hess, fc, x, prep.cfg)
150+
end
151+
152+
function DI.value_gradient_and_hessian(
153+
f,
154+
prep::HyperHessiansHessianPrep,
155+
backend::DI.AutoHyperHessians,
156+
x,
157+
contexts::Vararg{DI.Context, C},
158+
) where {C}
159+
DI.check_prep(f, prep, backend, x, contexts...)
160+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
161+
res = hessian_gradient_value(fc, x, prep.cfg)
162+
return res.value, res.gradient, res.hessian
163+
end
164+
165+
function DI.value_gradient_and_hessian!(
166+
f,
167+
grad,
168+
hess,
169+
prep::HyperHessiansHessianPrep,
170+
backend::DI.AutoHyperHessians,
171+
x,
172+
contexts::Vararg{DI.Context, C},
173+
) where {C}
174+
DI.check_prep(f, prep, backend, x, contexts...)
175+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
176+
val = hessian_gradient_value!(hess, grad, fc, x, prep.cfg)
177+
return val, grad, hess
178+
end
179+
180+
## HVP
181+
182+
function DI.prepare_hvp_nokwarg(
183+
strict::Val, f, backend::DI.AutoHyperHessians, x::AbstractArray, tx::NTuple, contexts::Vararg{DI.Context, C}
184+
) where {C}
185+
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
186+
cfg = DirectionalHVPConfig(x, tx, chunk_from_backend(backend, x))
187+
return HyperHessiansHVPPrep(_sig, cfg)
188+
end
189+
190+
function DI.prepare_hvp_same_point(
191+
f,
192+
prep::HyperHessiansHVPPrep,
193+
backend::DI.AutoHyperHessians,
194+
x,
195+
tx::NTuple,
196+
contexts::Vararg{DI.Context, C},
197+
) where {C}
198+
DI.check_prep(f, prep, backend, x, tx, contexts...)
199+
return prep
200+
end
201+
202+
function DI.hvp(
203+
f,
204+
prep::HyperHessiansHVPPrep,
205+
backend::AutoHyperHessians,
206+
x,
207+
tx::NTuple,
208+
contexts::Vararg{DI.Context, C},
209+
) where {C}
210+
DI.check_prep(f, prep, backend, x, tx, contexts...)
211+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
212+
return hvp(fc, x, tx, prep.cfg)
213+
end
214+
215+
function DI.hvp!(
216+
f,
217+
tg::NTuple,
218+
prep::HyperHessiansHVPPrep,
219+
backend::DI.AutoHyperHessians,
220+
x,
221+
tx::NTuple,
222+
contexts::Vararg{DI.Context, C},
223+
) where {C}
224+
DI.check_prep(f, prep, backend, x, tx, contexts...)
225+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
226+
return hvp!(tg, fc, x, tx, prep.cfg)
227+
end
228+
229+
function DI.gradient_and_hvp(
230+
f,
231+
prep::HyperHessiansHVPPrep,
232+
backend::DI.AutoHyperHessians,
233+
x,
234+
tx::NTuple,
235+
contexts::Vararg{DI.Context, C},
236+
) where {C}
237+
DI.check_prep(f, prep, backend, x, tx, contexts...)
238+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
239+
res = hvp_gradient_value(fc, x, tx, prep.cfg)
240+
return res.gradient, res.hvp
241+
end
242+
243+
function DI.gradient_and_hvp!(
244+
f,
245+
grad,
246+
tg::NTuple,
247+
prep::HyperHessiansHVPPrep,
248+
backend::DI.AutoHyperHessians,
249+
x,
250+
tx::NTuple,
251+
contexts::Vararg{DI.Context, C},
252+
) where {C}
253+
DI.check_prep(f, prep, backend, x, tx, contexts...)
254+
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
255+
hvp_gradient_value!(tg, grad, fc, x, tx, prep.cfg)
256+
return grad, tg
257+
end
258+
259+
end

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ include("second_order/hessian.jl")
6464

6565
include("misc/differentiate_with.jl")
6666
include("misc/from_primitive.jl")
67+
include("misc/autohyperhessians.jl")
6768
include("misc/sparsity_detector.jl")
6869
include("misc/simple_finite_diff.jl")
6970
include("misc/zero_backends.jl")
@@ -122,6 +123,7 @@ export AutoReverseDiff
122123
export AutoSymbolics
123124
export AutoTracker
124125
export AutoZygote
126+
export AutoHyperHessians
125127

126128
export AutoSparse
127129

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
AutoHyperHessians(; chunksize = nothing)
3+
4+
Lightweight ADTypes backend tag for HyperHessians. The `chunksize` keyword can
5+
be set to a positive `Int` to override HyperHessians' chunk heuristic; `nothing`
6+
lets HyperHessians choose.
7+
"""
8+
struct AutoHyperHessians{CS} <: ADTypes.AbstractADType
9+
chunksize::CS
10+
function AutoHyperHessians(; chunksize::Union{Nothing, Int} = nothing)
11+
if chunksize isa Int
12+
chunksize > 0 || throw(ArgumentError("chunksize must be positive, got $chunksize"))
13+
end
14+
return new{typeof(chunksize)}(chunksize)
15+
end
16+
end

0 commit comments

Comments
 (0)