11module TestUtils
22
3- using .. Reactant: Reactant, TracedRArray
3+ using .. Reactant: Reactant, TracedRArray, TracedRNumber, TracedUtils
4+ using Reactant. Ops: @opcall
45using ReactantCore: ReactantCore
56using LinearAlgebra: LinearAlgebra
67
@@ -20,22 +21,154 @@ function construct_test_array(::Type{T}, dims::Int...) where {T}
2021 return reshape (collect (T, 1 : prod (dims)), dims... )
2122end
2223
23- function finite_difference_gradient (
24- f, x:: AbstractArray{T} ; epsilon= eps (T)^ (3 / 4 )
25- ) where {T}
24+ # https://github.com/JuliaDiff/FiniteDiff.jl/blob/3a8c3d8d87e59de78e2831787a3f54b12b7c2075/src/epsilons.jl#L133
25+ function default_epslion (:: Val{fdtype} , :: Type{T} ) where {fdtype,T}
26+ if fdtype == :forward
27+ return sqrt (eps (real (T)))
28+ elseif fdtype == :central
29+ return cbrt (eps (real (T)))
30+ elseif fdtype == :hcentral
31+ return eps (T)^ (T (1 / 4 ))
32+ else
33+ return one (real (T))
34+ end
35+ end
36+
37+ function get_perturbation (x:: AbstractArray{T} , epsilon) where {T}
2638 onehot_matrix = Reactant. promote_to (
27- TracedRArray{Reactant. unwrapped_eltype (T),2 }, LinearAlgebra. I (length (x))
39+ TracedRArray{Reactant. unwrapped_eltype (T),2 },
40+ LinearAlgebra. Diagonal (fill (epsilon, length (x)));
41+ )
42+ return permutedims (
43+ reshape (onehot_matrix, size (x)... , length (x)), (ndims (x) + 1 , 1 : (ndims (x)). .. )
44+ )
45+ end
46+
47+ function generate_perturbed_array (:: Val{:central} , x:: AbstractArray{T} , epsilon) where {T}
48+ perturbation = get_perturbation (x, epsilon)
49+ x_ = reshape (x, 1 , size (x)... )
50+ return cat (x_ .+ perturbation, x_ .- perturbation; dims= 1 )
51+ end
52+
53+ function generate_perturbed_array (:: Val{:forward} , x:: AbstractArray{T} , epsilon) where {T}
54+ perturbation = get_perturbation (x, epsilon)
55+ x_ = reshape (x, 1 , size (x)... )
56+ return cat (x_ .+ perturbation, x_; dims= 1 )
57+ end
58+
59+ function finite_difference_gradient (
60+ f:: F , args... ; method:: Union{Val{:central},Val{:forward}} = Val (:central )
61+ ) where {F}
62+ argprefix = gensym (" finitediffarg" )
63+ resprefix = gensym (" finitediffresult" )
64+ resargprefix = gensym (" finitediffresarg" )
65+
66+ # TODO : can we detect and prevent using functions that mutate their arguments?
67+ mlir_fn_res = TracedUtils. make_mlir_fn (
68+ f,
69+ args,
70+ (),
71+ " finite_difference_gradient_fn" ,
72+ false ;
73+ args_in_result= :none ,
74+ argprefix,
75+ resprefix,
76+ resargprefix,
2877 )
29- perturbation = reshape (onehot_matrix .* epsilon, size (x)... , length (x))
30- f_input = cat (x .+ perturbation, x .- perturbation; dims= ndims (x) + 1 )
31-
32- f_evaluated = mapslices (f, f_input; dims= ntuple (identity, ndims (x)))
33- return ReactantCore. materialize_traced_array (
34- reshape (
35- (f_evaluated[1 : length (x)] - f_evaluated[(length (x) + 1 ): end ]) ./ (2 * epsilon),
36- size (x),
37- ),
78+
79+ seenargs = Reactant. OrderedIdDict ()
80+ Reactant. make_tracer (seenargs, f, (argprefix,), Reactant. TracedSetPath)
81+ for (i, arg) in enumerate (args)
82+ Reactant. make_tracer (seenargs, arg, (argprefix, i), Reactant. TracedSetPath)
83+ end
84+
85+ linear_args = Reactant. TracedType[]
86+ for (k, v) in seenargs
87+ v isa Reactant. TracedType || continue
88+ push! (linear_args, v)
89+ end
90+
91+ if (
92+ length (mlir_fn_res. linear_results) != 1 ||
93+ ! (mlir_fn_res. linear_results[1 ] isa TracedRNumber)
3894 )
95+ error (" `finite_difference_gradient` only supports functions with a single scalar \
96+ output. Received : $(mlir_fn_res. linear_results) " )
97+ end
98+
99+ gradient_results = TracedRArray[]
100+ gradient_result_map_path = []
101+ for i in 1 : length (linear_args)
102+ arg = linear_args[i]
103+ if arg isa TracedRArray && TracedUtils. has_idx (arg, argprefix)
104+ path = TracedUtils. get_idx (arg, argprefix)
105+ if mlir_fn_res. fnwrapped && length (path) > 1 && path[2 ] == 1
106+ continue
107+ end
108+
109+ # We need the gradient wrt this argument
110+ # we will naively insert the args here, cse will take care of the rest
111+ new_arguments = TracedRArray[]
112+
113+ epsilon = default_epslion (method, Reactant. unwrapped_eltype (arg))
114+ pertubed_arg = generate_perturbed_array (method, arg, epsilon)
115+
116+ bsize = size (pertubed_arg, 1 )
117+ for j in 1 : length (linear_args)
118+ if i == j
119+ new_arg = pertubed_arg
120+ elseif linear_args[j] isa TracedRNumber
121+ new_arg = @opcall broadcast_in_dim (
122+ linear_args[j], Int64[], Int64[bsize]
123+ )
124+ else
125+ new_arg = @opcall broadcast_in_dim (
126+ linear_args[j],
127+ collect (Int64, 2 : (ndims (linear_args[j]) + 1 )),
128+ Int64[bsize, size (linear_args[j])... ],
129+ )
130+ end
131+ new_arg = @opcall transpose (new_arg, Int64[1 , ((ndims (new_arg)): - 1 : 2 ). .. ];)
132+ push! (new_arguments, new_arg)
133+ end
134+
135+ batched_res = @opcall batch (
136+ new_arguments,
137+ [
138+ Reactant. MLIR. IR. TensorType (
139+ Int64[bsize],
140+ Reactant. MLIR. IR. Type (
141+ Reactant. unwrapped_eltype (mlir_fn_res. linear_results[1 ])
142+ ),
143+ ),
144+ ],
145+ Int64[bsize];
146+ fn= mlir_fn_res. f,
147+ )
148+ batched_res = only (batched_res)
149+
150+ if method isa Val{:central }
151+ diff = batched_res[1 : (bsize ÷ 2 )] - batched_res[((bsize ÷ 2 ) + 1 ): end ]
152+ grad_res = diff ./ (2 * epsilon)
153+ elseif method isa Val{:forward }
154+ diff = batched_res[1 : (end - 1 )] .- batched_res[end : end ]
155+ grad_res = diff ./ epsilon
156+ end
157+
158+ push! (gradient_result_map_path, TracedUtils. get_idx (arg, argprefix))
159+ push! (
160+ gradient_results,
161+ ReactantCore. materialize_traced_array (reshape (grad_res, size (arg))),
162+ )
163+ end
164+ end
165+
166+ results = deepcopy (args)
167+ for (path, grad_res) in zip (gradient_result_map_path, gradient_results)
168+ TracedUtils. set! (results, path[2 : end ], grad_res. mlir_data)
169+ end
170+ length (args) == 1 && return results[1 ]
171+ return results
39172end
40173
41174end
0 commit comments