@@ -3,7 +3,6 @@ use std::cmp::Ordering;
33
44use rustc_abi:: { Align , BackendRepr , ExternAbi , Float , HasDataLayout , Primitive , Size } ;
55use rustc_codegen_ssa:: base:: { compare_simd_types, wants_msvc_seh, wants_wasm_eh} ;
6- use rustc_codegen_ssa:: codegen_attrs:: autodiff_attrs;
76use rustc_codegen_ssa:: common:: { IntPredicate , TypeKind } ;
87use rustc_codegen_ssa:: errors:: { ExpectedPointerMutability , InvalidMonomorphization } ;
98use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
@@ -1157,7 +1156,13 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11571156 Instance :: try_resolve ( tcx, bx. cx . typing_env ( ) , * diff_id, diff_args) . unwrap ( ) . unwrap ( ) ;
11581157 let diff_symbol = symbol_name_for_instance_in_crate ( tcx, fn_diff. clone ( ) , LOCAL_CRATE ) ;
11591158
1160- let diff_attrs = autodiff_attrs ( tcx, fn_diff. def_id ( ) ) ;
1159+ // TODO(Sa4dUs): Store autodiff items in a single pass and just get them here
1160+ // in a O(1) step
1161+ let diff_attrs = tcx
1162+ . collect_and_partition_mono_items ( ( ) )
1163+ . autodiff_items
1164+ . iter ( )
1165+ . find ( |item| item. target == diff_symbol) ;
11611166 let Some ( diff_attrs) = diff_attrs else { bug ! ( "could not find autodiff attrs" ) } ;
11621167
11631168 // Build body
@@ -1168,7 +1173,7 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11681173 & diff_symbol,
11691174 llret_ty,
11701175 & val_arr,
1171- diff_attrs. clone ( ) ,
1176+ diff_attrs. attrs . clone ( ) ,
11721177 result,
11731178 ) ;
11741179}
@@ -1185,11 +1190,22 @@ fn get_args_from_tuple<'ll, 'tcx>(
11851190 for i in 0 ..tuple_place. layout . layout . 0 . fields . count ( ) {
11861191 let field_place = tuple_place. project_field ( bx, i) ;
11871192 let field_layout = tuple_place. layout . field ( bx, i) ;
1193+ let field_ty = field_layout. ty ;
11881194 let llvm_ty = field_layout. llvm_type ( bx. cx ) ;
11891195
11901196 let field_val = bx. load ( llvm_ty, field_place. val . llval , field_place. val . align ) ;
11911197
1192- ret_arr. push ( field_val)
1198+ match field_ty. kind ( ) {
1199+ ty:: Ref ( _, inner_ty, _) if matches ! ( inner_ty. kind( ) , ty:: Slice ( _) ) => {
1200+ let ptr = bx. extract_value ( field_val, 0 ) ;
1201+ let len = bx. extract_value ( field_val, 1 ) ;
1202+ ret_arr. push ( ptr) ;
1203+ ret_arr. push ( len) ;
1204+ }
1205+ _ => {
1206+ ret_arr. push ( field_val) ;
1207+ }
1208+ }
11931209 }
11941210
11951211 ret_arr
0 commit comments