@@ -21,7 +21,11 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff:
2121end
2222=#
2323
24- function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , ssa:: SSAValue , order:: Int ; custom_diff!, diff_cache)
24+ # TODO interp::AbstractADInterpreter instead interp::AbstractInterpreter?
25+
26+ function forward_diff! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
27+ ssa:: SSAValue , order:: Int ;
28+ custom_diff!, diff_cache)
2529 if haskey (diff_cache, ssa)
2630 return diff_cache[ssa]
2731 end
@@ -34,9 +38,19 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSA
3438 end
3539 return Δssa
3640end
37- forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , val:: Union{Integer, AbstractFloat} , order:: Int ; custom_diff!, diff_cache) = zero (val)
38- forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , @nospecialize (arg), order:: Int ; custom_diff!, diff_cache) = ChainRulesCore. NoTangent ()
39- function forward_diff! (ir:: IRCode , interp, irsv:: IRInterpretationState , arg:: Argument , order:: Int ; custom_diff!, diff_cache)
41+ function forward_diff! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
42+ val:: Union{Integer, AbstractFloat} , order:: Int ;
43+ custom_diff!, diff_cache)
44+ return zero (val)
45+ end
46+ function forward_diff! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
47+ @nospecialize (arg), order:: Int ;
48+ custom_diff!, diff_cache)
49+ return ChainRulesCore. NoTangent ()
50+ end
51+ function forward_diff! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
52+ arg:: Argument , order:: Int ;
53+ custom_diff!, diff_cache)
4054 recurse (x) = forward_diff! (ir, interp, irsv, x; custom_diff!, diff_cache)
4155 val = custom_diff! (ir, SSAValue (0 ), arg, recurse)
4256 if val != = nothing
@@ -45,7 +59,9 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Arg
4559 return ChainRulesCore. NoTangent ()
4660end
4761
48- function forward_diff_uncached! (ir:: IRCode , interp, irsv:: IRInterpretationState , ssa:: SSAValue , inst:: Core.Compiler.Instruction , order:: Int ; custom_diff!, diff_cache)
62+ function forward_diff_uncached! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
63+ ssa:: SSAValue , inst:: Core.Compiler.Instruction , order:: Int ;
64+ custom_diff!, diff_cache)
4965 stmt = inst[:inst ]
5066 recurse (x) = forward_diff! (ir, interp, irsv, x, order; custom_diff!, diff_cache)
5167 if (val = custom_diff! (ir, ssa, stmt, recurse)) != = nothing
@@ -105,8 +121,7 @@ function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState,
105121 argtypes = Any[argextype (arg, ir) for arg in Δtpl. args[2 : end ]]
106122 tup_T = CC. tuple_tfunc (CC. typeinf_lattice (interp), argtypes)
107123
108- Δ = insert_node! (ir, ssa, NewInstruction (
109- Δtpl, tup_T))
124+ Δ = insert_node! (ir, ssa, NewInstruction (Δtpl, tup_T))
110125
111126 # Now that we know the arguments, do a proper typeinf for this particular callsite
112127 new_spec_types = Tuple{typeof (ChainRulesCore. frule), widenconst (tup_T), (widenconst (argextype (arg, ir)) for arg in args). .. }
@@ -175,15 +190,15 @@ function forward_visit!(ir::IRCode, ssa::SSAValue, order::Int, ssa_orders::Vecto
175190 error ()
176191 end
177192end
178- forward_visit! (ir :: IRCode , _, order :: Int , ssa_orders :: Vector{Pair{Int, Bool}} , visit_custom! ) = nothing
193+ forward_visit! (:: IRCode , @nospecialize (x), :: Int , :: Vector{Pair{Int, Bool}} , _ ) = nothing
179194function forward_visit! (ir:: IRCode , a:: Argument , order:: Int , ssa_orders:: Vector{Pair{Int, Bool}} , visit_custom!)
180195 recurse (@nospecialize (val)) = forward_visit! (ir, val, order, ssa_orders, visit_custom!)
181196 return visit_custom! (ir, a, order, recurse)
182197end
183198
184199
185200"""
186- forward_diff_no_inf!(ir, to_diff; visit_custom!, transform)
201+ forward_diff_no_inf!(ir::IRCode , to_diff::Vector{Pair{SSAValue,Int}} ; visit_custom!, transform! )
187202
188203Internal method which generates the code for forward mode diffentiation
189204
@@ -192,13 +207,14 @@ Internal method which generates the code for forward mode diffentiation
192207 - `to_diff`: collection of all SSA values for which the derivative is to be taken,
193208 paired with the order (first deriviative, second derivative etc)
194209
195- - `visit_custom!(ir, stmt, order::Int, recurse::Bool)`:
210+ - `visit_custom!(ir::IRCode , stmt, order::Int, recurse::Bool) -> Bool `:
196211 decides if the custom `transform!` should be applied to a `stmt` or not
197212 Default: `false` for all statements
198- - `transform!(ir, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
213+ - `transform!(ir::IRCode , ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
199214"""
200- function forward_diff_no_inf! (ir:: IRCode , to_diff:: Vector{Pair{SSAValue, Int}} ;
201- visit_custom! = (args... )-> false , transform! = (args... )-> error ())
215+ function forward_diff_no_inf! (ir:: IRCode , to_diff:: Vector{Pair{SSAValue,Int}} ;
216+ visit_custom! = (@nospecialize args... )-> false ,
217+ transform! = (@nospecialize args... )-> error ())
202218 # Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
203219 ssa_orders = [0 => false for i = 1 : length (ir. stmts)]
204220 for (ssa, order) in to_diff
@@ -208,7 +224,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
208224 truncation_map = Dict {Pair{SSAValue, Int}, SSAValue} ()
209225
210226 # Step 2: Transform
211- function maparg (arg, ssa, order)
227+ function maparg (@nospecialize ( arg) , ssa:: SSAValue , order:: Int )
212228 if isa (arg, SSAValue)
213229 if arg. id > length (ssa_orders)
214230 # This is possible if the custom transform touched another statement.
@@ -259,10 +275,16 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
259275 inst = ir[SSAValue (ssa)]
260276 stmt = inst[:inst ]
261277 if isexpr (stmt, :invoke )
262- inst[:inst ] = Expr (:call , ∂☆ {order} (), map (arg-> maparg (arg, SSAValue (ssa), order), stmt. args[2 : end ])... )
278+ newargs = map (stmt. args[2 : end ]) do @nospecialize arg
279+ maparg (arg, SSAValue (ssa), order)
280+ end
281+ inst[:inst ] = Expr (:call , ∂☆ {order} (), newargs... )
263282 inst[:type ] = Any
264283 elseif isexpr (stmt, :call )
265- inst[:inst ] = Expr (:call , ∂☆ {order} (), map (arg-> maparg (arg, SSAValue (ssa), order), stmt. args)... )
284+ newargs = map (stmt. args) do @nospecialize arg
285+ maparg (arg, SSAValue (ssa), order)
286+ end
287+ inst[:inst ] = Expr (:call , ∂☆ {order} (), newargs... )
266288 inst[:type ] = Any
267289 elseif isa (stmt, PiNode)
268290 # TODO : New PiNode that discriminates based on primal?
@@ -288,7 +310,6 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
288310 end
289311end
290312
291-
292313function forward_diff! (interp:: ADInterpreter , ir:: IRCode , src:: CodeInfo , mi:: MethodInstance ,
293314 to_diff:: Vector{Pair{SSAValue, Int}} ; kwargs... )
294315 forward_diff_no_inf! (ir, to_diff; kwargs... )
0 commit comments