diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index dd0cb3c42ba26..fa3fa01ea0a35 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1241,11 +1241,20 @@ bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const { omp::clause::Linear>([&](const omp::clause::Linear &clause, const parser::CharBlock &) { auto &objects = std::get(clause.t); + static std::vector typeAttrs; + + if (!result.linearVars.size()) + typeAttrs.clear(); + for (const omp::Object &object : objects) { semantics::Symbol *sym = object.sym(); const mlir::Value variable = converter.getSymbolAddress(*sym); result.linearVars.push_back(variable); + mlir::Type ty = converter.genType(*sym); + typeAttrs.push_back(mlir::TypeAttr::get(ty)); } + result.linearVarTypes = + mlir::ArrayAttr::get(&converter.getMLIRContext(), typeAttrs); if (objects.size()) { if (auto &mod = std::get>( diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 0a200388a36e5..410674434478b 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1636,8 +1636,7 @@ static void genSimdClauses( cp.processReduction(loc, clauseOps, reductionSyms); cp.processSafelen(clauseOps); cp.processSimdlen(clauseOps); - - cp.processTODO(loc, llvm::omp::Directive::OMPD_simd); + cp.processLinear(clauseOps); } static void genSingleClauses(lower::AbstractConverter &converter, @@ -1831,9 +1830,9 @@ static void genWsloopClauses( cp.processOrdered(clauseOps); cp.processReduction(loc, clauseOps, reductionSyms); cp.processSchedule(stmtCtx, clauseOps); + cp.processLinear(clauseOps); - cp.processTODO( - loc, llvm::omp::Directive::OMPD_do); + cp.processTODO(loc, llvm::omp::Directive::OMPD_do); } //===----------------------------------------------------------------------===// diff --git a/flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90 b/flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90 deleted file mode 100644 index db8f5c293b40e..0000000000000 --- a/flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90 +++ /dev/null @@ -1,14 +0,0 @@ -! This test checks lowering of OpenMP do simd linear() pragma - -! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s -! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s -subroutine testDoSimdLinear(int_array) - integer :: int_array(:) -!CHECK: not yet implemented: Unhandled clause LINEAR in SIMD construct -!$omp do simd linear(int_array) - do index_ = 1, 10 - end do -!$omp end do simd - -end subroutine testDoSimdLinear - diff --git a/flang/test/Lower/OpenMP/simd-linear.f90 b/flang/test/Lower/OpenMP/simd-linear.f90 new file mode 100644 index 0000000000000..b6c7668af998b --- /dev/null +++ b/flang/test/Lower/OpenMP/simd-linear.f90 @@ -0,0 +1,57 @@ +! This test checks lowering of OpenMP SIMD Directive +! with linear clause + +! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s + +!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"} +!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[const:.*]] = arith.constant 1 : i32 +subroutine simple_linear + implicit none + integer :: x, y, i + !CHECK: omp.simd linear(%[[X]]#0 = %[[const]] : !fir.ref) {{.*}} + !$omp simd linear(x) + !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref + !CHECK: %[[const:.*]] = arith.constant 2 : i32 + !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32 + do i = 1, 10 + y = x + 2 + end do + !CHECK: } {linear_var_types = [i32]} +end subroutine + + +!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"} +!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +subroutine linear_step + implicit none + integer :: x, y, i + !CHECK: %[[const:.*]] = arith.constant 4 : i32 + !CHECK: omp.simd linear(%[[X]]#0 = %[[const]] : !fir.ref) {{.*}} + !$omp simd linear(x:4) + !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref + !CHECK: %[[const:.*]] = arith.constant 2 : i32 + !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32 + do i = 1, 10 + y = x + 2 + end do + !CHECK: } {linear_var_types = [i32]} +end subroutine + +!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"} +!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"} +!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +subroutine linear_expr + implicit none + integer :: x, y, i, a + !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref + !CHECK: %[[const:.*]] = arith.constant 4 : i32 + !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32 + !CHECK: omp.simd linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref) {{.*}} + !$omp simd linear(x:a+4) + do i = 1, 10 + y = x + 2 + end do + !CHECK: } {linear_var_types = [i32]} +end subroutine diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90 new file mode 100644 index 0000000000000..0145be6a7c4e6 --- /dev/null +++ b/flang/test/Lower/OpenMP/wsloop-linear.f90 @@ -0,0 +1,60 @@ +! This test checks lowering of OpenMP DO Directive (Worksharing) +! with linear clause + +! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s + +!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"} +!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[const:.*]] = arith.constant 1 : i32 +subroutine simple_linear + implicit none + integer :: x, y, i + !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref) {{.*}} + !$omp do linear(x) + !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref + !CHECK: %[[const:.*]] = arith.constant 2 : i32 + !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32 + do i = 1, 10 + y = x + 2 + end do + !$omp end do + !CHECK: } {linear_var_types = [i32]} +end subroutine + + +!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"} +!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +subroutine linear_step + implicit none + integer :: x, y, i + !CHECK: %[[const:.*]] = arith.constant 4 : i32 + !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref) {{.*}} + !$omp do linear(x:4) + !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref + !CHECK: %[[const:.*]] = arith.constant 2 : i32 + !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32 + do i = 1, 10 + y = x + 2 + end do + !$omp end do + !CHECK: } {linear_var_types = [i32]} +end subroutine + +!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"} +!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"} +!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +subroutine linear_expr + implicit none + integer :: x, y, i, a + !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref + !CHECK: %[[const:.*]] = arith.constant 4 : i32 + !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32 + !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref) {{.*}} + !$omp do linear(x:a+4) + do i = 1, 10 + y = x + 2 + end do + !$omp end do + !CHECK: } {linear_var_types = [i32]} +end subroutine diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 8e43c4284d078..05e2ee4e5632b 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -21,6 +21,7 @@ include "mlir/Dialect/OpenMP/OpenMPOpBase.td" include "mlir/IR/SymbolInterfaces.td" +include "mlir/IR/BuiltinAttributes.td" //===----------------------------------------------------------------------===// // V5.2: [6.3] `align` clause @@ -723,10 +724,9 @@ class OpenMP_LinearClauseSkip< bit description = false, bit extraClassDeclaration = false > : OpenMP_Clause { - let arguments = (ins - Variadic:$linear_vars, - Variadic:$linear_step_vars - ); + let arguments = (ins Variadic:$linear_vars, + Variadic:$linear_step_vars, + OptionalAttr:$linear_var_types); let optAssemblyFormat = [{ `linear` `(` diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 0d6b2870c625a..5d1f4f319eb02 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2825,6 +2825,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, ArrayRef attributes) { build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), + /*linear_var_types*/ nullptr, /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr, /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr, /*private_needs_barrier=*/false, @@ -2843,8 +2844,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, WsloopOp::build( builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars, - clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod, - clauses.ordered, clauses.privateVars, + clauses.linearStepVars, clauses.linearVarTypes, clauses.nowait, + clauses.order, clauses.orderMod, clauses.ordered, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), @@ -2889,17 +2890,16 @@ LogicalResult WsloopOp::verifyRegions() { void SimdOp::build(OpBuilder &builder, OperationState &state, const SimdOperands &clauses) { MLIRContext *ctx = builder.getContext(); - // TODO Store clauses in op: linearVars, linearStepVars - SimdOp::build(builder, state, clauses.alignedVars, - makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr, - /*linear_vars=*/{}, /*linear_step_vars=*/{}, - clauses.nontemporalVars, clauses.order, clauses.orderMod, - clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.privateNeedsBarrier, clauses.reductionMod, - clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen, - clauses.simdlen); + SimdOp::build( + builder, state, clauses.alignedVars, + makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr, + clauses.linearVars, clauses.linearStepVars, clauses.linearVarTypes, + clauses.nontemporalVars, clauses.order, clauses.orderMod, + clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), + clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen, + clauses.simdlen); } LogicalResult SimdOp::verify() { diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 869bde69d5cdc..4f185f804c1f9 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -137,28 +137,31 @@ class LinearClauseProcessor { private: SmallVector linearPreconditionVars; SmallVector linearLoopBodyTemps; - SmallVector linearOrigVars; SmallVector linearOrigVal; SmallVector linearSteps; + SmallVector linearVarTypes; llvm::BasicBlock *linearFinalizationBB; llvm::BasicBlock *linearExitBB; llvm::BasicBlock *linearLastIterExitBB; public: + // Register type for the linear variables + void registerType(LLVM::ModuleTranslation &moduleTranslation, + mlir::Attribute &ty) { + linearVarTypes.push_back(moduleTranslation.convertType( + mlir::cast(ty).getValue())); + } + // Allocate space for linear variabes void createLinearVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, - mlir::Value &linearVar) { - if (llvm::AllocaInst *linearVarAlloca = dyn_cast( - moduleTranslation.lookupValue(linearVar))) { - linearPreconditionVars.push_back(builder.CreateAlloca( - linearVarAlloca->getAllocatedType(), nullptr, ".linear_var")); - llvm::Value *linearLoopBodyTemp = builder.CreateAlloca( - linearVarAlloca->getAllocatedType(), nullptr, ".linear_result"); - linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar)); - linearLoopBodyTemps.push_back(linearLoopBodyTemp); - linearOrigVars.push_back(linearVarAlloca); - } + mlir::Value &linearVar, int idx) { + linearPreconditionVars.push_back( + builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var")); + llvm::Value *linearLoopBodyTemp = + builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result"); + linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar)); + linearLoopBodyTemps.push_back(linearLoopBodyTemp); } // Initialize linear step @@ -168,20 +171,15 @@ class LinearClauseProcessor { } // Emit IR for initialization of linear variables - llvm::OpenMPIRBuilder::InsertPointOrErrorTy - initLinearVar(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - llvm::BasicBlock *loopPreHeader) { + void initLinearVar(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::BasicBlock *loopPreHeader) { builder.SetInsertPoint(loopPreHeader->getTerminator()); - for (size_t index = 0; index < linearOrigVars.size(); index++) { - llvm::LoadInst *linearVarLoad = builder.CreateLoad( - linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]); + for (size_t index = 0; index < linearOrigVal.size(); index++) { + llvm::LoadInst *linearVarLoad = + builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]); builder.CreateStore(linearVarLoad, linearPreconditionVars[index]); } - llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = - moduleTranslation.getOpenMPBuilder()->createBarrier( - builder.saveIP(), llvm::omp::OMPD_barrier); - return afterBarrierIP; } // Emit IR for updating Linear variables @@ -190,20 +188,24 @@ class LinearClauseProcessor { builder.SetInsertPoint(loopBody->getTerminator()); for (size_t index = 0; index < linearPreconditionVars.size(); index++) { // Emit increments for linear vars - llvm::LoadInst *linearVarStart = - builder.CreateLoad(linearOrigVars[index]->getAllocatedType(), - - linearPreconditionVars[index]); - auto *mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]); - auto *addInst = builder.CreateAdd(linearVarStart, mulInst); - builder.CreateStore(addInst, linearLoopBodyTemps[index]); + llvm::LoadInst *linearVarStart = builder.CreateLoad( + linearVarTypes[index], linearPreconditionVars[index]); + auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]); + if (linearVarTypes[index]->isIntegerTy()) { + auto addInst = builder.CreateAdd(linearVarStart, mulInst); + builder.CreateStore(addInst, linearLoopBodyTemps[index]); + } else if (linearVarTypes[index]->isFloatingPointTy()) { + auto cvt = builder.CreateSIToFP(mulInst, linearVarTypes[index]); + auto addInst = builder.CreateFAdd(linearVarStart, cvt); + builder.CreateStore(addInst, linearLoopBodyTemps[index]); + } } } // Linear variable finalization is conditional on the last logical iteration. // Create BB splits to manage the same. - void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder, - llvm::BasicBlock *loopExit) { + void splitLinearFiniBB(llvm::IRBuilderBase &builder, + llvm::BasicBlock *loopExit) { linearFinalizationBB = loopExit->splitBasicBlock( loopExit->getTerminator(), "omp_loop.linear_finalization"); linearExitBB = linearFinalizationBB->splitBasicBlock( @@ -227,11 +229,10 @@ class LinearClauseProcessor { llvm::Type::getInt32Ty(builder.getContext()), 0)); // Store the linear variable values to original variables. builder.SetInsertPoint(linearLastIterExitBB->getTerminator()); - for (size_t index = 0; index < linearOrigVars.size(); index++) { + for (size_t index = 0; index < linearOrigVal.size(); index++) { llvm::LoadInst *linearVarTemp = - builder.CreateLoad(linearOrigVars[index]->getAllocatedType(), - linearLoopBodyTemps[index]); - builder.CreateStore(linearVarTemp, linearOrigVars[index]); + builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]); + builder.CreateStore(linearVarTemp, linearOrigVal[index]); } // Create conditional branch such that the linear variable @@ -255,7 +256,8 @@ class LinearClauseProcessor { users.push_back(user); for (auto *user : users) { if (auto *userInst = dyn_cast(user)) { - if (userInst->getParent()->getName().str() == BBName) + if (userInst->getParent()->getName().str().find(BBName) != + std::string::npos) user->replaceUsesOfWith(linearOrigVal[varIndex], linearLoopBodyTemps[varIndex]); } @@ -334,10 +336,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (!op.getIsDevicePtrVars().empty()) result = todo("is_device_ptr"); }; - auto checkLinear = [&todo](auto op, LogicalResult &result) { - if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty()) - result = todo("linear"); - }; auto checkNowait = [&todo](auto op, LogicalResult &result) { if (op.getNowait()) result = todo("nowait"); @@ -420,7 +418,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { }) .Case([&](omp::WsloopOp op) { checkAllocate(op, result); - checkLinear(op, result); checkOrder(op, result); checkReduction(op, result); }) @@ -428,10 +425,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkAllocate(op, result); checkReduction(op, result); }) - .Case([&](omp::SimdOp op) { - checkLinear(op, result); - checkReduction(op, result); - }) + .Case([&](omp::SimdOp op) { checkReduction(op, result); }) .Case([&](auto op) { checkHint(op, result); }) .Case( @@ -2627,10 +2621,15 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, // Initialize linear variables and linear step LinearClauseProcessor linearClauseProcessor; + if (!wsloopOp.getLinearVars().empty()) { - for (mlir::Value linearVar : wsloopOp.getLinearVars()) + auto linearVarTypes = wsloopOp.getLinearVarTypes().value(); + for (mlir::Attribute linearVarType : linearVarTypes) + linearClauseProcessor.registerType(moduleTranslation, linearVarType); + + for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars())) linearClauseProcessor.createLinearVar(builder, moduleTranslation, - linearVar); + linearVar, idx); for (mlir::Value linearStep : wsloopOp.getLinearStepVars()) linearClauseProcessor.initLinearStep(moduleTranslation, linearStep); } @@ -2645,16 +2644,17 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, // Emit Initialization and Update IR for linear variables if (!wsloopOp.getLinearVars().empty()) { + linearClauseProcessor.initLinearVar(builder, moduleTranslation, + loopInfo->getPreheader()); llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = - linearClauseProcessor.initLinearVar(builder, moduleTranslation, - loopInfo->getPreheader()); + moduleTranslation.getOpenMPBuilder()->createBarrier( + builder.saveIP(), llvm::omp::OMPD_barrier); if (failed(handleError(afterBarrierIP, *loopOp))) return failure(); builder.restoreIP(*afterBarrierIP); linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(), loopInfo->getIndVar()); - linearClauseProcessor.outlineLinearFinalizationBB(builder, - loopInfo->getExit()); + linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit()); } builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); @@ -2947,6 +2947,20 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); + // Initialize linear variables and linear step + LinearClauseProcessor linearClauseProcessor; + + if (!simdOp.getLinearVars().empty()) { + auto linearVarTypes = simdOp.getLinearVarTypes().value(); + for (mlir::Attribute linearVarType : linearVarTypes) + linearClauseProcessor.registerType(moduleTranslation, linearVarType); + for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) + linearClauseProcessor.createLinearVar(builder, moduleTranslation, + linearVar, idx); + for (mlir::Value linearStep : simdOp.getLinearStepVars()) + linearClauseProcessor.initLinearStep(moduleTranslation, linearStep); + } + llvm::Expected afterAllocas = allocatePrivateVars( builder, moduleTranslation, privateVarsInfo, allocaIP); if (handleError(afterAllocas, opInst).failed()) @@ -3016,14 +3030,27 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, if (failed(handleError(regionBlock, opInst))) return failure(); - builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation); + // Emit Initialization for linear variables + if (simdOp.getLinearVars().size()) { + linearClauseProcessor.initLinearVar(builder, moduleTranslation, + loopInfo->getPreheader()); + + linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(), + loopInfo->getIndVar()); + } + builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); + ompBuilder->applySimd(loopInfo, alignedVars, simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr()) : nullptr, order, simdlen, safelen); + for (size_t index = 0; index < simdOp.getLinearVars().size(); index++) + linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region", + index); + // We now need to reduce the per-simd-lane reduction variable into the // original variable. This works a bit differently to other reductions (e.g. // wsloop) because we don't need to call into the OpenMP runtime to handle diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 8bd33a382197e..1eb501ca02703 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -328,6 +328,52 @@ llvm.func @test_omp_masked(%arg0: i32)-> () { // ----- +llvm.func @wsloop_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { +// CHECK-LABEL: @wsloop_linear + +// CHECK: %p.lastiter = alloca i32, align 4 +// CHECK: %p.lowerbound = alloca i32, align 4 +// CHECK: %p.upperbound = alloca i32, align 4 +// CHECK: %p.stride = alloca i32, align 4 +// CHECK: %[[LINEAR_VAR:.*]] = alloca i32, align 4 +// CHECK: %[[LINEAR_RESULT:.*]] = alloca i32, align 4 + +// CHECK: omp_loop.preheader: +// CHECK: %[[LOAD:.*]] = load i32, ptr %{{.*}}, align 4 +// CHECK: store i32 %[[LOAD]], ptr %[[LINEAR_VAR]], align 4 + +// CHECK: omp_loop.body: +// CHECK: %[[LOOP_IV_CALC:.*]] = add i32 %omp_loop.iv, {{.*}} +// CHECK: %[[LINEAR_VAR_LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4 +// CHECK: %[[MUL:.*]] = mul i32 %[[LOOP_IV_CALC]], {{.*}} +// CHECK: %[[ADD:.*]] = add i32 %[[LINEAR_VAR_LOAD]], %[[MUL]] +// CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4 + +// CHECK: omp_loop.linear_finalization: +// CHECK: %[[ITER:.*]] = load i32, ptr %p.lastiter, align 4 +// CHECK: %[[CMP:.*]] = icmp ne i32 %[[ITER]], 0 +// CHECK: br i1 %[[CMP]], label %omp_loop.linear_lastiter_exit, label %omp_loop.linear_exit + +// CHECK: omp_loop.linear_lastiter_exit: +// CHECK: %[[LOAD:.*]] = load i32, ptr %[[LINEAR_RESULT]], align 4 +// CHECK: store i32 %[[LOAD]], ptr {{.*}}, align 4 +// CHECK: br label %omp_loop.linear_exit + +// CHECK: omp_loop.linear_exit: +// CHECK: %[[THREAD_ID:.*]] = call i32 @__kmpc_global_thread_num(ptr {{.*}}) +// CHECK: call void @__kmpc_barrier(ptr {{.*}}, i32 %[[THREAD_ID]]) +// CHECK: br label %omp_loop.after + + omp.wsloop linear(%x = %step : !llvm.ptr) { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } {linear_var_types = [i32]} + llvm.return +} + +// ----- + // CHECK: %struct.ident_t = type // CHECK: @[[$loc:.*]] = private unnamed_addr constant {{.*}} c";unknown;unknown;{{[0-9]+}};{{[0-9]+}};;\00" // CHECK: @[[$loc_struct:.*]] = private unnamed_addr constant %struct.ident_t {{.*}} @[[$loc]] {{.*}} @@ -695,6 +741,34 @@ llvm.func @simd_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr) { // ----- +llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { + +// CHECK-LABEL: @simd_linear + +// CHECK: %[[LINEAR_VAR:.*]] = alloca i32, align 4 +// CHECK: %[[LINEAR_RESULT:.*]] = alloca i32, align 4 + +// CHECK: omp_loop.preheader: +// CHECK: %[[LOAD:.*]] = load i32, ptr {{.*}}, align 4 +// CHECK: store i32 %[[LOAD]], ptr %[[LINEAR_VAR]], align 4 + +// CHECK: omp_loop.body: +// CHECK: %[[LOOP_IV_CALC:.*]] = mul i32 %omp_loop.iv, {{.*}} +// CHECK: %[[ADD:.*]] = add i32 %[[LOOP_IV_CALC]], {{.*}} +// CHECK: %[[LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4, !llvm.access.group !1 +// CHECK: %[[MUL:.*]] = mul i32 %omp_loop.iv, {{.*}} +// CHECK: %[[ADD:.*]] = add i32 %[[LOAD]], %[[MUL]] +// CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4, !llvm.access.group !1 + omp.simd linear(%x = %step : !llvm.ptr) { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } {linear_var_types = [i32]} + llvm.return +} + +// ----- + // CHECK-LABEL: @simd_simple_multiple llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { omp.simd { diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 731a6322736d4..fe35499e8944d 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -101,19 +101,6 @@ llvm.func @sections_private(%x : !llvm.ptr) { } -// ----- - -llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause linear in omp.simd operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.simd}} - omp.simd linear(%x = %step : !llvm.ptr) { - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.yield - } - } - llvm.return -} - // ----- omp.declare_reduction @add_f32 : f32 @@ -434,19 +421,6 @@ llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { llvm.return } -// ----- - -llvm.func @wsloop_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause linear in omp.wsloop operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} - omp.wsloop linear(%x = %step : !llvm.ptr) { - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.yield - } - } - llvm.return -} - // ----- llvm.func @wsloop_order(%lb : i32, %ub : i32, %step : i32) { // expected-error@below {{not yet implemented: Unhandled clause order in omp.wsloop operation}}