11using Core. Compiler: IRInterpretationState, construct_postdomtree, PiNode,
2- is_known_call, argextype, postdominates
2+ is_known_call, argextype, postdominates, userefs
33
44#=
55function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff::Vector{Pair{SSAValue, Int}}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}())
@@ -93,12 +93,6 @@ function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState,
9393 return Δtangent
9494 else # general frule handling
9595 info = inst[:info ]
96- if ! isa (info, FRuleCallInfo)
97- @show info
98- @show inst[:inst ]
99- display (ir)
100- error ()
101- end
10296 if isexpr (stmt, :invoke )
10397 args = stmt. args[2 : end ]
10498 else
@@ -196,22 +190,50 @@ function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_
196190 forward_visit! (ir, ssa, order, ssa_orders, visit_custom!)
197191 end
198192
193+ truncation_map = Dict {Pair{SSAValue, Int}, SSAValue} ()
194+
199195 # Step 2: Transform
200196 function maparg (arg, ssa, order)
201- if isa (arg, Argument)
197+ if isa (arg, SSAValue)
198+ if arg. id > length (ssa_orders)
199+ # This is possible if the custom transform touched another statement.
200+ # In that case just pass this through and assume the `transform!` did
201+ # it correctly.
202+ return arg
203+ end
204+ (argorder, _) = ssa_orders[arg. id]
205+ if argorder != order
206+ @assert order < argorder
207+ return get! (truncation_map, arg=> order) do
208+ # TODO : Other orders
209+ @assert order == 0
210+ insert_node! (ir, arg, NewInstruction (Expr (:call , primal, arg), Any), #= attach_after=# true )
211+ end
212+ end
213+ return arg
214+ elseif order == 0
215+ return arg
216+ elseif isa (arg, Argument)
202217 # TODO : Should we remember whether the callbacks wanted the arg?
203218 return transform! (ir, arg, order)
204- elseif isa (arg, SSAValue)
205- # TODO : Bundle truncation if necessary
206- return arg
219+ elseif isa (arg, GlobalRef)
220+ return insert_node! (ir, ssa, NewInstruction (Expr (:call , ZeroBundle{order}, arg), Any))
221+ elseif isa (arg, QuoteNode)
222+ return ZeroBundle {order} (arg. value)
207223 end
208224 @assert ! isa (arg, Expr)
209- return insert_node! (ir, ssa, NewInstruction ( Expr ( :call , ZeroBundle{order}, arg), Any) )
225+ return ZeroBundle {order} ( arg)
210226 end
211227
212228 for (ssa, (order, custom)) in enumerate (ssa_orders)
213229 if order == 0
214- # TODO : Bundle truncation?
230+ inst = ir[SSAValue (ssa)]
231+ stmt = inst[:inst ]
232+ urs = userefs (stmt)
233+ for ur in urs
234+ ur[] = maparg (ur[], SSAValue (ssa), order)
235+ end
236+ inst[:inst ] = urs[]
215237 continue
216238 end
217239 if custom
@@ -222,12 +244,16 @@ function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_
222244 if isexpr (stmt, :invoke )
223245 inst[:inst ] = Expr (:call , ∂☆ {order} (), map (arg-> maparg (arg, SSAValue (ssa), order), stmt. args[2 : end ])... )
224246 inst[:type ] = Any
225- elseif ! isa (stmt, Expr )
226- inst[:inst ] = maparg (stmt, ssa, order)
247+ elseif isexpr (stmt, :call )
248+ inst[:inst ] = Expr ( :call , ∂☆ {order} (), map (arg -> maparg (arg, SSAValue ( ssa) , order), stmt . args) ... )
227249 inst[:type ] = Any
228250 else
229- @show stmt
230- error ()
251+ urs = userefs (stmt)
252+ for ur in urs
253+ ur[] = maparg (ur[], SSAValue (ssa), order)
254+ end
255+ inst[:inst ] = urs[]
256+ inst[:type ] = Any
231257 end
232258 end
233259 end
0 commit comments