Skip to content

Commit 685d2cf

Browse files
authored
[Flang][OpenMP] Add lowering support for is_device_ptr clause (llvm#169331) (#764)
Add support for OpenMP is_device_ptr clause for target directives. [MLIR][OpenMP] Add OpenMPToLLVMIRTranslation support for is_device_ptr llvm#169367 This PR adds support for the OpenMP is_device_ptr clause in the MLIR to LLVM IR translation for target regions. The is_device_ptr clause allows device pointers (allocated via OpenMP runtime APIs) to be used directly in target regions without implicit mapping.
1 parent ef88ee8 commit 685d2cf

File tree

11 files changed

+200
-52
lines changed

11 files changed

+200
-52
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,6 @@ mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
4444
return mlir::omp::ReductionModifier::defaultmod;
4545
}
4646

47-
/// Check for unsupported map operand types.
48-
static void checkMapType(mlir::Location location, mlir::Type type) {
49-
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(type))
50-
type = refType.getElementType();
51-
if (auto boxType = mlir::dyn_cast_or_null<fir::BoxType>(type))
52-
if (!mlir::isa<fir::PointerType>(boxType.getElementType()))
53-
TODO(location, "OMPD_target_data MapOperand BoxType");
54-
}
55-
5647
static mlir::omp::ScheduleModifier
5748
translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) {
5849
switch (m) {
@@ -211,18 +202,6 @@ getIfClauseOperand(lower::AbstractConverter &converter,
211202
ifVal);
212203
}
213204

214-
static void addUseDeviceClause(
215-
lower::AbstractConverter &converter, const omp::ObjectList &objects,
216-
llvm::SmallVectorImpl<mlir::Value> &operands,
217-
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) {
218-
genObjectList(objects, converter, operands);
219-
for (mlir::Value &operand : operands)
220-
checkMapType(operand.getLoc(), operand.getType());
221-
222-
for (const omp::Object &object : objects)
223-
useDeviceSyms.push_back(object.sym());
224-
}
225-
226205
//===----------------------------------------------------------------------===//
227206
// ClauseProcessor unique clauses
228207
//===----------------------------------------------------------------------===//
@@ -1225,14 +1204,26 @@ bool ClauseProcessor::processInReduction(
12251204
}
12261205

12271206
bool ClauseProcessor::processIsDevicePtr(
1228-
mlir::omp::IsDevicePtrClauseOps &result,
1207+
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
12291208
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
1230-
return findRepeatableClause<omp::clause::IsDevicePtr>(
1231-
[&](const omp::clause::IsDevicePtr &devPtrClause,
1232-
const parser::CharBlock &) {
1233-
addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
1234-
isDeviceSyms);
1209+
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
1210+
bool clauseFound = findRepeatableClause<omp::clause::IsDevicePtr>(
1211+
[&](const omp::clause::IsDevicePtr &clause,
1212+
const parser::CharBlock &source) {
1213+
mlir::Location location = converter.genLocation(source);
1214+
// Force a map so the descriptor is materialized on the device with the
1215+
// device address inside.
1216+
mlir::omp::ClauseMapFlags mapTypeBits =
1217+
mlir::omp::ClauseMapFlags::is_device_ptr |
1218+
mlir::omp::ClauseMapFlags::to;
1219+
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
1220+
parentMemberIndices, result.isDevicePtrVars,
1221+
isDeviceSyms);
12351222
});
1223+
1224+
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1225+
result.isDevicePtrVars, isDeviceSyms);
1226+
return clauseFound;
12361227
}
12371228

