Skip to content

Commit 1bb3ca2

Browse files
vmaksimojsji
authored andcommitted
Fix OCL builtin mangling for the indirect pointer usage (#3460)
This continues #2723 to support correct OCL builtins translation with the `SPV_KHR_untyped_pointers` extension enabled. It reuses the approach from #2924 to get the "true" pointer element type and restore proper SPIR-V builtin mangling. The approach has been refactored into a separate function, which is now also used for OCL builtins translation. Original commit: KhronosGroup/SPIRV-LLVM-Translator@333f956817c386a
1 parent d6c0f78 commit 1bb3ca2

File tree

4 files changed

+217
-61
lines changed

4 files changed

+217
-61
lines changed

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 88 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3684,17 +3684,17 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName,
36843684
std::vector<Type *> ArgTys =
36853685
transTypeVector(SPIRVInstruction::getOperandTypes(Ops), true);
36863686

3687-
auto Ptr = findFirstPtrType(ArgTys);
3688-
if (Ptr < ArgTys.size() &&
3689-
BI->getValueType(Ops[Ptr]->getId())->isTypeUntypedPointerKHR()) {
3687+
unsigned PtrIdx = findFirstPtrType(ArgTys);
3688+
if (PtrIdx < ArgTys.size() &&
3689+
BI->getValueType(Ops[PtrIdx]->getId())->isTypeUntypedPointerKHR()) {
36903690
// Special handling for "truly" untyped pointers to preserve correct
36913691
// builtin mangling of atomic and matrix operations.
36923692
if (isAtomicOpCodeUntypedPtrSupported(OC)) {
36933693
auto *AI = static_cast<SPIRVAtomicInstBase *>(BI);
3694-
ArgTys[Ptr] = TypedPointerType::get(
3694+
ArgTys[PtrIdx] = TypedPointerType::get(
36953695
transType(AI->getSemanticType()),
3696-
SPIRSPIRVAddrSpaceMap::rmap(
3697-
BI->getValueType(Ops[Ptr]->getId())->getPointerStorageClass()));
3696+
SPIRSPIRVAddrSpaceMap::rmap(BI->getValueType(Ops[PtrIdx]->getId())
3697+
->getPointerStorageClass()));
36983698
}
36993699
}
37003700

@@ -3709,51 +3709,8 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName,
37093709
continue;
37103710
}
37113711
if (OpTy->isTypeUntypedPointerKHR()) {
3712-
auto *Val = transValue(Ops[I], BB->getParent(), BB);
3713-
Val = Val->stripPointerCasts();
3714-
if (isUntypedAccessChainOpCode(Ops[I]->getOpCode())) {
3715-
SPIRVType *BaseTy =
3716-
reinterpret_cast<SPIRVAccessChainBase *>(Ops[I])->getBaseType();
3717-
3718-
Type *Ty = nullptr;
3719-
if (BaseTy->isTypeArray())
3720-
Ty = transType(BaseTy->getArrayElementType());
3721-
else if (BaseTy->isTypeVector())
3722-
Ty = transType(BaseTy->getVectorComponentType());
3723-
else
3724-
Ty = transType(BaseTy);
3725-
ArgTys[I] = TypedPointerType::get(
3726-
Ty, SPIRSPIRVAddrSpaceMap::rmap(OpTy->getPointerStorageClass()));
3727-
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(Val)) {
3728-
ArgTys[I] = TypedPointerType::get(
3729-
GEP->getSourceElementType(),
3730-
SPIRSPIRVAddrSpaceMap::rmap(OpTy->getPointerStorageClass()));
3731-
} else if (Ops[I]->getOpCode() == OpUntypedVariableKHR) {
3732-
SPIRVUntypedVariableKHR *UV =
3733-
static_cast<SPIRVUntypedVariableKHR *>(Ops[I]);
3734-
Type *Ty = transType(UV->getDataType());
3735-
ArgTys[I] = TypedPointerType::get(
3736-
Ty, SPIRSPIRVAddrSpaceMap::rmap(OpTy->getPointerStorageClass()));
3737-
} else if (auto *AI = dyn_cast<AllocaInst>(Val)) {
3738-
ArgTys[I] = TypedPointerType::get(
3739-
AI->getAllocatedType(),
3740-
SPIRSPIRVAddrSpaceMap::rmap(OpTy->getPointerStorageClass()));
3741-
} else if (Ops[I]->getOpCode() == OpFunctionParameter &&
3742-
!RetTy->isVoidTy()) {
3743-
// Pointer could be a function parameter. Assume that the type of
3744-
// the pointer is the same as the return type.
3745-
Type *Ty = nullptr;
3746-
// it return type is array type, assign its element type to Ty
3747-
if (RetTy->isArrayTy())
3748-
Ty = RetTy->getArrayElementType();
3749-
else if (RetTy->isVectorTy())
3750-
Ty = cast<VectorType>(RetTy)->getElementType();
3751-
else
3752-
Ty = RetTy;
3753-
3754-
ArgTys[I] = TypedPointerType::get(
3755-
Ty, SPIRSPIRVAddrSpaceMap::rmap(OpTy->getPointerStorageClass()));
3756-
}
3712+
if (Type *NewPtrTy = getTypedPtrFromUntypedOperand(Ops[I], RetTy))
3713+
ArgTys[I] = NewPtrTy;
37573714
}
37583715
}
37593716
}
@@ -3825,6 +3782,55 @@ SPIRVToLLVM::SPIRVToLLVM(Module *LLVMModule, SPIRVModule *TheSPIRVModule)
38253782
DbgTran.reset(new SPIRVToLLVMDbgTran(TheSPIRVModule, LLVMModule, this));
38263783
}
38273784

