Skip to content

Commit c5514d8

Browse files
fix: finite diff gradient accidental type promotion + generalize inputs (#1847)
* fix: finite diff gradient accidental type promotion * fix: use better epsilon * feat: support multiple args for finitediff * feat: preserve the correct return type * test: against enzyme * test: incorrect usage * Update src/TestUtils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix: single arg return type * test: use analytic gradients * Update test/nn/luxlib.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent b5137d7 commit c5514d8

File tree

8 files changed

+291
-83
lines changed

8 files changed

+291
-83
lines changed

src/Compiler.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,10 @@ function optimization_passes(
702702
dus_to_concat::Bool=false,
703703
recognize_comms::Bool=true,
704704
lower_comms::Bool=true,
705-
max_constant_threshold::Int=1024,
706705
backend::String="gpu",
707706
)
707+
(; max_constant_threshold) = compile_options
708+
708709
transform_passes_list = [
709710
"patterns=compare_op_canon<16>",
710711
"transpose_transpose<16>",

src/TestUtils.jl

Lines changed: 147 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module TestUtils
22

3-
using ..Reactant: Reactant, TracedRArray
3+
using ..Reactant: Reactant, TracedRArray, TracedRNumber, TracedUtils
4+
using Reactant.Ops: @opcall
45
using ReactantCore: ReactantCore
56
using LinearAlgebra: LinearAlgebra
67

@@ -20,22 +21,154 @@ function construct_test_array(::Type{T}, dims::Int...) where {T}
2021
return reshape(collect(T, 1:prod(dims)), dims...)
2122
end
2223

23-
function finite_difference_gradient(
24-
f, x::AbstractArray{T}; epsilon=eps(T)^(3 / 4)
25-
) where {T}
24+
# https://github.com/JuliaDiff/FiniteDiff.jl/blob/3a8c3d8d87e59de78e2831787a3f54b12b7c2075/src/epsilons.jl#L133
25+
function default_epslion(::Val{fdtype}, ::Type{T}) where {fdtype,T}
26+
if fdtype == :forward
27+
return sqrt(eps(real(T)))
28+
elseif fdtype == :central
29+
return cbrt(eps(real(T)))
30+
elseif fdtype == :hcentral
31+
return eps(T)^(T(1 / 4))
32+
else
33+
return one(real(T))
34+
end
35+
end
36+
37+
function get_perturbation(x::AbstractArray{T}, epsilon) where {T}
2638
onehot_matrix = Reactant.promote_to(
27-
TracedRArray{Reactant.unwrapped_eltype(T),2}, LinearAlgebra.I(length(x))
39+
TracedRArray{Reactant.unwrapped_eltype(T),2},
40+
LinearAlgebra.Diagonal(fill(epsilon, length(x)));
41+
)
42+
return permutedims(
43+
reshape(onehot_matrix, size(x)..., length(x)), (ndims(x) + 1, 1:(ndims(x))...)
44+
)
45+
end
46+
47+
function generate_perturbed_array(::Val{:central}, x::AbstractArray{T}, epsilon) where {T}
48+
perturbation = get_perturbation(x, epsilon)
49+
x_ = reshape(x, 1, size(x)...)
50+
return cat(x_ .+ perturbation, x_ .- perturbation; dims=1)
51+
end
52+
53+
function generate_perturbed_array(::Val{:forward}, x::AbstractArray{T}, epsilon) where {T}
54+
perturbation = get_perturbation(x, epsilon)
55+
x_ = reshape(x, 1, size(x)...)
56+
return cat(x_ .+ perturbation, x_; dims=1)
57+
end
58+
59+
function finite_difference_gradient(
60+
f::F, args...; method::Union{Val{:central},Val{:forward}}=Val(:central)
61+
) where {F}
62+
argprefix = gensym("finitediffarg")
63+
resprefix = gensym("finitediffresult")
64+
resargprefix = gensym("finitediffresarg")
65+
66+
# TODO: can we detect and prevent using functions that mutate their arguments?
67+
mlir_fn_res = TracedUtils.make_mlir_fn(
68+
f,
69+
args,
70+
(),
71+
"finite_difference_gradient_fn",
72+
false;
73+
args_in_result=:none,
74+
argprefix,
75+
resprefix,
76+
resargprefix,
2877
)
29-
perturbation = reshape(onehot_matrix .* epsilon, size(x)..., length(x))
30-
f_input = cat(x .+ perturbation, x .- perturbation; dims=ndims(x) + 1)
31-
32-
f_evaluated = mapslices(f, f_input; dims=ntuple(identity, ndims(x)))
33-
return ReactantCore.materialize_traced_array(
34-
reshape(
35-
(f_evaluated[1:length(x)] - f_evaluated[(length(x) + 1):end]) ./ (2 * epsilon),
36-
size(x),
37-
),
78+
79+
seenargs = Reactant.OrderedIdDict()
80+
Reactant.make_tracer(seenargs, f, (argprefix,), Reactant.TracedSetPath)
81+
for (i, arg) in enumerate(args)
82+
Reactant.make_tracer(seenargs, arg, (argprefix, i), Reactant.TracedSetPath)
83+
end
84+
85+
linear_args = Reactant.TracedType[]
86+
for (k, v) in seenargs
87+
v isa Reactant.TracedType || continue
88+
push!(linear_args, v)
89+
end
90+
91+
if (
92+
length(mlir_fn_res.linear_results) != 1 ||
93+
!(mlir_fn_res.linear_results[1] isa TracedRNumber)
3894
)
95+
error("`finite_difference_gradient` only supports functions with a single scalar \
96+
output. Received : $(mlir_fn_res.linear_results)")
97+
end
98+
99+
gradient_results = TracedRArray[]
100+
gradient_result_map_path = []
101+
for i in 1:length(linear_args)
102+
arg = linear_args[i]
103+
if arg isa TracedRArray && TracedUtils.has_idx(arg, argprefix)
104+
path = TracedUtils.get_idx(arg, argprefix)
105+
if mlir_fn_res.fnwrapped && length(path) > 1 && path[2] == 1
106+
continue
107+
end
108+
109+
# We need the gradient wrt this argument
110+
# we will naively insert the args here, cse will take care of the rest
111+
new_arguments = TracedRArray[]
112+
113+
epsilon = default_epslion(method, Reactant.unwrapped_eltype(arg))
114+
pertubed_arg = generate_perturbed_array(method, arg, epsilon)
115+
116+
bsize = size(pertubed_arg, 1)
117+
for j in 1:length(linear_args)
118+
if i == j
119+
new_arg = pertubed_arg
120+
elseif linear_args[j] isa TracedRNumber
121+
new_arg = @opcall broadcast_in_dim(
122+
linear_args[j], Int64[], Int64[bsize]
123+
)
124+
else
125+
new_arg = @opcall broadcast_in_dim(
126+
linear_args[j],
127+
collect(Int64, 2:(ndims(linear_args[j]) + 1)),
128+
Int64[bsize, size(linear_args[j])...],
129+
)
130+
end
131+
new_arg = @opcall transpose(new_arg, Int64[1, ((ndims(new_arg)):-1:2)...];)
132+
push!(new_arguments, new_arg)
133+
end
134+
135+
batched_res = @opcall batch(
136+
new_arguments,
137+
[
138+
Reactant.MLIR.IR.TensorType(
139+
Int64[bsize],
140+
Reactant.MLIR.IR.Type(
141+
Reactant.unwrapped_eltype(mlir_fn_res.linear_results[1])
142+
),
143+
),
144+
],
145+
Int64[bsize];
146+
fn=mlir_fn_res.f,
147+
)
148+
batched_res = only(batched_res)
149+
150+
if method isa Val{:central}
151+
diff = batched_res[1:(bsize ÷ 2)] - batched_res[((bsize ÷ 2) + 1):end]
152+
grad_res = diff ./ (2 * epsilon)
153+
elseif method isa Val{:forward}
154+
diff = batched_res[1:(end - 1)] .- batched_res[end:end]
155+
grad_res = diff ./ epsilon
156+
end
157+
158+
push!(gradient_result_map_path, TracedUtils.get_idx(arg, argprefix))
159+
push!(
160+
gradient_results,
161+
ReactantCore.materialize_traced_array(reshape(grad_res, size(arg))),
162+
)
163+
end
164+
end
165+
166+
results = deepcopy(args)
167+
for (path, grad_res) in zip(gradient_result_map_path, gradient_results)
168+
TracedUtils.set!(results, path[2:end], grad_res.mlir_data)
169+
end
170+
length(args) == 1 && return results[1]
171+
return results
39172
end
40173

41174
end

src/TracedRArray.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T,
3333
Base.complex(x::TracedRArray{<:Real}) = complex.(x)
3434
Base.complex(x::TracedRArray{<:Complex}) = x
3535

36+
function Base.deepcopy_internal(x::TracedRArray, stackdict::IdDict)
37+
if haskey(stackdict, x)
38+
return stackdict[x]::typeof(x)
39+
end
40+
y = copy(x)
41+
stackdict[x] = y
42+
return y
43+
end
44+
3645
TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x)
3746

3847
function maybe_assert_scalar_setindexing(
@@ -1109,7 +1118,7 @@ function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector)
11091118
return accumulate!(op, A, B; dims=1)
11101119
end
11111120

1112-
if isdefined(Base, :_accumulate_promote_op)
1121+
@static if isdefined(Base, :_accumulate_promote_op)
11131122
function Base._accumulate_promote_op(op, A::AnyTracedRArray{T}; init=nothing) where {T}
11141123
if init !== nothing
11151124
init isa TracedRNumber && (init = zero(unwrapped_eltype(init)))

src/TracedRNumber.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ for (jlop, hloop) in (
491491
(:(Base.log), :log),
492492
(:(Base.log1p), :log_plus_one),
493493
(:(Base.sqrt), :sqrt),
494+
(:(Base.cbrt), :cbrt),
494495
(:(Base.acos), :acos),
495496
(:(Base.acosh), :acosh),
496497
(:(Base.asin), :asin),

src/stdlibs/LinearAlgebra.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ function overloaded_mul!(
273273
return C
274274
end
275275

276-
if isdefined(LinearAlgebra, :_triu)
276+
@static if isdefined(LinearAlgebra, :_triu)
277277
function LinearAlgebra._triu(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T}
278278
return overloaded_triu(materialize_traced_array(A), k)
279279
end
@@ -284,7 +284,7 @@ if isdefined(LinearAlgebra, :_triu)
284284
end
285285
end
286286

287-
if isdefined(LinearAlgebra, :_tril)
287+
@static if isdefined(LinearAlgebra, :_tril)
288288
function LinearAlgebra._tril(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T}
289289
return overloaded_tril(materialize_traced_array(A), k)
290290
end

test/autodiff.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,44 @@ end
366366

367367
@test @jit(jvp_vjp_cubic(v_r, x_r, lambdas_r)) fill(6, (3, 2))
368368
end
369+
370+
@testset "Finite Difference Gradient" begin
371+
x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float16, 2, 2))
372+
res = @jit Reactant.TestUtils.finite_difference_gradient(sum, x)
373+
@test res isa Reactant.ConcreteRArray{Float16,2}
374+
end
375+
376+
function fdiff_multiple_args(f, nt, x)
377+
return sum(abs2, f(nt.y .+ x .- nt.x))
378+
end
379+
380+
struct WrapperFunc{T}
381+
x::T
382+
end
383+
384+
(f::WrapperFunc)(x) = x .^ 3 .+ f.x
385+
386+
@testset "Finite Difference Gradient (non vector inputs)" begin
387+
nt = (;
388+
x=Reactant.TestUtils.construct_test_array(Float64, 3, 4),
389+
y=Reactant.TestUtils.construct_test_array(Float64, 3, 4),
390+
)
391+
fn = WrapperFunc(Reactant.TestUtils.construct_test_array(Float64, 3, 4))
392+
x = Reactant.TestUtils.construct_test_array(Float64, 3, 4)
393+
394+
nt_ra = Reactant.to_rarray(nt)
395+
fn_ra = Reactant.to_rarray(fn)
396+
x_ra = Reactant.to_rarray(x)
397+
398+
results_fd = @jit Reactant.TestUtils.finite_difference_gradient(
399+
fdiff_multiple_args, fn_ra, nt_ra, x_ra
400+
)
401+
@test results_fd isa typeof((fn_ra, nt_ra, x_ra))
402+
403+
results_enz = @jit Enzyme.gradient(Reverse, fdiff_multiple_args, fn_ra, nt_ra, x_ra)
404+
405+
@test results_fd[1].x results_enz[1].x
406+
@test results_fd[2].x results_enz[2].x
407+
@test results_fd[2].y results_enz[2].y
408+
@test results_fd[3] results_enz[3]
409+
end

0 commit comments

Comments
 (0)