1- #![ allow( unused_imports) ]
2- #![ allow( unused_variables) ]
3- use std:: ffi:: { CStr , CString } ;
1+ use std:: ffi:: CString ;
42use std:: io:: { self , Write } ;
53use std:: path:: { Path , PathBuf } ;
64use std:: sync:: Arc ;
75use std:: { fs, slice, str} ;
86
97use libc:: { c_char, c_int, c_uint, c_void, size_t} ;
108use llvm:: {
11- IntPredicate , LLVMGetNextBasicBlock , LLVMRustDISetInstMetadata ,
9+ IntPredicate ,
1210 LLVMRustLLVMHasZlibCompressionForDebugSymbols , LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
1311} ;
1412use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffItem , DiffActivity , DiffMode } ;
@@ -47,27 +45,26 @@ use crate::errors::{
4745} ;
4846use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind ;
4947use crate :: llvm:: {
50- self , enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind , BasicBlock ,
48+ self , enzyme_rust_forward_diff, enzyme_rust_reverse_diff, AttributeKind ,
5149 CreateEnzymeLogic , CreateTypeAnalysis , DiagnosticInfo , EnzymeLogicRef , EnzymeTypeAnalysisRef ,
52- FreeTypeAnalysis , LLVMAddFunction , LLVMAppendBasicBlockInContext , LLVMBuildCall2 ,
50+ FreeTypeAnalysis , LLVMAppendBasicBlockInContext , LLVMBuildCall2 ,
5351 LLVMBuildCondBr , LLVMBuildExtractValue , LLVMBuildICmp , LLVMBuildRet , LLVMBuildRetVoid ,
5452 LLVMCountParams , LLVMCountStructElementTypes , LLVMCreateBuilderInContext ,
55- LLVMCreateStringAttribute , LLVMDeleteFunction , LLVMDisposeBuilder , LLVMDumpModule ,
56- LLVMGetBasicBlockTerminator , LLVMGetFirstBasicBlock , LLVMGetFirstFunction ,
57- LLVMGetModuleContext , LLVMGetNextFunction , LLVMGetParams , LLVMGetReturnType ,
53+ LLVMCreateStringAttribute , LLVMDisposeBuilder , LLVMDumpModule ,
54+ LLVMGetFirstBasicBlock , LLVMGetFirstFunction ,
55+ LLVMGetNextFunction , LLVMGetParams , LLVMGetReturnType ,
5856 LLVMGetStringAttributeAtIndex , LLVMGlobalGetValueType , LLVMIsEnumAttribute ,
5957 LLVMIsStringAttribute , LLVMMetadataAsValue , LLVMPositionBuilderAtEnd ,
60- LLVMRemoveStringAttributeAtIndex , LLVMReplaceAllUsesWith , LLVMRustAddEnumAttributeAtIndex ,
61- LLVMRustAddFunctionAttributes , LLVMRustDIGetInstMetadata , LLVMRustDIGetInstMetadataOfTy ,
62- LLVMRustEraseBBFromParent , LLVMRustEraseInstBefore , LLVMRustEraseInstFromParent ,
58+ LLVMRemoveStringAttributeAtIndex , LLVMRustAddEnumAttributeAtIndex ,
59+ LLVMRustAddFunctionAttributes , LLVMRustDIGetInstMetadata ,
60+ LLVMRustEraseInstBefore , LLVMRustEraseInstFromParent ,
6361 LLVMRustGetEnumAttributeAtIndex , LLVMRustGetFunctionType , LLVMRustGetLastInstruction ,
64- LLVMRustGetTerminator , LLVMRustHasDbgMetadata , LLVMRustHasMetadata ,
65- LLVMRustRemoveEnumAttributeAtIndex , LLVMRustRemoveFncAttr ,
66- LLVMRustgetFirstNonPHIOrDbgOrLifetime , LLVMSetValueName2 , LLVMVerifyFunction ,
62+ LLVMRustGetTerminator , LLVMRustHasMetadata ,
63+ LLVMRustRemoveEnumAttributeAtIndex ,
64+ LLVMVerifyFunction ,
6765 LLVMVoidTypeInContext , PassManager , Value ,
6866} ;
6967use crate :: type_:: Type ;
70- use crate :: typetree:: to_enzyme_typetree;
7168use crate :: { base, common, llvm_util, DiffTypeTree , LlvmCodegenBackend , ModuleLlvm } ;
7269
7370pub fn llvm_err < ' a > ( dcx : DiagCtxtHandle < ' _ > , err : LlvmError < ' a > ) -> FatalError {
@@ -669,50 +666,27 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
669666
670667// DESIGN:
671668// Today we have our placeholder function, and our Enzyme generated one.
672- // We create a wrapper function and delete the placeholder body.
673- // We then call the wrapper from the placeholder.
669+ // We create a wrapper function and delete the placeholder body. You can see the
670+ // placeholder by running `cargo expand` on an autodiff invocation. We call the wrapper
671+ // from the placeholder. This function is a bit longer, because it matches the Rust level
672+ // autodiff macro with LLVM level Enzyme autodiff expectations.
674673//
675- // Soon, we won't delete the whole placeholder, but just the loop,
676- // and the two inline asm sections. For now we can still call the wrapper.
677- // In the future we call our Enzyme generated function directly and unwrap the return
678- // struct in our original placeholder.
679- //
680- // define internal double @_ZN2ad3bar17ha38374e821680177E(ptr align 8 %0, ptr align 8 %1, double %2) unnamed_addr #17 !dbg !13678 {
681- // %4 = alloca double, align 8
682- // %5 = alloca ptr, align 8
683- // %6 = alloca ptr, align 8
684- // %7 = alloca { ptr, double }, align 8
685- // store ptr %0, ptr %6, align 8
686- // call void @llvm.dbg.declare(metadata ptr %6, metadata !13682, metadata !DIExpression()), !dbg !13685
687- // store ptr %1, ptr %5, align 8
688- // call void @llvm.dbg.declare(metadata ptr %5, metadata !13683, metadata !DIExpression()), !dbg !13685
689- // store double %2, ptr %4, align 8
690- // call void @llvm.dbg.declare(metadata ptr %4, metadata !13684, metadata !DIExpression()), !dbg !13686
691- // call void asm sideeffect alignstack inteldialect "NOP", "~{dirflag},~{fpsr},~{flags},~{memory}"(), !dbg !13687, !srcloc !23
692- // %8 = call double @_ZN2ad3foo17h95b548a9411653b2E(ptr align 8 %0), !dbg !13687
693- // %9 = call double @_ZN4core4hint9black_box17h7bd67a41b0f12bdfE(double %8), !dbg !13687
694- // store ptr %1, ptr %7, align 8, !dbg !13687
695- // %10 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 1, !dbg !13687
696- // store double %2, ptr %10, align 8, !dbg !13687
697- // %11 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 0, !dbg !13687
698- // %12 = load ptr, ptr %11, align 8, !dbg !13687, !nonnull !23, !align !1047, !noundef !23
699- // %13 = getelementptr inbounds { ptr, double }, ptr %7, i32 0, i32 1, !dbg !13687
700- // %14 = load double, ptr %13, align 8, !dbg !13687, !noundef !23
701- // %15 = call { ptr, double } @_ZN4core4hint9black_box17h669f3b22afdcb487E(ptr align 8 %12, double %14), !dbg !13687
702- // %16 = extractvalue { ptr, double } %15, 0, !dbg !13687
703- // %17 = extractvalue { ptr, double } %15, 1, !dbg !13687
704- // br label %18, !dbg !13687
705- //
706- //18: ; preds = %18, %3
707- // br label %18, !dbg !13687
674+ // Think of computing the derivative with respect to &[f32] by marking it as duplicated.
675+ // The user will then pass an extra &mut [f32] and we want add the derivative to that.
676+ // On LLVM/Enzyme level, &[f32] however becomes `ptr, i64` and we mark ptr as duplicated,
677+ // and i64 (len) as const. Enzyme will then expect `ptr, ptr, i64` as arguments. See how the
678+ // second i64 from the mut slice isn't used? That's why we add a safety check to assert
679+ // that the second (mut) slice is at least as long as the first (const) slice. Otherwise,
680+ // Enzyme would write out of bounds if the first (const) slice is longer than the second.
708681
709682unsafe fn create_call < ' a > (
710683 tgt : & ' a Value ,
711684 src : & ' a Value ,
712- rev_mode : bool ,
713685 llmod : & ' a llvm:: Module ,
714686 llcx : & llvm:: Context ,
715- size_positions : & [ usize ] ,
687+ // FIXME: Instead of recomputing the positions as we do it below, we should
688+ // start using this list of positions that indicate length integers.
689+ _size_positions : & [ usize ] ,
716690 ad : & [ AutoDiff ] ,
717691) {
718692 unsafe {
@@ -756,9 +730,10 @@ unsafe fn create_call<'a>(
756730 inner_pos += 1 ;
757731 outer_pos += 1 ;
758732 } else {
759- // out: (ptr, <>int1, ptr, int2)
733+ // out: rust: (&[f32], &mut [f32])
734+ // out: llvm: (ptr, <>int1, ptr, int2)
760735 // inner: (ptr, <>ptr, int)
761- // goal: (ptr, ptr, int1), skipping int2
736+ // goal: call (ptr, ptr, int1), skipping int2
762737 // we are here: <>
763738 assert ! ( llvm:: LLVMRustGetTypeKind ( outer_arg_ty) == llvm:: TypeKind :: Integer ) ;
764739 assert ! ( llvm:: LLVMRustGetTypeKind ( inner_arg_ty) == llvm:: TypeKind :: Pointer ) ;
@@ -872,17 +847,17 @@ unsafe fn create_call<'a>(
872847 ) ;
873848
874849 // Add dummy dbg info to our newly generated call, if we have any.
875- let inst = LLVMRustgetFirstNonPHIOrDbgOrLifetime ( bb) . unwrap ( ) ;
876850 let md_ty = llvm:: LLVMGetMDKindIDInContext (
877851 llcx,
878852 "dbg" . as_ptr ( ) as * const c_char ,
879853 "dbg" . len ( ) as c_uint ,
880854 ) ;
881855
856+
882857 if LLVMRustHasMetadata ( last_inst, md_ty) {
883858 let md = LLVMRustDIGetInstMetadata ( last_inst) ;
884859 let md_val = LLVMMetadataAsValue ( llcx, md) ;
885- let md2 = llvm:: LLVMSetMetadata ( struct_ret, md_ty, md_val) ;
860+ let _md2 = llvm:: LLVMSetMetadata ( struct_ret, md_ty, md_val) ;
886861 } else {
887862 trace ! ( "No dbg info" ) ;
888863 }
@@ -938,8 +913,8 @@ unsafe fn get_panic_name(llmod: &llvm::Module) -> CString {
938913// For now we only check if shadow arguments are large enough. In this case we look for Rust panic
939914// functions in the module and call it. Due to hashing we can't hardcode the panic function name.
940915// Note: This worked even for panic=abort tests so seems solid enough for now.
941- // TODO : Pick a panic function which allows displaying an errormessage .
942- // TODO : We probably want to keep a handle at higher level and pass it down instead of searching.
916+ // FIXME : Pick a panic function which allows displaying an error message .
917+ // FIXME : We probably want to keep a handle at higher level and pass it down instead of searching.
943918unsafe fn add_panic_msg_to_global < ' a > (
944919 llmod : & ' a llvm:: Module ,
945920 llcx : & ' a llvm:: Context ,
@@ -961,7 +936,7 @@ unsafe fn add_panic_msg_to_global<'a>(
961936 let i8_array_type = LLVMArrayType2 ( LLVMInt8TypeInContext ( llcx) , msg_len as u64 ) ;
962937
963938 // Create the string constant
964- let string_const_val =
939+ let _string_const_val =
965940 LLVMConstStringInContext2 ( llcx, cmsg. as_ptr ( ) as * const i8 , msg_len as usize , 0 ) ;
966941
967942 // Create the array initializer
@@ -1098,8 +1073,7 @@ pub(crate) unsafe fn enzyme_ad(
10981073
10991074 let f_return_type = LLVMGetReturnType ( LLVMGlobalGetValueType ( res) ) ;
11001075
1101- let rev_mode = item. attrs . mode == DiffMode :: Reverse ;
1102- create_call ( target_fnc, res, rev_mode, llmod, llcx, & size_positions, ad) ;
1076+ create_call ( target_fnc, res, llmod, llcx, & size_positions, ad) ;
11031077 // TODO: implement drop for wrapper type?
11041078 FreeTypeAnalysis ( type_analysis) ;
11051079 }
@@ -1133,10 +1107,6 @@ pub(crate) unsafe fn differentiate(
11331107
11341108 // Before dumping the module, we want all the tt to become part of the module.
11351109 for ( i, item) in diff_items. iter ( ) . enumerate ( ) {
1136- let llvm_data_layout = unsafe { llvm:: LLVMGetDataLayoutStr ( & * llmod) } ;
1137- let llvm_data_layout =
1138- std:: str:: from_utf8 ( unsafe { CStr :: from_ptr ( llvm_data_layout) } . to_bytes ( ) )
1139- . expect ( "got a non-UTF8 data-layout from LLVM" ) ;
11401110 let tt: FncTree = FncTree { args : item. inputs . clone ( ) , ret : item. output . clone ( ) } ;
11411111 let name = CString :: new ( item. source . clone ( ) ) . unwrap ( ) ;
11421112 let fn_def: & llvm:: Value =
0 commit comments