Skip to content

Commit 548f339

Browse files
vmaksimojsji
authored andcommitted
Preserve byval/sret typed pointer semantics with SPV_KHR_untyped_pointers (#3417)
Avoid translation of pointer function arguments with `byval`/`sret` attribute as untyped pointers to preserve the information about the pointer element type. Insert `OpBitcast` to further use such pointers in the untyped pointer semantics (vise-versa bitcast instructions are explicitly allowed by the [SPV_KHR_untyped_pointers ](https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_untyped_pointers.html) extension). This approach ensures valid reverse translation and correct OpenCL runtime behavior, especially for kernels translation. Without the fix `clSetKernelArg()` fails with `CL_INVALID_MEM_OBJECT` error. Original commit: KhronosGroup/SPIRV-LLVM-Translator@b65c96eeec4e2b3
1 parent 0c32400 commit 548f339

File tree

7 files changed

+181
-21
lines changed

7 files changed

+181
-21
lines changed

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,20 @@ SPIRVType *LLVMToSPIRVBase::transScavengedType(Value *V) {
870870
if (!Ty) {
871871
Ty = FnTy->getParamType(Arg.getArgNo());
872872
}
873+
// Preserve element type for byval/sret arguments even when
874+
// SPV_KHR_untyped_pointers is enabled. Losing pointee type would make it
875+
// impossible to reconstruct the original parameter and will lead to
876+
// OpenCL runtime failure due to mismatched memory object semantics.
877+
if (BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers) &&
878+
(Arg.hasByValAttr() || Arg.hasStructRetAttr())) {
879+
TypedPointerType *TPT = cast<TypedPointerType>(Ty);
880+
auto *NewType = BM->addPointerType(
881+
SPIRSPIRVAddrSpaceMap::map(
882+
static_cast<SPIRAddressSpace>(TPT->getAddressSpace())),
883+
transType(TPT->getElementType()));
884+
PT.push_back(NewType);
885+
continue;
886+
}
873887
PT.push_back(transType(Ty));
874888
}
875889

@@ -2219,7 +2233,43 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
22192233
unsigned ArgNo = Arg->getArgNo();
22202234
SPIRVFunction *BF = BB->getParent();
22212235
// assert(BF->existArgument(ArgNo));
2222-
return mapValue(V, BF->getArgument(ArgNo));
2236+
auto *SPVArg = BF->getArgument(ArgNo);
2237+
2238+
if (BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers) &&
2239+
(Arg->hasByValAttr() || Arg->hasStructRetAttr()) &&
2240+
SPVArg->getType()->isTypePointer() &&
2241+
!SPVArg->getType()->isTypeUntypedPointerKHR()) {
2242+
// When SPV_KHR_untyped_pointers extension is enabled, bitcast typed
2243+
// pointer function arguments to untyped pointers for further usage in the
2244+
// untyped pointers paradigm.
2245+
// Do this only for safe cases where it would not require tracking uses of
2246+
// the original typed pointer argument. Otherwise, just keep the original
2247+
// typed pointer argument to avoid complex transformations later that may
2248+
// break SPIR-V validity.
2249+
2250+
auto PerformBitcastForArg =
2251+
[&](SPIRVFunctionParameter *BA) -> SPIRVValue * {
2252+
// Position to insert bitcast should be right after variable insertion
2253+
// point in the entry basic block.
2254+
auto *InsertBB = BF->getBasicBlock(0);
2255+
auto *InsertPoint = InsertBB->getVariableInsertionPoint();
2256+
auto *UntypedPtrType = BM->addPointerType(
2257+
BA->getType()->getPointerStorageClass(), nullptr);
2258+
2259+
auto *Bitcast = BM->addUnaryInst(OpBitcast, UntypedPtrType, BA,
2260+
InsertBB, InsertPoint);
2261+
return Bitcast;
2262+
};
2263+
2264+
for (auto *U : V->users()) {
2265+
auto *Inst = U->stripPointerCasts();
2266+
if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst) ||
2267+
isa<MemCpyInst>(Inst)) {
2268+
return mapValue(V, PerformBitcastForArg(SPVArg));
2269+
}
2270+
}
2271+
}
2272+
return mapValue(V, SPVArg);
22232273
}
22242274