3785+
Type *SPIRVToLLVM::getTypedPtrFromUntypedOperand(SPIRVValue *Val, Type *RetTy) {
3786+
Type *Ty = nullptr;
3787+
Op OC = Val->getOpCode();
3788+
if (isUntypedAccessChainOpCode(OC)) {
3789+
SPIRVType *BaseTy =
3790+
reinterpret_cast<SPIRVAccessChainBase *>(Val)->getBaseType();
3791+
if (BaseTy->isTypeArray())
3792+
Ty = transType(BaseTy->getArrayElementType());
3793+
else if (BaseTy->isTypeVector())
3794+
Ty = transType(BaseTy->getVectorComponentType());
3795+
else
3796+
Ty = transType(BaseTy);
3797+
} else if (OC == OpUntypedVariableKHR) {
3798+
auto *UV = static_cast<SPIRVUntypedVariableKHR *>(Val);
3799+
Ty = transType(UV->getDataType());
3800+
} else if (OC == OpFunctionParameter && !RetTy->isVoidTy()) {
3801+
// Pointer could be a function parameter. Assume that the type of
3802+
// the pointer is the same as the return type.
3803+
// If return type is array/vector type, assign its element type to Ty.
3804+
if (RetTy->isArrayTy())
3805+
Ty = RetTy->getArrayElementType();
3806+
else if (RetTy->isVectorTy())
3807+
Ty = cast<VectorType>(RetTy)->getElementType();
3808+
else
3809+
Ty = RetTy;
3810+
}
3811+
3812+
unsigned AddrSpace =
3813+
SPIRSPIRVAddrSpaceMap::rmap(Val->getType()->getPointerStorageClass());
3814+
if (Ty)
3815+
return TypedPointerType::get(Ty, AddrSpace);
3816+
3817+
// If we couldn't infer a better element type, attempt to derive from an
3818+
// already translated LLVM value (GEP, Alloca, etc.).
3819+
if (Value *V = getTranslatedValue(Val)) {
3820+
V = V->stripPointerCasts();
3821+
if (auto *GEP = dyn_cast<GetElementPtrInst>(V))
3822+
Ty = GEP->getSourceElementType();
3823+
else if (auto *AI = dyn_cast<AllocaInst>(V))
3824+
Ty = AI->getAllocatedType();
3825+
}
3826+
3827+
if (Ty)
3828+
return TypedPointerType::get(Ty, AddrSpace);
3829+
if (!RetTy->isVoidTy())
3830+
return TypedPointerType::get(RetTy, AddrSpace);
3831+
return nullptr;
3832+
}
3833+
38283834
std::string getSPIRVFuncSuffix(SPIRVInstruction *BI) {
38293835
std::string Suffix = "";
38303836
if (BI->getOpCode() == OpCreatePipeFromPipeStorage) {
@@ -5299,20 +5305,41 @@ Instruction *SPIRVToLLVM::transOCLBuiltinFromExtInst(SPIRVExtInst *BC,
52995305
assert(BM->getBuiltinSet(BC->getExtSetId()) == SPIRVEIS_OpenCL &&
53005306
"Not OpenCL extended instruction");
53015307

5308+
Type *RetTy = transType(BC->getType());
53025309
std::vector<Type *> ArgTypes = transTypeVector(BC->getArgTypes(), true);
5303-
for (unsigned I = 0; I < ArgTypes.size(); I++) {
5304-
// Special handling for "truly" untyped pointers to preserve correct OCL
5305-
// bultin mangling.
5306-
if (isa<PointerType>(ArgTypes[I]) &&
5307-
BC->getArgValue(I)->isUntypedVariable()) {
5308-
auto *BVar = static_cast<SPIRVUntypedVariableKHR *>(BC->getArgValue(I));
5309-
ArgTypes[I] = TypedPointerType::get(
5310-
transType(BVar->getDataType()),
5311-
SPIRSPIRVAddrSpaceMap::rmap(BVar->getStorageClass()));
5310+
// Special handling for "truly" untyped pointers to preserve correct
5311+
// OCL builtin mangling.
5312+
unsigned PtrIdx = findFirstPtrType(ArgTypes);
5313+
if (PtrIdx < ArgTypes.size() &&
5314+
BC->getArgValue(PtrIdx)->getType()->isTypeUntypedPointerKHR()) {
5315+
switch (ExtOp) {
5316+
case OpenCLLIB::Frexp:
5317+
case OpenCLLIB::Remquo:
5318+
case OpenCLLIB::Lgamma_r: {
5319+
// These builtins require their pointer arguments to point to i32 or
5320+
// vector of i32 values.
5321+
Type *DataType = Type::getInt32Ty(*Context);
5322+
if (RetTy->isVectorTy())
5323+
DataType = VectorType::get(DataType,
5324+
cast<VectorType>(RetTy)->getElementCount());
5325+
ArgTypes[PtrIdx] = TypedPointerType::get(
5326+
DataType, cast<PointerType>(ArgTypes[PtrIdx])->getAddressSpace());
5327+
} break;
5328+
case OpenCLLIB::Printf: {
5329+
// Printf's format argument type is always i8*.
5330+
ArgTypes[PtrIdx] = TypedPointerType::get(
5331+
Type::getInt8Ty(*Context),
5332+
cast<PointerType>(ArgTypes[PtrIdx])->getAddressSpace());
5333+
} break;
5334+
default: {
5335+
Type *NewPtrTy =
5336+
getTypedPtrFromUntypedOperand(BC->getArgValue(PtrIdx), RetTy);
5337+
if (NewPtrTy)
5338+
ArgTypes[PtrIdx] = NewPtrTy;
5339+
}
53125340
}
53135341
}
53145342

5315-
Type *RetTy = transType(BC->getType());
53165343
std::string MangledName =
53175344
getSPIRVFriendlyIRFunctionName(ExtOp, ArgTypes, RetTy);
53185345
opaquifyTypedPointers(ArgTypes);

llvm-spirv/lib/SPIRV/SPIRVReader.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ class SPIRVToLLVM : private BuiltinCallHelper {
9090
std::string transTypeToOCLTypeName(SPIRVType *BT, bool IsSigned = true);
9191
std::vector<Type *> transTypeVector(const std::vector<SPIRVType *> &,
9292
bool UseTypedPointerTypes = false);
93+
// Build a typed LLVM pointer type for a SPIR-V untyped pointer operand by
94+
// inferring an element LLVM type from the operand's SPIR-V value or from
95+
// translated LLVM value. Returns nullptr if no element type can be found.
96+
// This is needed to preserve correct mangling for builtins.
97+
Type *getTypedPtrFromUntypedOperand(SPIRVValue *Op, Type *RetTy);
9398
bool translate();
9499
bool transAddressingModel();
95100

llvm-spirv/test/extensions/EXT/SPV_EXT_relaxed_printf_string_address_space/non-constant-printf.ll

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
; RUN: llvm-dis %t.rev.bc -o %t.rev.ll
1010
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
1111

12+
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_EXT_relaxed_printf_string_address_space,+SPV_KHR_untyped_pointers
13+
; RUN: llvm-spirv -to-text %t.spv -o %t.spt
14+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
15+
; RUN: llvm-dis %t.rev.bc -o %t.rev.ll
16+
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
17+
1218
; CHECK-WO-EXT: RequiresExtension: Feature requires the following SPIR-V extension:
1319
; CHECK-WO-EXT: SPV_EXT_relaxed_printf_string_address_space extension should be allowed to translate this module, because this LLVM module contains the printf function with format string, whose address space is not equal to 2 (constant).
1420

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
; This test checks that OCL builtins preserve the type information for opaque
2+
; pointer parameters in their mangling even when untyped pointers extension is enabled.
3+
; Also check for the cases where the pointer parameter is used indirectly to emulate real-life usage.
4+
5+
; Instructions from OpenCL.std extended instruction set that have pointer arguments:
6+
; Math extended instructions: fract, frexp, lgamma_r, modf, remquo, sincos
7+
; Misc instructions: printf, prefetch (covered by separate tests)
8+
9+
; RUN: llvm-spirv %s -o %t.spv
10+
; RUN: spirv-val %t.spv
11+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
12+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
13+
14+
; RUN: llvm-spirv %s -o %t.spv --spirv-ext=+SPV_KHR_untyped_pointers
15+
; TODO: enable back once spirv-tools are updated
16+
; RUNx: spirv-val %t.spv
17+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
18+
; RUN: llvm-dis %t.rev.bc
19+
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
20+
21+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
22+
target triple = "spirv64-unknown-unknown"
23+
24+
; CHECK-LLVM-LABEL: define spir_func void @fract
25+
; CHECK-LLVM-COUNT-3: @_Z5fractfPf
26+
; CHECK-LLVM-LABEL: ret void
27+
define spir_func void @fract(ptr %arg) {
28+
entry:
29+
%ptr = alloca float
30+
%p = bitcast ptr %ptr to ptr
31+
%p2 = bitcast ptr %arg to ptr
32+
%res = call spir_func float @_Z17__spirv_ocl_fractfPf(float 1.250000e+00, ptr %ptr)
33+
%res1 = call spir_func float @_Z17__spirv_ocl_fractfPf(float 1.250000e+00, ptr %p)
34+
%res2 = call spir_func float @_Z17__spirv_ocl_fractfPf(float 1.250000e+00, ptr %p2)
35+
ret void
36+
}
37+
38+
declare spir_func float @_Z17__spirv_ocl_fractfPf(float, ptr)
39+
40+
; CHECK-LLVM-LABEL: define spir_func void @modf
41+
; CHECK-LLVM-COUNT-3: @_Z4modffPf
42+
; CHECK-LLVM-LABEL: ret void
43+
define spir_func void @modf(ptr %arg) {
44+
entry:
45+
%iptr = alloca float
46+
%p = bitcast ptr %iptr to ptr
47+
%p2 = bitcast ptr %arg to ptr
48+
%res = call spir_func float @_Z16__spirv_ocl_modffPf(float 1.250000e+00, ptr %iptr)
49+
%res1 = call spir_func float @_Z16__spirv_ocl_modffPf(float 1.250000e+00, ptr %p)
50+
%res2 = call spir_func float @_Z16__spirv_ocl_modffPf(float 1.250000e+00, ptr %p2)
51+
ret void
52+
}
53+
54+
declare spir_func float @_Z16__spirv_ocl_modffPf(float, ptr)
55+
56+
; CHECK-LLVM-LABEL: define spir_func void @sincos
57+
; CHECK-LLVM-COUNT-3: @_Z6sincosfPf
58+
; CHECK-LLVM-LABEL: ret void
59+
define spir_func void @sincos(ptr %arg) {
60+
entry:
61+
%cosval = alloca float
62+
%p = bitcast ptr %cosval to ptr
63+
%p2 = bitcast ptr %arg to ptr
64+
%res = call spir_func float @_Z18__spirv_ocl_sincosfPf(float 1.250000e+00, ptr %cosval)
65+
%res1 = call spir_func float @_Z18__spirv_ocl_sincosfPf(float 1.250000e+00, ptr %p)
66+
%res2 = call spir_func float @_Z18__spirv_ocl_sincosfPf(float 1.250000e+00, ptr %p2)
67+
ret void
68+
}
69+
70+
declare spir_func float @_Z18__spirv_ocl_sincosfPf(float, ptr)
71+
72+
; CHECK-LLVM-LABEL: define spir_func void @frexp
73+
; CHECK-LLVM-COUNT-3: @_Z5frexpfPi
74+
; CHECK-LLVM-LABEL: ret void
75+
define spir_func void @frexp(ptr %arg) {
76+
entry:
77+
%exp = alloca i32
78+
%p = bitcast ptr %exp to ptr
79+
%p2 = bitcast ptr %arg to ptr
80+
%res = call spir_func float @_Z17__spirv_ocl_frexpfPi(float 1.250000e+00, ptr %exp)
81+
%res1 = call spir_func float @_Z17__spirv_ocl_frexpfPi(float 1.250000e+00, ptr %p)
82+
%res2 = call spir_func float @_Z17__spirv_ocl_frexpfPi(float 1.250000e+00, ptr %p2)
83+
ret void
84+
}
85+
86+
declare spir_func float @_Z17__spirv_ocl_frexpfPi(float, ptr)
87+
88+
; CHECK-LLVM-LABEL: define spir_func void @lgamma_r
89+
; CHECK-LLVM-COUNT-3: @_Z8lgamma_rfPi
90+
; CHECK-LLVM-LABEL: ret void
91+
define spir_func void @lgamma_r(ptr %arg) {
92+
entry:
93+
%signp = alloca i32
94+
%p = bitcast ptr %signp to ptr
95+
%p2 = bitcast ptr %arg to ptr
96+
%res = call spir_func float @_Z20__spirv_ocl_lgamma_rfPi(float 1.250000e+00, ptr %signp)
97+
%res1 = call spir_func float @_Z20__spirv_ocl_lgamma_rfPi(float 1.250000e+00, ptr %p)
98+
%res2 = call spir_func float @_Z20__spirv_ocl_lgamma_rfPi(float 1.250000e+00, ptr %p2)
99+
ret void
100+
}
101+
102+
declare spir_func float @_Z20__spirv_ocl_lgamma_rfPi(float, ptr)
103+
104+
; CHECK-LLVM-LABEL: define spir_func void @remquo
105+
; CHECK-LLVM-COUNT-3: @_Z6remquoffPi
106+
; CHECK-LLVM-LABEL: ret void
107+
define spir_func void @remquo(ptr %arg) {
108+
entry:
109+
%quo = alloca i32
110+
%p = bitcast ptr %quo to ptr
111+
%p2 = bitcast ptr %arg to ptr
112+
%res = call spir_func float @_Z18__spirv_ocl_remquoffPi(float 1.250000e+00, float 1.250000e+00, ptr %quo)
113+
%res1 = call spir_func float @_Z18__spirv_ocl_remquoffPi(float 1.250000e+00, float 1.250000e+00, ptr %p)
114+
%res2 = call spir_func float @_Z18__spirv_ocl_remquoffPi(float 1.250000e+00, float 1.250000e+00, ptr %p2)
115+
ret void
116+
}
117+
118+
declare spir_func float @_Z18__spirv_ocl_remquoffPi(float, float, ptr)

0 commit comments

Comments
 (0)