Skip to content

Commit 7b7ff69

Browse files
authored
[Flang] Support generic execution of parallel regions (#414)
This set of patches removes the early tagging of Generic-SPMD target regions from MLIR to instead only tell apart Generic from SPMD. This matches the behavior of Clang, which then relies on the OpenMPOpt pass to detect situations where Generic kernels can be executed in SPMD mode, potentially after certain transformations. Merging this PR results in split distribute + parallel do kernels running in Generic mode, which might cause performance regressions in these cases. This is because the OpenMPOpt pass is currently not prepared to properly SPMDize Generic kernels containing new DeviceRTL loop functions that only Flang currently generates. Generic mode before these changes is broken when parallel regions are reached. With this, it should be possible to properly execute them.
1 parent 4f5eaa1 commit 7b7ff69

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2832
-942
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11392,8 +11392,8 @@ void CGOpenMPRuntime::emitTargetDataCalls(
1139211392
llvm::OpenMPIRBuilder::LocationDescription OmpLoc(CodeGenIP);
1139311393
llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
1139411394
cantFail(OMPBuilder.createTargetData(
11395-
OmpLoc, AllocaIP, CodeGenIP, DeviceID, IfCondVal, Info, GenMapInfoCB,
11396-
CustomMapperCB,
11395+
OmpLoc, AllocaIP, CodeGenIP, /*DeallocIPs=*/{}, DeviceID, IfCondVal,
11396+
Info, GenMapInfoCB, CustomMapperCB,
1139711397
/*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, RTLoc));
1139811398
CGF.Builder.restoreIP(AfterIP);
1139911399
}

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2238,20 +2238,21 @@ void CodeGenFunction::EmitOMPParallelDirective(const OMPParallelDirective &S) {
22382238
const CapturedStmt *CS = S.getCapturedStmt(OMPD_parallel);
22392239
const Stmt *ParallelRegionBodyStmt = CS->getCapturedStmt();
22402240

2241-
auto BodyGenCB = [&, this](InsertPointTy AllocaIP,
2242-
InsertPointTy CodeGenIP) {
2241+
auto BodyGenCB = [&, this](InsertPointTy AllocIP, InsertPointTy CodeGenIP,
2242+
ArrayRef<InsertPointTy> DeallocIPs) {
22432243
OMPBuilderCBHelpers::EmitOMPOutlinedRegionBody(
2244-
*this, ParallelRegionBodyStmt, AllocaIP, CodeGenIP, "parallel");
2244+
*this, ParallelRegionBodyStmt, AllocIP, CodeGenIP, "parallel");
22452245
return llvm::Error::success();
22462246
};
22472247

22482248
CGCapturedStmtInfo CGSI(*CS, CR_OpenMP);
22492249
CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(*this, &CGSI);
22502250
llvm::OpenMPIRBuilder::InsertPointTy AllocaIP(
22512251
AllocaInsertPt->getParent(), AllocaInsertPt->getIterator());
2252-
llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(
2253-
OMPBuilder.createParallel(Builder, AllocaIP, BodyGenCB, PrivCB, FiniCB,
2254-
IfCond, NumThreads, ProcBind, S.hasCancel()));
2252+
llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
2253+
cantFail(OMPBuilder.createParallel(
2254+
Builder, AllocaIP, /*DeallocIPs=*/{}, BodyGenCB, PrivCB, FiniCB,
2255+
IfCond, NumThreads, ProcBind, S.hasCancel()));
22552256
Builder.restoreIP(AfterIP);
22562257
return;
22572258
}
@@ -4936,21 +4937,23 @@ void CodeGenFunction::EmitOMPSectionsDirective(const OMPSectionsDirective &S) {
49364937
llvm::SmallVector<BodyGenCallbackTy, 4> SectionCBVector;
49374938
if (CS) {
49384939
for (const Stmt *SubStmt : CS->children()) {
4939-
auto SectionCB = [this, SubStmt](InsertPointTy AllocaIP,
4940-
InsertPointTy CodeGenIP) {
4941-
OMPBuilderCBHelpers::EmitOMPInlinedRegionBody(
4942-
*this, SubStmt, AllocaIP, CodeGenIP, "section");
4940+
auto SectionCB = [this, SubStmt](InsertPointTy AllocIP,
4941+
InsertPointTy CodeGenIP,
4942+
ArrayRef<InsertPointTy> DeallocIPs) {
4943+
OMPBuilderCBHelpers::EmitOMPInlinedRegionBody(*this, SubStmt, AllocIP,
4944+
CodeGenIP, "section");
49434945
return llvm::Error::success();
49444946
};
49454947
SectionCBVector.push_back(SectionCB);
49464948
}
49474949
} else {
4948-
auto SectionCB = [this, CapturedStmt](InsertPointTy AllocaIP,
4949-
InsertPointTy CodeGenIP) {
4950-
OMPBuilderCBHelpers::EmitOMPInlinedRegionBody(
4951-
*this, CapturedStmt, AllocaIP, CodeGenIP, "section");
4952-
return llvm::Error::success();
4953-
};
4950+
auto SectionCB =
4951+
[this, CapturedStmt](InsertPointTy AllocIP, InsertPointTy CodeGenIP,
4952+
ArrayRef<InsertPointTy> DeallocIPs) {
4953+
OMPBuilderCBHelpers::EmitOMPInlinedRegionBody(
4954+
*this, CapturedStmt, AllocIP, CodeGenIP, "section");
4955+
return llvm::Error::success();
4956+
};
49544957
SectionCBVector.push_back(SectionCB);
49554958
}
49564959

@@ -5004,10 +5007,11 @@ void CodeGenFunction::EmitOMPSectionDirective(const OMPSectionDirective &S) {
50045007
return llvm::Error::success();
50055008
};
50065009

5007-
auto BodyGenCB = [SectionRegionBodyStmt, this](InsertPointTy AllocaIP,
5008-
InsertPointTy CodeGenIP) {
5010+
auto BodyGenCB = [SectionRegionBodyStmt,
5011+
this](InsertPointTy AllocIP, InsertPointTy CodeGenIP,
5012+
ArrayRef<InsertPointTy> DeallocIPs) {
50095013
OMPBuilderCBHelpers::EmitOMPInlinedRegionBody(
5010-
*this, SectionRegionBodyStmt, AllocaIP, CodeGenIP, "section");
5014+
*this, SectionRegionBodyStmt, AllocIP, CodeGenIP, "section");
50115015
return llvm::Error::success();
50125016
};
50135017

@@ -5089,10 +5093,11 @@ void CodeGenFunction::EmitOMPMasterDirective(const OMPMasterDirective &S) {
50895093
return llvm::Error::success();
50905094
};
50915095

5092-
auto BodyGenCB = [MasterRegionBodyStmt, this](InsertPointTy AllocaIP,
5093-
InsertPointTy CodeGenIP) {
5096+
auto BodyGenCB = [MasterRegionBodyStmt,
5097+
this](InsertPointTy AllocIP, InsertPointTy CodeGenIP,
5098+
ArrayRef<InsertPointTy> DeallocIPs) {
50945099
OMPBuilderCBHelpers::EmitOMPInlinedRegionBody(
5095-
*this, MasterRegionBodyStmt, AllocaIP, CodeGenIP, "master");
5100+
*this, MasterRegionBodyStmt, AllocIP, CodeGenIP, "master");
50965101
return llvm::Error::success();
50975102
};
50985103

@@ -5139,10 +5144,11 @@ void CodeGenFunction::EmitOMPMaskedDirective(const OMPMaskedDirective &S) {
51395144
return llvm::Error::success();
51405145
};
51415146

5142-
auto BodyGenCB = [MaskedRegionBodyStmt, this](InsertPointTy AllocaIP,
5143-
InsertPointTy CodeGenIP) {
5147+
auto BodyGenCB = [MaskedRegionBodyStmt,
5148+
this](InsertPointTy AllocIP, InsertPointTy CodeGenIP,
5149+
ArrayRef<InsertPointTy> DeallocIPs) {
51445150
OMPBuilderCBHelpers::EmitOMPInlinedRegionBody(
5145-
*this, MaskedRegionBodyStmt, AllocaIP, CodeGenIP, "masked");
5151+
*this, MaskedRegionBodyStmt, AllocIP, CodeGenIP, "masked");
51465152
return llvm::Error::success();
51475153
};
51485154

@@ -5182,10 +5188,11 @@ void CodeGenFunction::EmitOMPCriticalDirective(const OMPCriticalDirective &S) {
51825188
return llvm::Error::success();
51835189
};
51845190

5185-
auto BodyGenCB = [CriticalRegionBodyStmt, this](InsertPointTy AllocaIP,
5186-
InsertPointTy CodeGenIP) {
5191+
auto BodyGenCB = [CriticalRegionBodyStmt,
5192+
this](InsertPointTy AllocIP, InsertPointTy CodeGenIP,
5193+
ArrayRef<InsertPointTy> DeallocIPs) {
51875194
OMPBuilderCBHelpers::EmitOMPInlinedRegionBody(
5188-
*this, CriticalRegionBodyStmt, AllocaIP, CodeGenIP, "critical");
5195+
*this, CriticalRegionBodyStmt, AllocIP, CodeGenIP, "critical");
51895196
return llvm::Error::success();
51905197
};
51915198

@@ -6152,8 +6159,8 @@ void CodeGenFunction::EmitOMPTaskgroupDirective(
61526159
InsertPointTy AllocaIP(AllocaInsertPt->getParent(),
61536160
AllocaInsertPt->getIterator());
61546161

6155-
auto BodyGenCB = [&, this](InsertPointTy AllocaIP,
6156-
InsertPointTy CodeGenIP) {
6162+
auto BodyGenCB = [&, this](InsertPointTy AllocIP, InsertPointTy CodeGenIP,
6163+
ArrayRef<InsertPointTy> DeallocIPs) {
61576164
Builder.restoreIP(CodeGenIP);
61586165
EmitStmt(S.getInnermostCapturedStmt()->getCapturedStmt());
61596166
return llvm::Error::success();
@@ -6162,7 +6169,8 @@ void CodeGenFunction::EmitOMPTaskgroupDirective(
61626169
if (!CapturedStmtInfo)
61636170
CapturedStmtInfo = &CapStmtInfo;
61646171
llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
6165-
cantFail(OMPBuilder.createTaskgroup(Builder, AllocaIP, BodyGenCB));
6172+
cantFail(OMPBuilder.createTaskgroup(Builder, AllocaIP,
6173+
/*DeallocIPs=*/{}, BodyGenCB));
61666174
Builder.restoreIP(AfterIP);
61676175
return;
61686176
}
@@ -6879,8 +6887,9 @@ void CodeGenFunction::EmitOMPOrderedDirective(const OMPOrderedDirective &S) {
68796887
return llvm::Error::success();
68806888
};
68816889

6882-
auto BodyGenCB = [&S, C, this](InsertPointTy AllocaIP,
6883-
InsertPointTy CodeGenIP) {
6890+
auto BodyGenCB = [&S, C, this](InsertPointTy AllocIP,
6891+
InsertPointTy CodeGenIP,
6892+
ArrayRef<InsertPointTy> DeallocIPs) {
68846893
Builder.restoreIP(CodeGenIP);
68856894

68866895
const CapturedStmt *CS = S.getInnermostCapturedStmt();
@@ -6898,7 +6907,7 @@ void CodeGenFunction::EmitOMPOrderedDirective(const OMPOrderedDirective &S) {
68986907
OutlinedFn, CapturedVars);
68996908
} else {
69006909
OMPBuilderCBHelpers::EmitOMPInlinedRegionBody(
6901-
*this, CS->getCapturedStmt(), AllocaIP, CodeGenIP, "ordered");
6910+
*this, CS->getCapturedStmt(), AllocIP, CodeGenIP, "ordered");
69026911
}
69036912
return llvm::Error::success();
69046913
};

flang/include/flang/Optimizer/Support/InitFIR.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/Dialect/Math/IR/Math.h"
3535
#include "mlir/Dialect/OpenACC/OpenACC.h"
3636
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
37+
#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
3738
#include "mlir/Dialect/SCF/IR/SCF.h"
3839
#include "mlir/Dialect/SCF/Transforms/Passes.h"
3940
#include "mlir/InitAllDialects.h"
@@ -106,6 +107,7 @@ inline void loadDialects(mlir::MLIRContext &context) {
106107
/// but is a smaller set since we aren't using many of the passes found there.
107108
inline void registerMLIRPassesForFortranTools() {
108109
mlir::acc::registerOpenACCPasses();
110+
mlir::omp::registerOpenMPPasses();
109111
mlir::registerCanonicalizerPass();
110112
mlir::registerCSEPass();
111113
mlir::affine::registerAffineLoopFusionPass();

flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,7 @@ struct TargetAllocMemOpConversion
245245
size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize);
246246
for (mlir::Value opnd : adaptor.getOperands().drop_front())
247247
size = mlir::LLVM::MulOp::create(
248-
rewriter, loc, ity, size,
249-
integerCast(lowerTy(), loc, rewriter, ity, opnd));
248+
rewriter, loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
250249
auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
251250
auto mallocTy =
252251
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);

flang/lib/Optimizer/Passes/Pipelines.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
/// common to flang and the test tools.
1111

1212
#include "flang/Optimizer/Passes/Pipelines.h"
13+
#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
1314
#include "llvm/Support/CommandLine.h"
1415

1516
/// Force setting the no-alias attribute on fuction arguments when possible.
@@ -408,6 +409,9 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
408409
}
409410

410411
fir::addFIRToLLVMPass(pm, config);
412+
413+
if (config.EnableOpenMP && !config.EnableOpenMPSimd)
414+
pm.addPass(mlir::omp::createStackToSharedPass());
411415
}
412416

413417
/// Create a pass pipeline for lowering from MLIR to LLVM IR

flang/test/Fir/basic-program.fir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,5 +161,7 @@ func.func @_QQmain() {
161161
// PASSES-NEXT: LowerNontemporalPass
162162
// PASSES-NEXT: FIRToLLVMLowering
163163
// PASSES-NEXT: ReconcileUnrealizedCasts
164+
// PASSES-NEXT: 'llvm.func' Pipeline
165+
// PASSES-NEXT: StackToSharedPass
164166
// PASSES-NEXT: PrepareForOMPOffloadPrivatizationPass
165167
// PASSES-NEXT: LLVMIRLoweringPass

flang/test/Integration/OpenMP/threadprivate-target-device.f90

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,14 @@
1414
! target code in the same function.
1515

1616
! CHECK: define weak_odr protected amdgpu_kernel void @{{.*}}(ptr %{{.*}}, ptr %[[ARG1:.*]], ptr %[[ARG2:.*]]) #{{[0-9]+}} {
17-
! CHECK: %[[ALLOCA_X:.*]] = alloca ptr, align 8, addrspace(5)
18-
! CHECK: %[[ASCAST_X:.*]] = addrspacecast ptr addrspace(5) %[[ALLOCA_X]] to ptr
19-
! CHECK: store ptr %[[ARG1]], ptr %[[ASCAST_X]], align 8
17+
! CHECK: %[[ALLOC_N:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
18+
! CHECK: store ptr %[[ARG2]], ptr %[[ALLOC_N]], align 8
2019

21-
! CHECK: %[[ALLOCA_N:.*]] = alloca ptr, align 8, addrspace(5)
22-
! CHECK: %[[ASCAST_N:.*]] = addrspacecast ptr addrspace(5) %[[ALLOCA_N]] to ptr
23-
! CHECK: store ptr %[[ARG2]], ptr %[[ASCAST_N]], align 8
20+
! CHECK: %[[ALLOC_X:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
21+
! CHECK: store ptr %[[ARG1]], ptr %[[ALLOC_X]], align 8
2422

25-
! CHECK: %[[LOAD_X:.*]] = load ptr, ptr %[[ASCAST_X]], align 8
26-
! CHECK: call void @bar_(ptr %[[LOAD_X]], ptr %[[ASCAST_N]])
23+
! CHECK: %[[LOAD_X:.*]] = load ptr, ptr %[[ALLOC_X]], align 8
24+
! CHECK: call void @bar_(ptr %[[LOAD_X]], ptr %[[ALLOC_N]])
2725

2826
module test
2927
implicit none

0 commit comments

Comments
 (0)