@@ -213,25 +213,12 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
213213 if (loop.lowerBound ().empty ())
214214 return failure ();
215215
216- if (loop.getNumLoops () != 1 )
217- return opInst.emitOpError (" collapsed loops not yet supported" );
218-
219216 // Static is the default.
220217 omp::ClauseScheduleKind schedule = omp::ClauseScheduleKind::Static;
221218 if (loop.schedule_val ().hasValue ())
222219 schedule =
223220 *omp::symbolizeClauseScheduleKind (loop.schedule_val ().getValue ());
224221
225- // Find the loop configuration.
226- llvm::Value *lowerBound = moduleTranslation.lookupValue (loop.lowerBound ()[0 ]);
227- llvm::Value *upperBound = moduleTranslation.lookupValue (loop.upperBound ()[0 ]);
228- llvm::Value *step = moduleTranslation.lookupValue (loop.step ()[0 ]);
229- llvm::Type *ivType = step->getType ();
230- llvm::Value *chunk =
231- loop.schedule_chunk_var ()
232- ? moduleTranslation.lookupValue (loop.schedule_chunk_var ())
233- : llvm::ConstantInt::get (ivType, 1 );
234-
235222 // Set up the source location value for OpenMP runtime.
236223 llvm::DISubprogram *subprogram =
237224 builder.GetInsertBlock ()->getParent ()->getSubprogram ();
@@ -240,22 +227,29 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
240227 llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder.saveIP (),
241228 llvm::DebugLoc (diLoc));
242229
243- // Generator of the canonical loop body. Produces an SESE region of basic
244- // blocks.
230+ // Generator of the canonical loop body.
245231 // TODO: support error propagation in OpenMPIRBuilder and use it instead of
246232 // relying on captured variables.
233+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
234+ SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
247235 LogicalResult bodyGenStatus = success ();
248236 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
249- llvm::IRBuilder<>::InsertPointGuard guard (builder);
250-
251237 // Make sure further conversions know about the induction variable.
252- moduleTranslation.mapValue (loop.getRegion ().front ().getArgument (0 ), iv);
238+ moduleTranslation.mapValue (
239+ loop.getRegion ().front ().getArgument (loopInfos.size ()), iv);
240+
241+ // Capture the body insertion point for use in nested loops. BodyIP of the
242+ // CanonicalLoopInfo always points to the beginning of the entry block of
243+ // the body.
244+ bodyInsertPoints.push_back (ip);
245+
246+ if (loopInfos.size () != loop.getNumLoops () - 1 )
247+ return ;
253248
249+ // Convert the body of the loop.
254250 llvm::BasicBlock *entryBlock = ip.getBlock ();
255251 llvm::BasicBlock *exitBlock =
256252 entryBlock->splitBasicBlock (ip.getPoint (), " omp.wsloop.exit" );
257-
258- // Convert the body of the loop.
259253 convertOmpOpRegions (loop.region (), " omp.wsloop.region" , *entryBlock,
260254 *exitBlock, builder, moduleTranslation, bodyGenStatus);
261255 };
@@ -264,17 +258,46 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
264258 // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
265259 // i.e. it has a positive step, uses signed integer semantics. Reconsider
266260 // this code when WsLoop clearly supports more cases.
261+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
262+ for (unsigned i = 0 , e = loop.getNumLoops (); i < e; ++i) {
263+ llvm::Value *lowerBound =
264+ moduleTranslation.lookupValue (loop.lowerBound ()[i]);
265+ llvm::Value *upperBound =
266+ moduleTranslation.lookupValue (loop.upperBound ()[i]);
267+ llvm::Value *step = moduleTranslation.lookupValue (loop.step ()[i]);
268+
269+ // Make sure loop trip count are emitted in the preheader of the outermost
270+ // loop at the latest so that they are all available for the new collapsed
271+ // loop will be created below.
272+ llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
273+ llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP ;
274+ if (i != 0 ) {
275+ loc = llvm::OpenMPIRBuilder::LocationDescription (bodyInsertPoints.back (),
276+ llvm::DebugLoc (diLoc));
277+ computeIP = loopInfos.front ()->getPreheaderIP ();
278+ }
279+ loopInfos.push_back (ompBuilder->createCanonicalLoop (
280+ loc, bodyGen, lowerBound, upperBound, step,
281+ /* IsSigned=*/ true , loop.inclusive (), computeIP));
282+
283+ if (failed (bodyGenStatus))
284+ return failure ();
285+ }
286+
287+ // Collapse loops. Store the insertion point because LoopInfos may get
288+ // invalidated.
289+ llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front ()->getAfterIP ();
267290 llvm::CanonicalLoopInfo *loopInfo =
268- moduleTranslation.getOpenMPBuilder ()->createCanonicalLoop (
269- ompLoc, bodyGen, lowerBound, upperBound, step, /* IsSigned=*/ true ,
270- /* InclusiveStop=*/ loop.inclusive ());
271- if (failed (bodyGenStatus))
272- return failure ();
291+ ompBuilder->collapseLoops (diLoc, loopInfos, {});
273292
293+ // Find the loop configuration.
294+ llvm::Type *ivType = loopInfo->getIndVar ()->getType ();
295+ llvm::Value *chunk =
296+ loop.schedule_chunk_var ()
297+ ? moduleTranslation.lookupValue (loop.schedule_chunk_var ())
298+ : llvm::ConstantInt::get (ivType, 1 );
274299 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
275300 findAllocaInsertPoint (builder, moduleTranslation);
276- llvm::OpenMPIRBuilder::InsertPointTy afterIP;
277- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
278301
279302 bool isSimd = false ;
280303 if (auto simd = loop.simd_modifier ()) {
@@ -283,9 +306,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
283306 }
284307
285308 if (schedule == omp::ClauseScheduleKind::Static) {
286- loopInfo = ompBuilder->createStaticWorkshareLoop (ompLoc, loopInfo, allocaIP,
287- !loop.nowait (), chunk);
288- afterIP = loopInfo->getAfterIP ();
309+ ompBuilder->createStaticWorkshareLoop (ompLoc, loopInfo, allocaIP,
310+ !loop.nowait (), chunk);
289311 } else {
290312 llvm::omp::OMPScheduleType schedType;
291313 switch (schedule) {
@@ -328,11 +350,14 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
328350 break ;
329351 }
330352 }
331- afterIP = ompBuilder->createDynamicWorkshareLoop (
353+ ompBuilder->createDynamicWorkshareLoop (
332354 ompLoc, loopInfo, allocaIP, schedType, !loop.nowait (), chunk);
333355 }
334356
335- // Continue building IR after the loop.
357+ // Continue building IR after the loop. Note that the LoopInfo returned by
358+ // `collapseLoops` points inside the outermost loop and is intended for
359+ // potential further loop transformations. Use the insertion point stored
360+ // before collapsing loops instead.
336361 builder.restoreIP (afterIP);
337362 return success ();
338363}
0 commit comments