Skip to content

Commit 4befd36

Browse files
committed
ForwardDiffExt: switched to NNlib activation functions
1 parent 971f309 commit 4befd36

File tree

3 files changed

+39
-21
lines changed

3 files changed

+39
-21
lines changed

Project.toml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
3-
authors = ["Chris Elrod <elrodc@gmail.com>"]
43
version = "0.12.172"
4+
authors = ["Chris Elrod <elrodc@gmail.com>"]
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -30,9 +30,10 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
3030
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3131
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3232
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
33+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3334

3435
[extensions]
35-
ForwardDiffExt = ["ChainRulesCore", "ForwardDiff"]
36+
ForwardDiffExt = ["ChainRulesCore", "ForwardDiff", "NNlib"]
3637
SpecialFunctionsExt = "SpecialFunctions"
3738

3839
[compat]
@@ -46,6 +47,7 @@ HostCPUFeatures = "0.1.10"
4647
IfElse = "0.1"
4748
LayoutPointers = "0.1.11"
4849
LinearAlgebra = "1"
50+
NNlib = "0.9.31"
4951
OffsetArrays = "1.4.1"
5052
PolyesterWeave = "0.1.10, 0.2"
5153
PrecompileTools = "1"
@@ -57,4 +59,8 @@ StaticArrayInterface = "1"
5759
ThreadingUtilities = "0.5"
5860
UnPack = "1"
5961
VectorizationBase = "0.21.72"
60-
julia = "1.6"
62+
julia = "1.10"
63+
64+
[extras]
65+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
66+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

ext/ForwardDiffExt.jl

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
module ForwardDiffExt
22
import ForwardDiff, ChainRulesCore
3-
using LoopVectorization, VectorizationBase, SLEEFPirates, ForwardDiff
3+
using LoopVectorization, VectorizationBase, SLEEFPirates, ForwardDiff, NNlib
4+
using SLEEFPirates: tanh_fast, sigmoid_fast
45

56
import IfElse: ifelse
67
using VectorizationBase: AbstractSIMD, AbstractMask, zero_offsets
78

89
using LoopVectorization:
910
AbstractSIMD,
1011
AbstractStridedPointer,
11-
relu,
1212
vmap,
1313
VectorizationBase,
1414
vmapt,
@@ -140,7 +140,8 @@ end
140140
)
141141
end
142142
end
143-
@generated function VectorizationBase.relu(
143+
144+
@generated function NNlib.relu(
144145
x::ForwardDiff.Dual{T,S,N}
145146
) where {T,S,N}
146147
quote
@@ -157,6 +158,27 @@ end
157158
end
158159
end
159160

161+
@generated function NNlib.leakyrelu(
162+
x::ForwardDiff.Dual{T,S,N},
163+
a = 0.01
164+
) where {T,S,N}
165+
quote
166+
$(Expr(:meta, :inline))
167+
v = x.value
168+
z = zero(v)
169+
170+
α = convert(typeof(v), a)
171+
cmp = v < z
172+
r = ifelse(cmp, α * v, v)
173+
p = x.partials
174+
ForwardDiff.Dual{T}(
175+
r,
176+
ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> ifelse(cmp, α * p[n], p[n]))
177+
)
178+
end
179+
end
180+
181+
160182
@generated function _ifelse(
161183
m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}},
162184
x::ForwardDiff.Dual{TAG,V,P},

test/forwarddiffext.jl

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,6 @@ function tovec(x::ForwardDiff.Dual{T,V,N}) where {T,V,N}
1616
return ret
1717
end
1818

19-
if LoopVectorization.ifelse !== Base.ifelse
20-
@inline function NNlib.leakyrelu(
21-
x::LoopVectorization.AbstractSIMD,
22-
a = NNlib.oftf(x, NNlib.leakyrelu_a),
23-
)
24-
LoopVectorization.ifelse(x > zero(x), float(x), NNlib.oftf(x, a * x)) # max(a*x, x) is 3x slower
25-
end
26-
@inline function NNlib.leakyrelu(
27-
x::ForwardDiff.Dual{<:Any,<:LoopVectorization.AbstractSIMD},
28-
a = NNlib.oftf(x, NNlib.leakyrelu_a),
29-
)
30-
LoopVectorization.ifelse(x > zero(x), float(x), NNlib.oftf(x, a * x)) # max(a*x, x) is 3x slower
31-
end
32-
end
33-
3419
vx0 = randnvec()
3520
vx1 = randnvec()
3621
vx2 = randnvec()
@@ -50,3 +35,8 @@ vud = ForwardDiff.Dual(vu0, vu1, vu2)
5035
reinterpret(Float64, NNlib.leakyrelu.(tovec(vd0)))
5136
@test reinterpret(Float64, tovec(NNlib.leakyrelu(vud)))
5237
reinterpret(Float64, NNlib.leakyrelu.(tovec(vud)))
38+
39+
@test reinterpret(Float64, tovec(NNlib.relu(vd0)))
40+
reinterpret(Float64, NNlib.relu.(tovec(vd0)))
41+
@test reinterpret(Float64, tovec(NNlib.relu(vud)))
42+
reinterpret(Float64, NNlib.relu.(tovec(vud)))

0 commit comments

Comments
 (0)