@@ -501,30 +501,42 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
501501 else
502502 changes = LLVM. API. LLVMCloneFunctionChangeTypeLocalChangesOnly
503503 end
504- clone_into! (new_f, f; value_map, changes)
504+
505+ # use a value materializer for replacing uses of the function in constants
506+ # NOTE: we assume kernel functions can't be called. on-device kernel launches,
507+ # e.g. CUDA's dynamic parallelism, will pass the function to an API instead,
508+ # and we update those constant expressions arguments here.
509+ function materializer (val)
510+ opcodes = (LLVM. API. LLVMPtrToInt, LLVM. API. LLVMAddrSpaceCast, LLVM. API. LLVMBitCast)
511+ if val isa LLVM. ConstantExpr && opcode (val) in opcodes
512+ target = operands (val)[1 ]
513+ if target == f
514+ return if opcode (val) == LLVM. API. LLVMPtrToInt
515+ LLVM. const_ptrtoint (new_f, llvmtype (val))
516+ elseif opcode (val) == LLVM. API. LLVMAddrSpaceCast
517+ LLVM. const_addrspacecast (new_f, llvmtype (val))
518+ elseif opcode (val) == LLVM. API. LLVMBitCast
519+ LLVM. const_bitcast (new_f, llvmtype (val))
520+ end
521+ end
522+ end
523+ return val
524+ end
525+
526+ # we don't want module-level changes, because otherwise LLVM will clone metadata,
527+ # resulting in mismatching references between `!dbg` metadata and `dbg` instructions
528+ clone_into! (new_f, f; value_map, changes, materializer)
505529
506530 # fall through
507531 br! (builder, blocks (new_f)[2 ])
508532 end
509533
510- # update uses of the kernel
511- # NOTE: we assume kernel functions can't be called. on-device kernel launches,
512- # e.g. CUDA's dynamic parallelism, will pass the function to an API instead,
513- # and we update those constant expressions arguments here.
534+ # drop unused constants that may be referring to the old functions
535+ # XXX : can we do this differently?
514536 for use in uses (f)
515537 val = user (use)
516- if val isa LLVM. ConstantExpr && opcode (val) == LLVM. API. LLVMPtrToInt
517- target = operands (val)[1 ]
518- if target == f
519- new_val = LLVM. const_ptrtoint (new_f, llvmtype (val))
520- replace_uses! (val, new_val)
521-
522- # drop the old constant if it is unused
523- # XXX : can we do this differently?
524- if isempty (uses (val))
525- LLVM. unsafe_destroy! (val)
526- end
527- end
538+ if val isa LLVM. ConstantExpr && isempty (uses (val))
539+ LLVM. unsafe_destroy! (val)
528540 end
529541 end
530542
@@ -576,8 +588,30 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
576588 # this is both for extern uses, and to make this transformation a two-step process.
577589 state_intr = kernel_state_intr (mod, T_state)
578590
579- # add a state argument to every function
580- worklist = filter (! isdeclaration, collect (functions (mod)))
591+ # determine which functions need a kernel state argument
592+ #
593+ # previously, we add the argument to every function and relied on unused arg elim to
594+ # clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting
595+ # function pointers. such IR is hard to rewrite, so instead be more conservative.
596+ worklist = Set {LLVM.Function} ([entry, state_intr])
597+ worklist_length = 0
598+ while worklist_length != length (worklist)
599+ # iteratively discover functions that use the intrinsic or any function calling it
600+ worklist_length = length (worklist)
601+ additions = LLVM. Function[]
602+ for f in worklist, use in uses (f)
603+ inst = user (use):: Instruction
604+ bb = LLVM. parent (inst)
605+ new_f = LLVM. parent (bb)
606+ in (new_f, worklist) || push! (additions, new_f)
607+ end
608+ for f in additions
609+ push! (worklist, f)
610+ end
611+ end
612+ delete! (worklist, state_intr)
613+
614+ # add a state argument
581615 workmap = Dict {LLVM.Function, LLVM.Function} ()
582616 for f in worklist
583617 fn = LLVM. name (f)
@@ -608,10 +642,17 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
608642
609643 # use a value materializer for replacing uses of the function in constants
610644 function materializer (val)
611- if val isa LLVM. ConstantExpr && opcode (val) == LLVM. API. LLVMPtrToInt
645+ opcodes = (LLVM. API. LLVMPtrToInt, LLVM. API. LLVMAddrSpaceCast, LLVM. API. LLVMBitCast)
646+ if val isa LLVM. ConstantExpr && opcode (val) in opcodes
612647 src = operands (val)[1 ]
613648 if haskey (workmap, src)
614- return LLVM. const_ptrtoint (workmap[src], llvmtype (val))
649+ return if opcode (val) == LLVM. API. LLVMPtrToInt
650+ LLVM. const_ptrtoint (workmap[src], llvmtype (val))
651+ elseif opcode (val) == LLVM. API. LLVMAddrSpaceCast
652+ LLVM. const_addrspacecast (workmap[src], llvmtype (val))
653+ elseif opcode (val) == LLVM. API. LLVMBitCast
654+ LLVM. const_bitcast (workmap[src], llvmtype (val))
655+ end
615656 end
616657 end
617658 return val
@@ -677,20 +718,6 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
677718 replace_uses! (val, new_val)
678719 @assert isempty (uses (val))
679720 unsafe_delete! (LLVM. parent (val), val)
680- elseif val isa LLVM. ConstantExpr && opcode (val) == LLVM. API. LLVMBitCast
681- # XXX : why isn't this caught by the value materializer above?
682- target = operands (val)[1 ]
683- @assert target == f
684- new_val = LLVM. const_bitcast (new_f, llvmtype (val))
685- rewrite_uses! (val, new_val)
686- # we can't simply replace this constant expression, as it may be used
687- # as a call, taking arguments (so we need to rewrite it to pass the state)
688-
689- # drop the old constant if it is unused
690- # XXX : can we do this differently?
691- if isempty (uses (val))
692- LLVM. unsafe_destroy! (val)
693- end
694721 else
695722 error (" Cannot rewrite unknown use of function: $val " )
696723 end
@@ -721,14 +748,10 @@ function lower_kernel_state!(fun::LLVM.Function)
721748 return false
722749 end
723750
724- # find the kernel state argument. this should be the first argument of the function.
725- state_arg = parameters (fun)[1 ]
726- T_state = convert (LLVMType, state; ctx)
727- @assert llvmtype (state_arg) == T_state
728-
729751 # fixup all uses of the state getter to use the newly introduced function state argument
730752 if haskey (functions (mod), " julia.gpu.state_getter" )
731753 state_intr = functions (mod)[" julia.gpu.state_getter" ]
754+ state_arg = nothing # only look-up when needed
732755
733756 Builder (ctx) do builder
734757 for use in uses (state_intr)
@@ -741,6 +764,14 @@ function lower_kernel_state!(fun::LLVM.Function)
741764 bb = LLVM. parent (inst)
742765 f = LLVM. parent (bb)
743766
767+ if state_arg === nothing
768+ # find the kernel state argument. this should be the first argument of
769+ # the function, but only when this function needs the state!
770+ state_arg = parameters (fun)[1 ]
771+ T_state = convert (LLVMType, state; ctx)
772+ @assert llvmtype (state_arg) == T_state
773+ end
774+
744775 replace_uses! (inst, state_arg)
745776
746777 @assert isempty (uses (inst))
0 commit comments