@@ -1009,6 +1009,179 @@ OpenMPIRBuilder::createSection(const LocationDescription &Loc,
10091009 /* IsCancellable*/ true );
10101010}
10111011
1012+ // / Create a function with a unique name and a "void (i8*, i8*)" signature in
1013+ // / the given module and return it.
1014+ Function *getFreshReductionFunc (Module &M) {
1015+ Type *VoidTy = Type::getVoidTy (M.getContext ());
1016+ Type *Int8PtrTy = Type::getInt8PtrTy (M.getContext ());
1017+ auto *FuncTy =
1018+ FunctionType::get (VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false );
1019+ return Function::Create (FuncTy, GlobalVariable::InternalLinkage,
1020+ M.getDataLayout ().getDefaultGlobalsAddressSpace (),
1021+ " .omp.reduction.func" , &M);
1022+ }
1023+
1024+ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions (
1025+ const LocationDescription &Loc, InsertPointTy AllocaIP,
1026+ ArrayRef<ReductionInfo> ReductionInfos, bool IsNoWait) {
1027+ for (const ReductionInfo &RI : ReductionInfos) {
1028+ (void )RI;
1029+ assert (RI.Variable && " expected non-null variable" );
1030+ assert (RI.PrivateVariable && " expected non-null private variable" );
1031+ assert (RI.ReductionGen && " expected non-null reduction generator callback" );
1032+ assert (RI.Variable ->getType () == RI.PrivateVariable ->getType () &&
1033+ " expected variables and their private equivalents to have the same "
1034+ " type" );
1035+ assert (RI.Variable ->getType ()->isPointerTy () &&
1036+ " expected variables to be pointers" );
1037+ }
1038+
1039+ if (!updateToLocation (Loc))
1040+ return InsertPointTy ();
1041+
1042+ BasicBlock *InsertBlock = Loc.IP .getBlock ();
1043+ BasicBlock *ContinuationBlock =
1044+ InsertBlock->splitBasicBlock (Loc.IP .getPoint (), " reduce.finalize" );
1045+ InsertBlock->getTerminator ()->eraseFromParent ();
1046+
1047+ // Create and populate array of type-erased pointers to private reduction
1048+ // values.
1049+ unsigned NumReductions = ReductionInfos.size ();
1050+ Type *RedArrayTy = ArrayType::get (Builder.getInt8PtrTy (), NumReductions);
1051+ Builder.restoreIP (AllocaIP);
1052+ Value *RedArray = Builder.CreateAlloca (RedArrayTy, nullptr , " red.array" );
1053+
1054+ Builder.SetInsertPoint (InsertBlock, InsertBlock->end ());
1055+
1056+ for (auto En : enumerate(ReductionInfos)) {
1057+ unsigned Index = En.index ();
1058+ const ReductionInfo &RI = En.value ();
1059+ Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64 (
1060+ RedArrayTy, RedArray, 0 , Index, " red.array.elem." + Twine (Index));
1061+ Value *Casted =
1062+ Builder.CreateBitCast (RI.PrivateVariable , Builder.getInt8PtrTy (),
1063+ " private.red.var." + Twine (Index) + " .casted" );
1064+ Builder.CreateStore (Casted, RedArrayElemPtr);
1065+ }
1066+
1067+ // Emit a call to the runtime function that orchestrates the reduction.
1068+ // Declare the reduction function in the process.
1069+ Function *Func = Builder.GetInsertBlock ()->getParent ();
1070+ Module *Module = Func->getParent ();
1071+ Value *RedArrayPtr =
1072+ Builder.CreateBitCast (RedArray, Builder.getInt8PtrTy (), " red.array.ptr" );
1073+ Constant *SrcLocStr = getOrCreateSrcLocStr (Loc);
1074+ bool CanGenerateAtomic =
1075+ llvm::all_of (ReductionInfos, [](const ReductionInfo &RI) {
1076+ return RI.AtomicReductionGen ;
1077+ });
1078+ Value *Ident = getOrCreateIdent (
1079+ SrcLocStr, CanGenerateAtomic ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
1080+ : IdentFlag (0 ));
1081+ Value *ThreadId = getOrCreateThreadID (Ident);
1082+ Constant *NumVariables = Builder.getInt32 (NumReductions);
1083+ const DataLayout &DL = Module->getDataLayout ();
1084+ unsigned RedArrayByteSize = DL.getTypeStoreSize (RedArrayTy);
1085+ Constant *RedArraySize = Builder.getInt64 (RedArrayByteSize);
1086+ Function *ReductionFunc = getFreshReductionFunc (*Module);
1087+ Value *Lock = getOMPCriticalRegionLock (" .reduction" );
1088+ Function *ReduceFunc = getOrCreateRuntimeFunctionPtr (
1089+ IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
1090+ : RuntimeFunction::OMPRTL___kmpc_reduce);
1091+ CallInst *ReduceCall =
1092+ Builder.CreateCall (ReduceFunc,
1093+ {Ident, ThreadId, NumVariables, RedArraySize,
1094+ RedArrayPtr, ReductionFunc, Lock},
1095+ " reduce" );
1096+
1097+ // Create final reduction entry blocks for the atomic and non-atomic case.
1098+ // Emit IR that dispatches control flow to one of the blocks based on the
1099+ // reduction supporting the atomic mode.
1100+ BasicBlock *NonAtomicRedBlock =
1101+ BasicBlock::Create (Module->getContext (), " reduce.switch.nonatomic" , Func);
1102+ BasicBlock *AtomicRedBlock =
1103+ BasicBlock::Create (Module->getContext (), " reduce.switch.atomic" , Func);
1104+ SwitchInst *Switch =
1105+ Builder.CreateSwitch (ReduceCall, ContinuationBlock, /* NumCases */ 2 );
1106+ Switch->addCase (Builder.getInt32 (1 ), NonAtomicRedBlock);
1107+ Switch->addCase (Builder.getInt32 (2 ), AtomicRedBlock);
1108+
1109+ // Populate the non-atomic reduction using the elementwise reduction function.
1110+ // This loads the elements from the global and private variables and reduces
1111+ // them before storing back the result to the global variable.
1112+ Builder.SetInsertPoint (NonAtomicRedBlock);
1113+ for (auto En : enumerate(ReductionInfos)) {
1114+ const ReductionInfo &RI = En.value ();
1115+ Type *ValueType = RI.getElementType ();
1116+ Value *RedValue = Builder.CreateLoad (ValueType, RI.Variable ,
1117+ " red.value." + Twine (En.index ()));
1118+ Value *PrivateRedValue =
1119+ Builder.CreateLoad (ValueType, RI.PrivateVariable ,
1120+ " red.private.value." + Twine (En.index ()));
1121+ Value *Reduced;
1122+ Builder.restoreIP (
1123+ RI.ReductionGen (Builder.saveIP (), RedValue, PrivateRedValue, Reduced));
1124+ if (!Builder.GetInsertBlock ())
1125+ return InsertPointTy ();
1126+ Builder.CreateStore (Reduced, RI.Variable );
1127+ }
1128+ Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr (
1129+ IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
1130+ : RuntimeFunction::OMPRTL___kmpc_end_reduce);
1131+ Builder.CreateCall (EndReduceFunc, {Ident, ThreadId, Lock});
1132+ Builder.CreateBr (ContinuationBlock);
1133+
1134+ // Populate the atomic reduction using the atomic elementwise reduction
1135+ // function. There are no loads/stores here because they will be happening
1136+ // inside the atomic elementwise reduction.
1137+ Builder.SetInsertPoint (AtomicRedBlock);
1138+ if (CanGenerateAtomic) {
1139+ for (const ReductionInfo &RI : ReductionInfos) {
1140+ Builder.restoreIP (RI.AtomicReductionGen (Builder.saveIP (), RI.Variable ,
1141+ RI.PrivateVariable ));
1142+ if (!Builder.GetInsertBlock ())
1143+ return InsertPointTy ();
1144+ }
1145+ Builder.CreateBr (ContinuationBlock);
1146+ } else {
1147+ Builder.CreateUnreachable ();
1148+ }
1149+
1150+ // Populate the outlined reduction function using the elementwise reduction
1151+ // function. Partial values are extracted from the type-erased array of
1152+ // pointers to private variables.
1153+ BasicBlock *ReductionFuncBlock =
1154+ BasicBlock::Create (Module->getContext (), " " , ReductionFunc);
1155+ Builder.SetInsertPoint (ReductionFuncBlock);
1156+ Value *LHSArrayPtr = Builder.CreateBitCast (ReductionFunc->getArg (0 ),
1157+ RedArrayTy->getPointerTo ());
1158+ Value *RHSArrayPtr = Builder.CreateBitCast (ReductionFunc->getArg (1 ),
1159+ RedArrayTy->getPointerTo ());
1160+ for (auto En : enumerate(ReductionInfos)) {
1161+ const ReductionInfo &RI = En.value ();
1162+ Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64 (
1163+ RedArrayTy, LHSArrayPtr, 0 , En.index ());
1164+ Value *LHSI8Ptr = Builder.CreateLoad (Builder.getInt8PtrTy (), LHSI8PtrPtr);
1165+ Value *LHSPtr = Builder.CreateBitCast (LHSI8Ptr, RI.Variable ->getType ());
1166+ Value *LHS = Builder.CreateLoad (RI.getElementType (), LHSPtr);
1167+ Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64 (
1168+ RedArrayTy, RHSArrayPtr, 0 , En.index ());
1169+ Value *RHSI8Ptr = Builder.CreateLoad (Builder.getInt8PtrTy (), RHSI8PtrPtr);
1170+ Value *RHSPtr =
1171+ Builder.CreateBitCast (RHSI8Ptr, RI.PrivateVariable ->getType ());
1172+ Value *RHS = Builder.CreateLoad (RI.getElementType (), RHSPtr);
1173+ Value *Reduced;
1174+ Builder.restoreIP (RI.ReductionGen (Builder.saveIP (), LHS, RHS, Reduced));
1175+ if (!Builder.GetInsertBlock ())
1176+ return InsertPointTy ();
1177+ Builder.CreateStore (Reduced, LHSPtr);
1178+ }
1179+ Builder.CreateRetVoid ();
1180+
1181+ Builder.SetInsertPoint (ContinuationBlock);
1182+ return Builder.saveIP ();
1183+ }
1184+
10121185OpenMPIRBuilder::InsertPointTy
10131186OpenMPIRBuilder::createMaster (const LocationDescription &Loc,
10141187 BodyGenCallbackTy BodyGenCB,
0 commit comments