12381229
bool ClauseProcessor::processLink(

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class ClauseProcessor {
134134
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
135135
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
136136
bool processIsDevicePtr(
137-
mlir::omp::IsDevicePtrClauseOps &result,
137+
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
138138
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
139139
bool
140140
processLink(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,7 +1677,7 @@ static void genTargetClauses(
16771677
hostEvalInfo->collectValues(clauseOps.hostEvalVars);
16781678
}
16791679
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
1680-
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
1680+
cp.processIsDevicePtr(stmtCtx, clauseOps, isDevicePtrSyms);
16811681
cp.processMap(loc, stmtCtx, clauseOps, llvm::omp::Directive::OMPD_unknown,
16821682
&mapSyms);
16831683
cp.processNowait(clauseOps);
@@ -2499,13 +2499,15 @@ static bool isDuplicateMappedSymbol(
24992499
const semantics::Symbol &sym,
25002500
const llvm::SetVector<const semantics::Symbol *> &privatizedSyms,
25012501
const llvm::SmallVectorImpl<const semantics::Symbol *> &hasDevSyms,
2502-
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms) {
2502+
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms,
2503+
const llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms) {
25032504
llvm::SmallVector<const semantics::Symbol *> concatSyms;
25042505
concatSyms.reserve(privatizedSyms.size() + hasDevSyms.size() +
2505-
mappedSyms.size());
2506+
mappedSyms.size() + isDevicePtrSyms.size());
25062507
concatSyms.append(privatizedSyms.begin(), privatizedSyms.end());
25072508
concatSyms.append(hasDevSyms.begin(), hasDevSyms.end());
25082509
concatSyms.append(mappedSyms.begin(), mappedSyms.end());
2510+
concatSyms.append(isDevicePtrSyms.begin(), isDevicePtrSyms.end());
25092511

25102512
auto checkSymbol = [&](const semantics::Symbol &checkSym) {
25112513
return std::any_of(concatSyms.begin(), concatSyms.end(),
@@ -2545,6 +2547,38 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25452547
loc, clauseOps, defaultMaps, hasDeviceAddrSyms,
25462548
isDevicePtrSyms, mapSyms);
25472549

2550+
if (!isDevicePtrSyms.empty()) {
2551+
// is_device_ptr maps get duplicated so the clause and synthesized
2552+
// has_device_addr entry each own a unique MapInfoOp user, keeping
2553+
// MapInfoFinalization happy while still wiring the symbol into
2554+
// has_device_addr when the user didn’t spell it explicitly.
2555+
auto insertionPt = firOpBuilder.saveInsertionPoint();
2556+
auto alreadyPresent = [&](const semantics::Symbol *sym) {
2557+
return llvm::any_of(hasDeviceAddrSyms, [&](const semantics::Symbol *s) {
2558+
return s && sym && s->GetUltimate() == sym->GetUltimate();
2559+
});
2560+
};
2561+
2562+
for (auto [idx, sym] : llvm::enumerate(isDevicePtrSyms)) {
2563+
mlir::Value mapVal = clauseOps.isDevicePtrVars[idx];
2564+
assert(sym && "expected symbol for is_device_ptr");
2565+
assert(mapVal && "expected map value for is_device_ptr");
2566+
auto mapInfo = mapVal.getDefiningOp<mlir::omp::MapInfoOp>();
2567+
assert(mapInfo && "expected map info op");
2568+
2569+
if (!alreadyPresent(sym)) {
2570+
clauseOps.hasDeviceAddrVars.push_back(mapVal);
2571+
hasDeviceAddrSyms.push_back(sym);
2572+
}
2573+
2574+
firOpBuilder.setInsertionPointAfter(mapInfo);
2575+
mlir::Operation *clonedOp = firOpBuilder.clone(*mapInfo.getOperation());
2576+
auto clonedMapInfo = mlir::cast<mlir::omp::MapInfoOp>(clonedOp);
2577+
clauseOps.isDevicePtrVars[idx] = clonedMapInfo.getResult();
2578+
}
2579+
firOpBuilder.restoreInsertionPoint(insertionPt);
2580+
}
2581+
25482582
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
25492583
/*shouldCollectPreDeterminedSymbols=*/
25502584
lower::omp::isLastItemInQueue(item, queue),
@@ -2584,7 +2618,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25842618
return;
25852619

25862620
if (!isDuplicateMappedSymbol(sym, dsp.getAllSymbolsToPrivatize(),
2587-
hasDeviceAddrSyms, mapSyms)) {
2621+
hasDeviceAddrSyms, mapSyms, isDevicePtrSyms)) {
25882622
if (const auto *details =
25892623
sym.template detailsIf<semantics::HostAssocDetails>())
25902624
converter.copySymbolBinding(details->symbol(), sym);

flang/test/Integration/OpenMP/map-types-and-sizes.f90

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ subroutine mapType_array
3434
!$omp end target
3535
end subroutine mapType_array
3636

37+
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [1 x i64] [i64 8]
38+
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [1 x i64] [i64 33]
39+
subroutine mapType_is_device_ptr
40+
use iso_c_binding, only : c_ptr
41+
type(c_ptr) :: p
42+
!$omp target is_device_ptr(p)
43+
!$omp end target
44+
end subroutine mapType_is_device_ptr
45+
3746
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [3 x i64] [i64 24, i64 0, i64 24]
3847
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [3 x i64] [i64 16933, i64 515, i64 32772]
3948
subroutine mapType_ptr

flang/test/Lower/OpenMP/target.f90

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,36 @@ subroutine omp_target_device_addr
567567
end subroutine omp_target_device_addr
568568

569569

570+
!===============================================================================
571+
! Target `is_device_ptr` clause
572+
!===============================================================================
573+
574+
!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
575+
subroutine omp_target_is_device_ptr
576+
use iso_c_binding, only: c_ptr
577+
implicit none
578+
integer :: i
579+
integer :: arr(4)
580+
type(c_ptr) :: p
581+
582+
i = 0
583+
arr = 0
584+
585+
!CHECK: %[[P_STORAGE:.*]] = omp.map.info {{.*}}{name = "p"}
586+
!CHECK: %[[P_IS:.*]] = omp.map.info {{.*}}{name = "p"}
587+
!CHECK: %[[ARR_MAP:.*]] = omp.map.info {{.*}}{name = "arr"}
588+
!CHECK: omp.target is_device_ptr(%[[P_IS]] :
589+
!CHECK-SAME: has_device_addr(%[[P_STORAGE]] ->
590+
!CHECK-SAME: map_entries({{.*}}%[[ARR_MAP]] ->
591+
!$omp target is_device_ptr(p)
592+
i = i + 1
593+
arr(1) = i
594+
!$omp end target
595+
!CHECK: omp.terminator
596+
!CHECK: }
597+
end subroutine omp_target_is_device_ptr
598+
599+
570600
!===============================================================================
571601
! Target Data with unstructured code
572602
!===============================================================================

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def ClauseMapFlagsRefPtr : I32BitEnumAttrCaseBit<"ref_ptr", 16>;
129129
def ClauseMapFlagsRefPtee : I32BitEnumAttrCaseBit<"ref_ptee", 17>;
130130
def ClauseMapFlagsRefPtrPtee : I32BitEnumAttrCaseBit<"ref_ptr_ptee", 18>;
131131
def ClauseMapFlagsDescriptor : I32BitEnumAttrCaseBit<"descriptor", 19>;
132+
def ClauseMapFlagsIsDevicePtr : I32BitEnumAttrCaseBit<"is_device_ptr", 20>;
132133

133134
def ClauseMapFlags : OpenMP_BitEnumAttr<
134135
"ClauseMapFlags",
@@ -153,7 +154,8 @@ def ClauseMapFlags : OpenMP_BitEnumAttr<
153154
ClauseMapFlagsRefPtr,
154155
ClauseMapFlagsRefPtee,
155156
ClauseMapFlagsRefPtrPtee,
156-
ClauseMapFlagsDescriptor
157+
ClauseMapFlagsDescriptor,
158+
ClauseMapFlagsIsDevicePtr
157159
]>;
158160

159161
def ClauseMapFlagsAttr : OpenMP_EnumAttr<ClauseMapFlags,

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,6 +1874,9 @@ static ParseResult parseMapClause(OpAsmParser &parser,
18741874
if (mapTypeMod == "ref_ptr_ptee")
18751875
mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
18761876

1877+
if (mapTypeMod == "is_device_ptr")
1878+
mapTypeBits |= ClauseMapFlags::is_device_ptr;
1879+
18771880
return success();
18781881
};
18791882

@@ -1945,6 +1948,8 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
19451948
mapTypeStrs.push_back("ref_ptee");
19461949
if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
19471950
mapTypeStrs.push_back("ref_ptr_ptee");
1951+
if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
1952+
mapTypeStrs.push_back("is_device_ptr");
19481953
if (mapFlags == ClauseMapFlags::none)
19491954
mapTypeStrs.push_back("none");
19501955

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
338338
op.getInReductionSyms())
339339
result = todo("in_reduction");
340340
};
341-
auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) {
342-
if (!op.getIsDevicePtrVars().empty())
343-
result = todo("is_device_ptr");
344-
};
345341
auto checkLinear = [&todo](auto op, LogicalResult &result) {
346342
if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
347343
result = todo("linear");
@@ -454,7 +450,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
454450
checkBare(op, result);
455451
checkDevice(op, result);
456452
checkInReduction(op, result);
457-
checkIsDevicePtr(op, result);
458453
})
459454
.Default([](Operation &) {
460455
// Assume all clauses for an operation can be translated unless they are
@@ -2728,7 +2723,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
27282723
return failure();
27292724

27302725
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2731-
2726+
27322727
// Emit Initialization and Update IR for linear variables
27332728
if (wsloopOp.getLinearVars().size()) {
27342729
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
@@ -4034,6 +4029,9 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
40344029
auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
40354030
return (mlirFlags & flag) == flag;
40364031
};
4032+
const bool hasExplicitMap =
4033+
(mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
4034+
omp::ClauseMapFlags::none;
40374035

40384036
llvm::omp::OpenMPOffloadMappingFlags mapType =
40394037
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
@@ -4077,6 +4075,12 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
40774075
if (mapTypeToBool(omp::ClauseMapFlags::descriptor))
40784076
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DESCRIPTOR;
40794077

4078+
if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
4079+
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4080+
if (!hasExplicitMap)
4081+
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4082+
}
4083+
40804084
return mapType;
40814085
}
40824086

@@ -4221,6 +4225,9 @@ static void collectMapDataFromMapOperands(
42214225
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
42224226
auto mapType = convertClauseMapFlags(mapOp.getMapType());
42234227
auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4228+
bool isDevicePtr =
4229+
(mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4230+
omp::ClauseMapFlags::none;
42244231

42254232
mapData.OriginalValue.push_back(origValue);
42264233
mapData.BasePointers.push_back(origValue);
@@ -4247,14 +4254,18 @@ static void collectMapDataFromMapOperands(
42474254
mapData.Mappers.push_back(nullptr);
42484255
}
42494256
} else {
4257+
// For is_device_ptr we need the map type to propagate so the runtime
4258+
// can materialize the device-side copy of the pointer container.
42504259
mapData.Types.push_back(
4251-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4260+
isDevicePtr ? mapType
4261+
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
42524262
mapData.Mappers.push_back(nullptr);
42534263
}
42544264
mapData.Names.push_back(LLVM::createMappingInformation(
42554265
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
42564266
mapData.DevicePointers.push_back(
4257-
llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4267+
isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4268+
: llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
42584269
mapData.IsAMapping.push_back(false);
42594270
mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
42604271
}

mlir/test/Target/LLVMIR/omptarget-llvm.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,3 +622,20 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
622622
// CHECK: br label %[[VAL_40]]
623623
// CHECK: omp.done: ; preds = %[[VAL_68]], %[[VAL_63]], %[[VAL_32]]
624624
// CHECK: ret void
625+
626+
// -----
627+
628+
module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
629+
llvm.func @_QPomp_target_is_device_ptr(%arg0 : !llvm.ptr) {
630+
%map = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.ptr)
631+
map_clauses(is_device_ptr) capture(ByRef) -> !llvm.ptr {name = ""}
632+
omp.target map_entries(%map -> %ptr_arg : !llvm.ptr) {
633+
omp.terminator
634+
}
635+
llvm.return
636+
}
637+
}
638+
639+
// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 8]
640+
// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 288]
641+
// CHECK-LABEL: define void @_QPomp_target_is_device_ptr

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -222,17 +222,6 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) {
222222

223223
// -----
224224

225-
llvm.func @target_is_device_ptr(%x : !llvm.ptr) {
226-
// expected-error@below {{not yet implemented: Unhandled clause is_device_ptr in omp.target operation}}
227-
// expected-error@below {{LLVM Translation failed for operation: omp.target}}
228-
omp.target is_device_ptr(%x : !llvm.ptr) {
229-
omp.terminator
230-
}
231-
llvm.return
232-
}
233-
234-
// -----
235-
236225
llvm.func @target_enter_data_depend(%x: !llvm.ptr) {
237226
// expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}}
238227
// expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}}

0 commit comments

Comments
 (0)