11using Core. Compiler: IRInterpretationState, construct_postdomtree, PiNode,
22 is_known_call, argextype, postdominates
33
4- function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , pantelides:: Vector{SSAValue} ; custom_diff! = (args... )-> nothing , diff_cache= Dict {SSAValue, SSAValue} ())
4+ #=
5+ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff::Vector{Pair{SSAValue, Int}}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}())
56 Δs = SSAValue[]
67 rets = findall(@nospecialize(x)->isa(x, ReturnNode) && isdefined(x, :val), ir.stmts.inst)
78 postdomtree = construct_postdomtree(ir.cfg.blocks)
8- for ssa in pantelides
9- Δssa = forward_diff! (ir, interp, irsv, ssa; custom_diff!, diff_cache)
9+ for ( ssa, order) in to_diff
10+ Δssa = forward_diff!(ir, interp, irsv, ssa, order ; custom_diff!, diff_cache)
1011 Δblock = block_for_inst(ir, Δssa.id)
1112 for idx in rets
1213 retblock = block_for_inst(ir, idx)
@@ -18,31 +19,24 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, pantelid
1819 end
1920 return (ir, Δs)
2021end
22+ =#
2123
22- function diff_unassigned_variable! (ir, ssa)
23- return insert_node! (ir, ssa, NewInstruction (
24- Expr (:call , GlobalRef (Intrinsics, :state_ddt ), ssa), Float64), #= attach_after=# true )
25- end
26-
27- function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , ssa:: SSAValue ; custom_diff!, diff_cache)
24+ function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , ssa:: SSAValue , order:: Int ; custom_diff!, diff_cache)
2825 if haskey (diff_cache, ssa)
2926 return diff_cache[ssa]
3027 end
3128 inst = ir[ssa]
3229 stmt = inst[:inst ]
33- if isa (stmt, SSAValue)
34- return forward_diff! (ir, interp, irsv, stmt; custom_diff!, diff_cache)
35- end
36- Δssa = forward_diff_uncached! (ir, interp, irsv, ssa, inst; custom_diff!, diff_cache)
30+ Δssa = forward_diff_uncached! (ir, interp, irsv, ssa, inst, order:: Int ; custom_diff!, diff_cache)
3731 @assert Δssa != = nothing
3832 if isa (Δssa, SSAValue)
3933 diff_cache[ssa] = Δssa
4034 end
4135 return Δssa
4236end
43- forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , val:: Union{Integer, AbstractFloat} ; custom_diff!, diff_cache) = zero (val)
44- forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , @nospecialize (arg); custom_diff!, diff_cache) = ChainRulesCore. NoTangent ()
45- function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , arg:: Argument ; custom_diff!, diff_cache)
37+ forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , val:: Union{Integer, AbstractFloat} , order :: Int ; custom_diff!, diff_cache) = zero (val)
38+ forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , @nospecialize (arg), order :: Int ; custom_diff!, diff_cache) = ChainRulesCore. NoTangent ()
39+ function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , arg:: Argument , order :: Int ; custom_diff!, diff_cache)
4640 recurse (x) = forward_diff! (ir, interp, irsv, x; custom_diff!, diff_cache)
4741 val = custom_diff! (ir, SSAValue (0 ), arg, recurse)
4842 if val != = nothing
@@ -51,13 +45,15 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Arg
5145 return ChainRulesCore. NoTangent ()
5246end
5347
54- function forward_diff_uncached! (ir:: IRCode , interp, irsv:: IRInterpretationState , ssa:: SSAValue , inst:: Core.Compiler.Instruction ; custom_diff!, diff_cache)
48+ function forward_diff_uncached! (ir:: IRCode , interp, irsv:: IRInterpretationState , ssa:: SSAValue , inst:: Core.Compiler.Instruction , order :: Int ; custom_diff!, diff_cache)
5549 stmt = inst[:inst ]
56- recurse (x) = forward_diff! (ir, interp, irsv, x; custom_diff!, diff_cache)
50+ recurse (x) = forward_diff! (ir, interp, irsv, x, order ; custom_diff!, diff_cache)
5751 if (val = custom_diff! (ir, ssa, stmt, recurse)) != = nothing
5852 return val
5953 elseif isa (stmt, PiNode)
6054 return recurse (stmt. val)
55+ elseif isa (stmt, SSAValue)
56+ return recurse (stmt)
6157 elseif isa (stmt, PhiNode)
6258 Δphi = PhiNode (copy (stmt. edges), similar (stmt. values))
6359 T = Union{}
@@ -152,3 +148,108 @@ function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState,
152148 return Δssa
153149 end
154150end
151+
152+ function forward_visit! (ir:: IRCode , ssa:: SSAValue , order:: Int , ssa_orders:: Vector{Pair{Int, Bool}} , visit_custom!)
153+ if ssa_orders[ssa. id][1 ] >= order
154+ return
155+ end
156+ ssa_orders[ssa. id] = order => ssa_orders[ssa. id][2 ]
157+ inst = ir[ssa]
158+ stmt = inst[:inst ]
159+ recurse (@nospecialize (val)) = forward_visit! (ir, val, order, ssa_orders, visit_custom!)
160+ if visit_custom! (ir, stmt, order, recurse)
161+ ssa_orders[ssa. id] = order => true
162+ return
163+ elseif isa (stmt, PiNode)
164+ return recurse (stmt. val)
165+ elseif isa (stmt, PhiNode)
166+ for i = 1 : length (stmt. values)
167+ isassigned (stmt. values, i) || continue
168+ recurse (stmt. values[i])
169+ end
170+ return
171+ elseif isexpr (stmt, :new ) || isexpr (stmt, :invoke )
172+ foreach (recurse, stmt. args[2 : end ])
173+ elseif isexpr (stmt, :call )
174+ foreach (recurse, stmt. args)
175+ elseif isa (stmt, SSAValue)
176+ recurse (stmt)
177+ elseif ! isa (stmt, Expr)
178+ return
179+ else
180+ @show stmt
181+ error ()
182+ end
183+ end
184+ forward_visit! (ir:: IRCode , _, order:: Int , ssa_orders:: Vector{Pair{Int, Bool}} , visit_custom!) = nothing
185+ function forward_visit! (ir:: IRCode , a:: Argument , order:: Int , ssa_orders:: Vector{Pair{Int, Bool}} , visit_custom!)
186+ recurse (@nospecialize (val)) = forward_visit! (ir, val, order, ssa_orders, visit_custom!)
187+ return visit_custom! (ir, a, order, recurse)
188+ end
189+
190+
191+ function forward_diff_no_inf! (ir:: IRCode , interp, mi:: MethodInstance , world, to_diff:: Vector{Pair{SSAValue, Int}} ;
192+ visit_custom! = (args... )-> false , transform! = (args... )-> error ())
193+ # Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
194+ ssa_orders = [0 => false for i = 1 : length (ir. stmts)]
195+ for (ssa, order) in to_diff
196+ forward_visit! (ir, ssa, order, ssa_orders, visit_custom!)
197+ end
198+
199+ # Step 2: Transform
200+ function maparg (arg, ssa, order)
201+ if isa (arg, Argument)
202+ # TODO : Should we remember whether the callbacks wanted the arg?
203+ return transform! (ir, arg, order)
204+ elseif isa (arg, SSAValue)
205+ # TODO : Bundle truncation if necessary
206+ return arg
207+ end
208+ @assert ! isa (arg, Expr)
209+ return insert_node! (ir, ssa, NewInstruction (Expr (:call , ZeroBundle{order}, arg), Any))
210+ end
211+
212+ for (ssa, (order, custom)) in enumerate (ssa_orders)
213+ if order == 0
214+ # TODO : Bundle truncation?
215+ continue
216+ end
217+ if custom
218+ transform! (ir, SSAValue (ssa), order)
219+ else
220+ inst = ir[SSAValue (ssa)]
221+ stmt = inst[:inst ]
222+ if isexpr (stmt, :invoke )
223+ inst[:inst ] = Expr (:call , ∂☆ {order} (), map (arg-> maparg (arg, SSAValue (ssa), order), stmt. args[2 : end ])... )
224+ inst[:type ] = Any
225+ elseif ! isa (stmt, Expr)
226+ inst[:inst ] = maparg (stmt, ssa, order)
227+ inst[:type ] = Any
228+ else
229+ @show stmt
230+ error ()
231+ end
232+ end
233+ end
234+
235+ end
236+
237+ function forward_diff! (ir:: IRCode , interp, mi:: MethodInstance , world, to_diff:: Vector{Pair{SSAValue, Int}} ; kwargs... )
238+ forward_diff_no_inf! (ir, interp, mi, world, to_diff; kwargs... )
239+
240+ # Step 3: Re-inference
241+ ir = compact! (ir)
242+
243+ extra_reprocess = CC. BitSet ()
244+ for i = 1 : length (ir. stmts)
245+ if ir[SSAValue (i)][:type ] == Any
246+ CC. push! (extra_reprocess, i)
247+ end
248+ end
249+
250+ interp′ = enable_reinference (interp)
251+ irsv = IRInterpretationState (interp′, ir, mi, world, ir. argtypes[1 : mi. def. nargs])
252+ rt = CC. _ir_abstract_constant_propagation (interp′, irsv; extra_reprocess)
253+
254+ return ir
255+ end
0 commit comments