22252275
if (CreateForward)

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,8 @@ class SPIRVModuleImpl : public SPIRVModule {
517517
SPIRVInstruction *addTransposeInst(SPIRVType *TheType, SPIRVId TheMatrix,
518518
SPIRVBasicBlock *BB) override;
519519
SPIRVInstruction *addUnaryInst(Op, SPIRVType *, SPIRVValue *,
520-
SPIRVBasicBlock *) override;
520+
SPIRVBasicBlock *,
521+
SPIRVInstruction * = nullptr) override;
521522
SPIRVInstruction *addVariable(SPIRVType *, SPIRVType *, bool,
522523
SPIRVLinkageTypeKind, SPIRVValue *,
523524
const std::string &, SPIRVStorageClassKind,
@@ -1748,16 +1749,16 @@ SPIRVInstruction *SPIRVModuleImpl::addReturnValueInst(SPIRVValue *ReturnValue,
17481749
return addInstruction(new SPIRVReturnValue(ReturnValue, BB), BB);
17491750
}
17501751

1751-
SPIRVInstruction *SPIRVModuleImpl::addUnaryInst(Op TheOpCode,
1752-
SPIRVType *TheType,
1753-
SPIRVValue *Op,
1754-
SPIRVBasicBlock *BB) {
1752+
SPIRVInstruction *
1753+
SPIRVModuleImpl::addUnaryInst(Op TheOpCode, SPIRVType *TheType, SPIRVValue *Op,
1754+
SPIRVBasicBlock *BB,
1755+
SPIRVInstruction *InsertBefore) {
17551756
if (TheType->isTypeFloat(16, FPEncodingBFloat16KHR) && TheOpCode != OpDot)
17561757
addCapability(internal::CapabilityBFloat16ArithmeticINTEL);
17571758
return addInstruction(
17581759
SPIRVInstTemplateBase::create(TheOpCode, TheType, getId(),
17591760
getVec(Op->getId()), BB, this),
1760-
BB);
1761+
BB, InsertBefore);
17611762
}
17621763

17631764
SPIRVInstruction *SPIRVModuleImpl::addVectorExtractDynamicInst(

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,8 @@ class SPIRVModule {
495495
SPIRVId TheMatrix,
496496
SPIRVBasicBlock *BB) = 0;
497497
virtual SPIRVInstruction *addUnaryInst(Op, SPIRVType *, SPIRVValue *,
498-
SPIRVBasicBlock *) = 0;
498+
SPIRVBasicBlock *,
499+
SPIRVInstruction * = nullptr) = 0;
499500
virtual SPIRVInstruction *addVariable(SPIRVType *, SPIRVType *, bool,
500501
SPIRVLinkageTypeKind, SPIRVValue *,
501502
const std::string &,
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; Ensure that a typed pointer passed by value is converted to an untyped pointer prior usage.
2+
3+
; RUN: llvm-spirv %s -spirv-text -o %t.spt --spirv-ext=+SPV_KHR_untyped_pointers
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
; RUN: llvm-spirv %s -o %t.spv --spirv-ext=+SPV_KHR_untyped_pointers
6+
; RUN: spirv-val %t.spv
7+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
8+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
9+
10+
; CHECK-SPIRV-DAG: Name [[#Fun:]] "kernel"
11+
; CHECK-SPIRV-DAG: Decorate [[#Param:]] FuncParamAttr 2
12+
; CHECK-SPIRV-DAG: TypeUntypedPointerKHR [[#UntypedPtrTy:]] 7
13+
; CHECK-SPIRV-DAG: TypeStruct [[#StructTy:]]
14+
; CHECK-SPIRV-DAG: TypePointer [[#PtrTy:]] 7 [[#StructTy]]
15+
; CHECK-SPIRV-DAG: TypeInt [[#I32Ty:]] 32 0
16+
17+
; CHECK-SPIRV: Function [[#]] [[#Fun]]
18+
; CHECK-SPIRV: FunctionParameter [[#PtrTy]] [[#Param]]
19+
20+
; CHECK-SPIRV: Bitcast [[#UntypedPtrTy]] [[#BC:]] [[#Param]]
21+
; CHECK-SPIRV: Load [[#I32Ty]] [[#]] [[#BC]]
22+
23+
; CHECK-LLVM: @kernel(ptr %arg0, ptr byval(%struct.Example) align 8 %arg1)
24+
25+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
26+
target triple = "spir-unknown-unknown"
27+
28+
%struct.Example = type { }
29+
30+
define spir_kernel void @kernel(ptr %arg0, ptr byval(%struct.Example) align 8 %arg1) {
31+
entry:
32+
%0 = load i32, ptr %arg1, align 8
33+
ret void
34+
}

llvm-spirv/test/llvm-intrinsics/memset-opaque.ll

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,22 @@
88
; RUN: llvm-spirv %t.bc -spirv-text -o %t.spt --spirv-ext=+SPV_KHR_untyped_pointers
99
; RUN: FileCheck < %t.spt %s --check-prefixes=CHECK-SPIRV,CHECK-SPIRV-UNTYPED-PTR
1010
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_KHR_untyped_pointers
11-
; RUN: spirv-val %t.spv
11+
12+
; TODO: run validator once it's fixed or once we relax type scavenger for untyped pointers.
13+
; Now it fails with the following error, which does not contradict the extension specification:
14+
; error: line 116: Expected input and Result Type to point to the same type: GenericCastToPtr
15+
; %a = OpGenericCastToPtr %_ptr_Workgroup %agg_result
16+
; RUNx: spirv-val %t.spv
17+
1218
; RUN: llvm-spirv -r %t.spv -o - | llvm-dis | FileCheck %s --check-prefix=CHECK-LLVM-OPAQUE
1319

1420
; CHECK-SPIRV: Decorate [[#NonConstMemset:]] LinkageAttributes "spirv.llvm_memset_p3_i32"
1521
; CHECK-SPIRV: TypeInt [[Int8:[0-9]+]] 8 0
1622
; CHECK-SPIRV: Constant {{[0-9]+}} [[Lenmemset21:[0-9]+]] 4
1723
; CHECK-SPIRV: Constant {{[0-9]+}} [[Lenmemset0:[0-9]+]] 12
1824
; CHECK-SPIRV: Constant {{[0-9]+}} [[Const21:[0-9]+]] 21
19-
; CHECK-SPIRV-UNTYPED-PTR: TypeUntypedPointerKHR [[Int8Ptr:[0-9]+]] 8
2025
; CHECK-SPIRV: TypeArray [[Int8x4:[0-9]+]] [[Int8]] [[Lenmemset21]]
26+
; CHECK-SPIRV-UNTYPED-PTR: TypeUntypedPointerKHR [[Int8Ptr:[0-9]+]] 8
2127
; CHECK-SPIRV-TYPED-PTR: TypePointer [[Int8Ptr:[0-9]+]] 8 [[Int8]]
2228
; CHECK-SPIRV: TypeArray [[Int8x12:[0-9]+]] [[Int8]] [[Lenmemset0]]
2329
; CHECK-SPIRV-TYPED-PTR: TypePointer [[Int8PtrConst:[0-9]+]] 7 [[Int8]]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
; Ensure that a pointer passed by value is translated as a typed pointer even
2+
; with the SPV_KHR_untyped_pointers extension enabled to preserve byval semantics.
3+
4+
; RUN: llvm-spirv %s -spirv-text -o %t.txt
5+
; RUN: FileCheck < %t.txt %s --check-prefix=CHECK-SPIRV
6+
; RUN: llvm-spirv %s -o %t.spv
7+
; RUN: spirv-val %t.spv
8+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
9+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
10+
11+
; RUN: llvm-spirv %s -spirv-text -o %t.txt --spirv-ext=+SPV_KHR_untyped_pointers
12+
; RUN: FileCheck < %t.txt %s --check-prefix=CHECK-SPIRV
13+
; RUN: llvm-spirv %s -o %t.spv --spirv-ext=+SPV_KHR_untyped_pointers
14+
; RUN: spirv-val %t.spv
15+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
16+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
17+
18+
; CHECK-SPIRV-DAG: Name [[#Fun:]] "kernel"
19+
; CHECK-SPIRV-DAG: Decorate [[#Param:]] FuncParamAttr 2
20+
; CHECK-SPIRV-DAG: TypeStruct [[#StructTy:]]
21+
; CHECK-SPIRV-DAG: TypePointer [[#PtrTy:]] [[#]] [[#StructTy]]
22+
; CHECK-SPIRV: Function [[#]] [[#Fun]]
23+
; CHECK-SPIRV: FunctionParameter [[#PtrTy]] [[#Param]]
24+
25+
; CHECK-LLVM: @kernel(ptr %arg0, ptr byval(%struct.Example) align 8 %arg1)
26+
27+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
28+
target triple = "spir-unknown-unknown"
29+
30+
%struct.Example = type { }
31+
32+
define spir_kernel void @kernel(ptr %arg0, ptr byval(%struct.Example) align 8 %arg1) {
33+
entry:
34+
ret void
35+
}
Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,44 @@
1-
; RUN: llvm-as %s -o %t.bc
2-
; RUN: llvm-spirv %t.bc -spirv-text -o %t.txt
1+
; Ensure that sret pointer semantics is preserved when the parameter is unused
2+
; (even with the SPV_KHR_untyped_pointers extension enabled).
3+
4+
; RUN: llvm-spirv %s -spirv-text -o %t.txt
35
; RUN: FileCheck < %t.txt %s --check-prefix=CHECK-SPIRV
4-
; RUN: llvm-spirv %t.bc -o %t.spv
6+
; RUN: llvm-spirv %s -o %t.spv
57
; RUN: spirv-val %t.spv
68
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
79
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
810

9-
; CHECK-SPIRV-DAG: Name [[#Fun:]] "_Z3booi"
10-
; CHECK-SPIRV-DAG: Decorate [[#Param:]] FuncParamAttr 3
11-
; CHECK-SPIRV-DAG: TypePointer [[#PtrTy:]] [[#]] [[#StructTy:]]
11+
; RUN: llvm-spirv %s -spirv-text -o %t.txt --spirv-ext=+SPV_KHR_untyped_pointers
12+
; RUN: FileCheck < %t.txt %s --check-prefix=CHECK-SPIRV
13+
; RUN: llvm-spirv %s -o %t.spv --spirv-ext=+SPV_KHR_untyped_pointers
14+
; RUNx: spirv-val %t.spv
15+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
16+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
17+
18+
; CHECK-SPIRV-DAG: Name [[#FunBoo:]] "boo"
19+
; CHECK-SPIRV-DAG: Name [[#FunBaz:]] "baz"
20+
; CHECK-SPIRV-DAG: Decorate [[#ParamBoo:]] FuncParamAttr 3
21+
; CHECK-SPIRV-DAG: Decorate [[#ParamBaz:]] FuncParamAttr 3
22+
; CHECK-SPIRV-DAG: TypePointer [[#PtrTy7:]] 7 [[#StructTy:]]
23+
; CHECK-SPIRV-DAG: TypePointer [[#PtrTy8:]] 8 [[#StructTy:]]
1224
; CHECK-SPIRV-DAG: TypeStruct [[#StructTy]]
13-
; CHECK-SPIRV: Function [[#]] [[#Fun]]
14-
; CHECK-SPIRV: FunctionParameter [[#PtrTy:]] [[#Param]]
1525

16-
; CHECK-LLVM: call spir_func void @_Z3booi(ptr sret(%struct.Example) align 8
26+
; CHECK-SPIRV-DAG: TypeFunction [[#BooTy:]] [[#]] [[#PtrTy7]] [[#]] {{$}}
27+
; CHECK-SPIRV-DAG: TypeFunction [[#BazTy:]] [[#]] [[#PtrTy8]] {{$}}
28+
29+
; CHECK-SPIRV: Function [[#]] [[#FunBoo]] [[#]] [[#BooTy]]
30+
; CHECK-SPIRV: FunctionParameter [[#PtrTy7]] [[#ParamBoo]]
31+
32+
; CHECK-SPIRV: Function [[#]] [[#FunBaz]] [[#]] [[#BazTy]]
33+
; CHECK-SPIRV: FunctionParameter [[#PtrTy8]] [[#ParamBaz]]
34+
35+
; CHECK-SPIRV: FunctionParameter [[#PtrTy7]] [[#ParamBar:]]
36+
; With untyped extension enabled addrspacecast is done to untyped pointer type in addrspace 8.
37+
; CHECK-SPIRV: PtrCastToGeneric [[#]] [[#Cast:]] [[#ParamBar]]
38+
; CHECK-SPIRV: FunctionCall [[#]] [[#]] [[#FunBaz]] [[#Cast]]
39+
40+
; CHECK-LLVM: call spir_func void @boo(ptr sret(%struct.Example) align 8
41+
; CHECK-LLVM: call spir_func void @baz(ptr addrspace(4) sret(%struct.Example) %cast)
1742

1843
source_filename = "/app/example.cpp"
1944
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
@@ -23,8 +48,16 @@ target triple = "spir-unknown-unknown"
2348

2449
define spir_func i32 @foo() {
2550
%1 = alloca %struct.Example, align 8
26-
call void @_Z3booi(ptr sret(%struct.Example) align 8 %1, i32 noundef 42)
51+
call void @boo(ptr sret(%struct.Example) align 8 %1, i32 noundef 42)
2752
ret i32 0
2853
}
2954

30-
declare void @_Z3booi(ptr sret(%struct.Example) align 8, i32 noundef)
55+
define spir_func void @bar(ptr sret(%struct.Example) %ret_ptr) {
56+
%cast = addrspacecast ptr %ret_ptr to ptr addrspace(4)
57+
call void @baz(ptr addrspace(4) sret(%struct.Example) %cast)
58+
ret void
59+
}
60+
61+
62+
declare void @boo(ptr sret(%struct.Example) align 8, i32 noundef)
63+
declare void @baz(ptr addrspace(4) sret(%struct.Example))

0 commit comments

Comments
 (0)