diff --git a/clang/include/clang/Basic/AddressSpaces.h b/clang/include/clang/Basic/AddressSpaces.h index 48e4a1c61fe02..7280b8fc923d2 100644 --- a/clang/include/clang/Basic/AddressSpaces.h +++ b/clang/include/clang/Basic/AddressSpaces.h @@ -62,6 +62,7 @@ enum class LangAS : unsigned { hlsl_private, hlsl_device, hlsl_input, + hlsl_push_constant, // Wasm specific address spaces. wasm_funcref, diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index c929da7d538bd..71595dbe90bd9 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -5164,6 +5164,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr { let Documentation = [HLSLVkExtBuiltinInputDocs]; } +def HLSLVkPushConstant : InheritableAttr { + let Spellings = [CXX11<"vk", "push_constant">]; + let Args = []; + let Subjects = SubjectList<[GlobalVar], ErrorDiag>; + let LangOpts = [HLSL]; + let Documentation = [HLSLVkPushConstantDocs]; +} + def HLSLVkConstantId : InheritableAttr { let Spellings = [CXX11<"vk", "constant_id">]; let Args = [IntArgument<"Id">]; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index fa365da3ed9aa..201187ae08e36 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -8794,6 +8794,22 @@ https://github.com/microsoft/hlsl-specs/blob/main/proposals/0011-inline-spirv.md }]; } +def HLSLVkPushConstantDocs : Documentation { + let Category = DocCatVariable; + let Content = [{ +Vulkan shaders have `PushConstants` + +The ``[[vk::push_constant]]`` attribute allows you to declare this +global variable as a push constant when targeting Vulkan. +This attribute is ignored otherwise. + +This attribute must be applied to the variable, not underlying type. +The variable type must be a struct, per the requirements of Vulkan, "there +must be no more than one push constant block statically used per shader entry +point." +}]; +} + def AnnotateTypeDocs : Documentation { let Category = DocCatType; let Heading = "annotate_type"; diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 69ed958a2a2aa..703e759530fa1 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -13216,6 +13216,9 @@ def err_hlsl_attr_invalid_type : Error< "attribute %0 only applies to a field or parameter of type '%1'">; def err_hlsl_attr_invalid_ast_node : Error< "attribute %0 only applies to %1">; +def err_hlsl_attr_incompatible + : Error<"%0 attribute is not compatible with %1 attribute">; + def err_hlsl_entry_shader_attr_mismatch : Error< "%0 attribute on entry function does not match the target profile">; def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numthreads attribute cannot exceed %1">; @@ -13333,6 +13336,9 @@ def err_hlsl_incomplete_resource_array_in_function_param: Error< def err_hlsl_assign_to_global_resource: Error< "assignment to global resource variable %0 is not allowed">; +def err_hlsl_push_constant_unique + : Error<"cannot have more than one push constant block">; + // Layout randomization diagnostics. def err_non_designated_init_used : Error< "a randomized struct can only be initialized with a designated initializer">; diff --git a/clang/include/clang/Basic/HLSLRuntime.h b/clang/include/clang/Basic/HLSLRuntime.h index 03166805daa6a..f6a1cf9636467 100644 --- a/clang/include/clang/Basic/HLSLRuntime.h +++ b/clang/include/clang/Basic/HLSLRuntime.h @@ -14,6 +14,7 @@ #ifndef CLANG_BASIC_HLSLRUNTIME_H #define CLANG_BASIC_HLSLRUNTIME_H +#include "clang/Basic/AddressSpaces.h" #include "clang/Basic/LangOptions.h" #include @@ -30,6 +31,10 @@ getStageFromEnvironment(const llvm::Triple::EnvironmentType &E) { return static_cast(Pipeline); } +constexpr bool isInitializedByPipeline(LangAS AS) { + return AS == LangAS::hlsl_input || AS == LangAS::hlsl_push_constant; +} + #define ENUM_COMPARE_ASSERT(Value) \ static_assert( \ getStageFromEnvironment(llvm::Triple::Value) == ShaderStage::Value, \ diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index a2faa91d1e54d..99d8ed137b0c2 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -188,6 +188,7 @@ class SemaHLSL : public SemaBase { void handleSemanticAttr(Decl *D, const ParsedAttr &AL); void handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL); + void handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL); bool CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall); QualType ProcessResourceTypeAttributes(QualType Wrapped); @@ -237,6 +238,8 @@ class SemaHLSL : public SemaBase { IdentifierInfo *RootSigOverrideIdent = nullptr; + bool HasDeclaredAPushConstant = false; + // Information about the current subtree being flattened. struct SemanticInfo { HLSLParsedSemanticAttr *Semantic; diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp index 4548af17e37f2..53082bcf78f6a 100644 --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -101,6 +101,7 @@ bool Qualifiers::isTargetAddressSpaceSupersetOf(LangAS A, LangAS B, (A == LangAS::Default && B == LangAS::hlsl_private) || (A == LangAS::Default && B == LangAS::hlsl_device) || (A == LangAS::Default && B == LangAS::hlsl_input) || + (A == LangAS::Default && B == LangAS::hlsl_push_constant) || // Conversions from target specific address spaces may be legal // depending on the target information. Ctx.getTargetInfo().isAddressSpaceSupersetOf(A, B); diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp index d2881d5ac518a..c7ac9c2a2124f 100644 --- a/clang/lib/AST/TypePrinter.cpp +++ b/clang/lib/AST/TypePrinter.cpp @@ -2694,6 +2694,8 @@ std::string Qualifiers::getAddrSpaceAsString(LangAS AS) { return "hlsl_device"; case LangAS::hlsl_input: return "hlsl_input"; + case LangAS::hlsl_push_constant: + return "hlsl_push_constant"; case LangAS::wasm_funcref: return "__funcref"; default: diff --git a/clang/lib/Basic/TargetInfo.cpp b/clang/lib/Basic/TargetInfo.cpp index c0ed900ebd45c..92ed0c461ff82 100644 --- a/clang/lib/Basic/TargetInfo.cpp +++ b/clang/lib/Basic/TargetInfo.cpp @@ -52,6 +52,7 @@ static const LangASMap FakeAddrSpaceMap = { 15, // hlsl_private 16, // hlsl_device 17, // hlsl_input + 18, // hlsl_push_constant 20, // wasm_funcref }; diff --git a/clang/lib/Basic/Targets/AArch64.h b/clang/lib/Basic/Targets/AArch64.h index 1a7aa658e9d87..8e8d8e6ae86b5 100644 --- a/clang/lib/Basic/Targets/AArch64.h +++ b/clang/lib/Basic/Targets/AArch64.h @@ -48,6 +48,7 @@ static const unsigned ARM64AddrSpaceMap[] = { 0, // hlsl_private 0, // hlsl_device 0, // hlsl_input + 0, // hlsl_push_constant // Wasm address space values for this target are dummy values, // as it is only enabled for Wasm targets. 20, // wasm_funcref diff --git a/clang/lib/Basic/Targets/AMDGPU.cpp b/clang/lib/Basic/Targets/AMDGPU.cpp index d4d696b8456b6..993a73a89c9e9 100644 --- a/clang/lib/Basic/Targets/AMDGPU.cpp +++ b/clang/lib/Basic/Targets/AMDGPU.cpp @@ -63,6 +63,7 @@ const LangASMap AMDGPUTargetInfo::AMDGPUDefIsGenMap = { llvm::AMDGPUAS::PRIVATE_ADDRESS, // hlsl_private llvm::AMDGPUAS::GLOBAL_ADDRESS, // hlsl_device llvm::AMDGPUAS::PRIVATE_ADDRESS, // hlsl_input + llvm::AMDGPUAS::GLOBAL_ADDRESS, // hlsl_push_constant }; const LangASMap AMDGPUTargetInfo::AMDGPUDefIsPrivMap = { @@ -91,6 +92,7 @@ const LangASMap AMDGPUTargetInfo::AMDGPUDefIsPrivMap = { llvm::AMDGPUAS::PRIVATE_ADDRESS, // hlsl_private llvm::AMDGPUAS::GLOBAL_ADDRESS, // hlsl_device llvm::AMDGPUAS::PRIVATE_ADDRESS, // hlsl_input + llvm::AMDGPUAS::GLOBAL_ADDRESS, // hlsl_push_constant }; } // namespace targets } // namespace clang diff --git a/clang/lib/Basic/Targets/DirectX.h b/clang/lib/Basic/Targets/DirectX.h index a21a593365773..c0799a6f7610f 100644 --- a/clang/lib/Basic/Targets/DirectX.h +++ b/clang/lib/Basic/Targets/DirectX.h @@ -46,6 +46,7 @@ static const unsigned DirectXAddrSpaceMap[] = { 0, // hlsl_private 0, // hlsl_device 0, // hlsl_input + 0, // hlsl_push_constant // Wasm address space values for this target are dummy values, // as it is only enabled for Wasm targets. 20, // wasm_funcref diff --git a/clang/lib/Basic/Targets/NVPTX.h b/clang/lib/Basic/Targets/NVPTX.h index f5c8396f398aa..6338a4f2f9036 100644 --- a/clang/lib/Basic/Targets/NVPTX.h +++ b/clang/lib/Basic/Targets/NVPTX.h @@ -50,6 +50,7 @@ static const unsigned NVPTXAddrSpaceMap[] = { 0, // hlsl_private 0, // hlsl_device 0, // hlsl_input + 0, // hlsl_push_constant // Wasm address space values for this target are dummy values, // as it is only enabled for Wasm targets. 20, // wasm_funcref diff --git a/clang/lib/Basic/Targets/SPIR.h b/clang/lib/Basic/Targets/SPIR.h index 332bf79e2babd..02fd6d13958c0 100644 --- a/clang/lib/Basic/Targets/SPIR.h +++ b/clang/lib/Basic/Targets/SPIR.h @@ -51,6 +51,7 @@ static const unsigned SPIRDefIsPrivMap[] = { 10, // hlsl_private 11, // hlsl_device 7, // hlsl_input + 13, // hlsl_push_constant // Wasm address space values for this target are dummy values, // as it is only enabled for Wasm targets. 20, // wasm_funcref @@ -87,6 +88,7 @@ static const unsigned SPIRDefIsGenMap[] = { 10, // hlsl_private 11, // hlsl_device 7, // hlsl_input + 13, // hlsl_push_constant // Wasm address space values for this target are dummy values, // as it is only enabled for Wasm targets. 20, // wasm_funcref diff --git a/clang/lib/Basic/Targets/SystemZ.h b/clang/lib/Basic/Targets/SystemZ.h index 4e15d5af1cde6..4ce515b31a001 100644 --- a/clang/lib/Basic/Targets/SystemZ.h +++ b/clang/lib/Basic/Targets/SystemZ.h @@ -46,6 +46,7 @@ static const unsigned ZOSAddressMap[] = { 0, // hlsl_private 0, // hlsl_device 0, // hlsl_input + 0, // hlsl_push_constant 0 // wasm_funcref }; diff --git a/clang/lib/Basic/Targets/TCE.h b/clang/lib/Basic/Targets/TCE.h index 005cab9819472..161025378c471 100644 --- a/clang/lib/Basic/Targets/TCE.h +++ b/clang/lib/Basic/Targets/TCE.h @@ -55,6 +55,7 @@ static const unsigned TCEOpenCLAddrSpaceMap[] = { 0, // hlsl_private 0, // hlsl_device 0, // hlsl_input + 0, // hlsl_push_constant // Wasm address space values for this target are dummy values, // as it is only enabled for Wasm targets. 20, // wasm_funcref diff --git a/clang/lib/Basic/Targets/WebAssembly.h b/clang/lib/Basic/Targets/WebAssembly.h index 4de6ce6bb5a21..c8065843aeb42 100644 --- a/clang/lib/Basic/Targets/WebAssembly.h +++ b/clang/lib/Basic/Targets/WebAssembly.h @@ -46,6 +46,7 @@ static const unsigned WebAssemblyAddrSpaceMap[] = { 0, // hlsl_private 0, // hlsl_device 0, // hlsl_input + 0, // hlsl_push_constant 20, // wasm_funcref }; diff --git a/clang/lib/Basic/Targets/X86.h b/clang/lib/Basic/Targets/X86.h index e7da2622e78b5..7b88ac70e234f 100644 --- a/clang/lib/Basic/Targets/X86.h +++ b/clang/lib/Basic/Targets/X86.h @@ -50,6 +50,7 @@ static const unsigned X86AddrSpaceMap[] = { 0, // hlsl_private 0, // hlsl_device 0, // hlsl_input + 0, // hlsl_push_constant // Wasm address space values for this target are dummy values, // as it is only enabled for Wasm targets. 20, // wasm_funcref diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp index a04d606b6a0e5..c3ababa950d83 100644 --- a/clang/lib/CodeGen/CodeGenModule.cpp +++ b/clang/lib/CodeGen/CodeGenModule.cpp @@ -6100,9 +6100,11 @@ void CodeGenModule::EmitGlobalVarDefinition(const VarDecl *D, getCUDARuntime().handleVarRegistration(D, *GV); } - if (LangOpts.HLSL && GetGlobalVarAddressSpace(D) == LangAS::hlsl_input) { + if (LangOpts.HLSL && + hlsl::isInitializedByPipeline(GetGlobalVarAddressSpace(D))) { // HLSL Input variables are considered to be set by the driver/pipeline, but - // only visible to a single thread/wave. + // only visible to a single thread/wave. Push constants are also externally + // initialized, but constant, hence cross-wave visibility is not relevant. GV->setExternallyInitialized(true); } else { GV->setInitializer(Init); @@ -6153,10 +6155,11 @@ void CodeGenModule::EmitGlobalVarDefinition(const VarDecl *D, !D->hasAttr()) Linkage = llvm::GlobalValue::InternalLinkage; - // HLSL variables in the input address space maps like memory-mapped - // variables. Even if they are 'static', they are externally initialized and - // read/write by the hardware/driver/pipeline. - if (LangOpts.HLSL && GetGlobalVarAddressSpace(D) == LangAS::hlsl_input) + // HLSL variables in the input or push-constant address space maps are like + // memory-mapped variables. Even if they are 'static', they are externally + // initialized and read/write by the hardware/driver/pipeline. + if (LangOpts.HLSL && + hlsl::isInitializedByPipeline(GetGlobalVarAddressSpace(D))) Linkage = llvm::GlobalValue::ExternalLinkage; GV->setLinkage(Linkage); diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 7e7abcee40a56..303e7c3ce7832 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -30,6 +30,7 @@ #include "clang/AST/Type.h" #include "clang/Basic/Builtins.h" #include "clang/Basic/DiagnosticComment.h" +#include "clang/Basic/HLSLRuntime.h" #include "clang/Basic/PartialDiagnostic.h" #include "clang/Basic/SourceManager.h" #include "clang/Basic/TargetInfo.h" @@ -14586,10 +14587,10 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) { if (getLangOpts().HLSL && HLSL().ActOnUninitializedVarDecl(Var)) return; - // HLSL input variables are expected to be externally initialized, even - // when marked `static`. + // HLSL input & push-constant variables are expected to be externally + // initialized, even when marked `static`. if (getLangOpts().HLSL && - Var->getType().getAddressSpace() == LangAS::hlsl_input) + hlsl::isInitializedByPipeline(Var->getType().getAddressSpace())) return; // C++03 [dcl.init]p9: diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index c9d1ee76a2e52..f1041cdfc9174 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -7679,6 +7679,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL, case ParsedAttr::AT_HLSLVkExtBuiltinInput: S.HLSL().handleVkExtBuiltinInputAttr(D, AL); break; + case ParsedAttr::AT_HLSLVkPushConstant: + S.HLSL().handleVkPushConstantAttr(D, AL); + break; case ParsedAttr::AT_HLSLVkConstantId: S.HLSL().handleVkConstantIdAttr(D, AL); break; diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 89645e3b67db3..73610a3a28346 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1768,6 +1768,11 @@ void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) { HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID)); } +void SemaHLSL::handleVkPushConstantAttr(Decl *D, const ParsedAttr &AL) { + D->addAttr(::new (getASTContext()) + HLSLVkPushConstantAttr(getASTContext(), AL)); +} + void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) { uint32_t Id; if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id)) @@ -3934,12 +3939,15 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) { return Ty; } -static bool IsDefaultBufferConstantDecl(VarDecl *VD) { +static bool IsDefaultBufferConstantDecl(const ASTContext &Ctx, VarDecl *VD) { + bool IsVulkan = + Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::Vulkan; + bool IsVKPushConstant = IsVulkan && VD->hasAttr(); QualType QT = VD->getType(); return VD->getDeclContext()->isTranslationUnit() && QT.getAddressSpace() == LangAS::Default && VD->getStorageClass() != SC_Static && - !VD->hasAttr() && + !VD->hasAttr() && !IsVKPushConstant && !isInvalidConstantBufferLeafElementType(QT.getTypePtr()); } @@ -3960,6 +3968,19 @@ void SemaHLSL::deduceAddressSpace(VarDecl *Decl) { return; } + bool IsVulkan = getASTContext().getTargetInfo().getTriple().getOS() == + llvm::Triple::Vulkan; + if (IsVulkan && Decl->hasAttr()) { + if (HasDeclaredAPushConstant) + SemaRef.Diag(Decl->getLocation(), diag::err_hlsl_push_constant_unique); + + LangAS ImplAS = LangAS::hlsl_push_constant; + Type = SemaRef.getASTContext().getAddrSpaceQualType(Type, ImplAS); + Decl->setType(Type); + HasDeclaredAPushConstant = true; + return; + } + if (Type->isSamplerT() || Type->isVoidType()) return; @@ -3992,7 +4013,7 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) { // Global variables outside a cbuffer block that are not a resource, static, // groupshared, or an empty array or struct belong to the default constant // buffer $Globals (to be created at the end of the translation unit). - if (IsDefaultBufferConstantDecl(VD)) { + if (IsDefaultBufferConstantDecl(getASTContext(), VD)) { // update address space to hlsl_constant QualType NewTy = getASTContext().getAddrSpaceQualType( VD->getType(), LangAS::hlsl_constant); @@ -4293,8 +4314,11 @@ void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) { bool HasBinding = false; for (Attr *A : VD->attrs()) { - if (isa(A)) + if (isa(A)) { HasBinding = true; + if (auto PA = VD->getAttr()) + Diag(PA->getLoc(), diag::err_hlsl_attr_incompatible) << A << PA; + } HLSLResourceBindingAttr *RBA = dyn_cast(A); if (!RBA || !RBA->hasRegisterSlot()) diff --git a/clang/test/AST/HLSL/vk.pushconstant.hlsl b/clang/test/AST/HLSL/vk.pushconstant.hlsl new file mode 100644 index 0000000000000..6e7179e887143 --- /dev/null +++ b/clang/test/AST/HLSL/vk.pushconstant.hlsl @@ -0,0 +1,16 @@ +// RUN: %clang_cc1 -triple spirv-unknown-vulkan1.3-compute -x hlsl -ast-dump -o - %s | FileCheck %s --check-prefix=CHECK-VK +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-compute -x hlsl -ast-dump -o - %s | FileCheck %s --check-prefix=CHECK-DX + +struct S { + int value; +}; + +[[vk::push_constant]] S PC; +// CHECK-VK: VarDecl 0x[[A:[0-9a-f]+]] col:25 PC 'hlsl_push_constant S' +// CHECK-VK-NEXT: HLSLVkPushConstantAttr 0x[[A:[0-9a-f]+]] + +// CHECK-DX: VarDecl 0x[[A:[0-9a-f]+]] col:25 PC 'hlsl_constant S' +// CHECK-DX-NEXT: HLSLVkPushConstantAttr 0x[[A:[0-9a-f]+]] + +[numthreads(1, 1, 1)] +void main() { } diff --git a/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.access.bitfield.hlsl b/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.access.bitfield.hlsl new file mode 100644 index 0000000000000..412ec4dffc572 --- /dev/null +++ b/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.access.bitfield.hlsl @@ -0,0 +1,20 @@ +// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s + +struct S { + uint32_t a : 1; + uint32_t b : 1; +}; +// CHECK: %struct.S = type { i8 } + +[[vk::push_constant]] S buffer; +// CHECK: @buffer = external hidden addrspace(13) externally_initialized global %struct.S, align 1 + +[numthreads(1, 1, 1)] +void main() { + uint32_t v = buffer.b; +// CHECK: %bf.load = load i8, ptr addrspace(13) @buffer, align 1 +// CHECK: %bf.lshr = lshr i8 %bf.load, 1 +// CHECK: %bf.clear = and i8 %bf.lshr, 1 +// CHECK: %bf.cast = zext i8 %bf.clear to i32 +// CHECK: store i32 %bf.cast +} diff --git a/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.access.hlsl b/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.access.hlsl new file mode 100644 index 0000000000000..890e9a6c1262a --- /dev/null +++ b/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.access.hlsl @@ -0,0 +1,16 @@ +// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s + +struct S { + uint a; +}; +// CHECK: %struct.S = type { i32 } + +[[vk::push_constant]] S buffer; +// CHECK: @buffer = external hidden addrspace(13) externally_initialized global %struct.S, align 1 + +[numthreads(1, 1, 1)] +void main() { + uint32_t v = buffer.a; +// CHECK: %[[#REG:]] = load i32, ptr addrspace(13) @buffer, align 1 +// CHECK: store i32 %[[#REG]], ptr %v, align 4 +} diff --git a/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.anon-struct.hlsl b/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.anon-struct.hlsl new file mode 100644 index 0000000000000..2b2e9d09c7ab0 --- /dev/null +++ b/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.anon-struct.hlsl @@ -0,0 +1,17 @@ +// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s + +[[vk::push_constant]] +struct { + int a; + float b; + float3 c; +} +PushConstants; + +// CHECK: %struct.anon = type <{ i32, float, <3 x float> }> +// CHECK: @PushConstants = external hidden addrspace(13) externally_initialized global %struct.anon, align 1 + +[numthreads(1, 1, 1)] +void main() { + float tmp = PushConstants.b; +} diff --git a/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.dxil.hlsl b/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.dxil.hlsl new file mode 100644 index 0000000000000..47719b5f28e23 --- /dev/null +++ b/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.dxil.hlsl @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-compute -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s + +struct S { + uint value; +}; +// CHECK: %struct.S = type { i32 } + +// When targeting DXIL, the attribute is ignored, meaning this variable +// is part of the implicit cbuffer. +[[vk::push_constant]] S buffer; +// CHECK: @buffer = external hidden addrspace(2) global %struct.S, align 1 + +[numthreads(1, 1, 1)] +void main() { + uint32_t v = buffer.value; +// CHECK: %[[#REG:]] = load i32, ptr addrspace(2) @buffer, align 4 +// CHECK: store i32 %[[#REG]], ptr %v, align 4 +} diff --git a/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.layout.hlsl b/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.layout.hlsl new file mode 100644 index 0000000000000..c671e6effe3e4 --- /dev/null +++ b/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.layout.hlsl @@ -0,0 +1,31 @@ +// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s + +struct T { + float2 f1[3]; + // FIXME(): matrix support. + // column_major float3x2 f2[2]; + // row_major int3x2 f4[2]; + // row_major float3x2 f3[2]; +}; +// %struct.T = type { [3 x <2 x float>] } + +struct S { + float f1; + float3 f2; + T f4; + // FIXME(): matrix support. + // row_major int2x3 f5; + // row_major float2x3 f3; +}; +// %struct.S = type <{ float, <3 x float>, %struct.T }> + +[[vk::push_constant]] +S pcs; +// CHECK: @pcs = external hidden addrspace(13) externally_initialized global %struct.S, align 1 + +[numthreads(1, 1, 1)] +void main() { + float a = pcs.f1; +// CHECK: %[[#TMP:]] = load float, ptr addrspace(13) @pcs, align 1 +// CHECK: store float %[[#TMP]], ptr %a, align 4 +} diff --git a/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.static.hlsl b/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.static.hlsl new file mode 100644 index 0000000000000..ca05fda1d0ee1 --- /dev/null +++ b/clang/test/CodeGenHLSL/vk-features/vk.pushconstant.static.hlsl @@ -0,0 +1,25 @@ +// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s + +struct S +{ + const static uint a = 1; + uint b; +}; +// CHECK: %struct.S = type { i32 } + +[[vk::push_constant]] S s; +// CHECK: @s = external hidden addrspace(13) externally_initialized global %struct.S, align 1 + +[numthreads(1,1,1)] +void main() +{ + uint32_t v = s.b; + // CHECK: %[[#TMP:]] = load i32, ptr addrspace(13) @s, align 1 + // CHECK: store i32 %[[#TMP]], ptr %v, align 4 + + uint32_t w = S::a; + // CHECK: store i32 1, ptr %w, align 4 + + uint32_t x = s.a; + // CHECK: store i32 1, ptr %x, align 4 +} diff --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test b/clang/test/Misc/pragma-attribute-supported-attributes-list.test index 1e1d4a356f515..0d57a0937111a 100644 --- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test +++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test @@ -90,6 +90,7 @@ // CHECK-NEXT: GNUInline (SubjectMatchRule_function) // CHECK-NEXT: HIPManaged (SubjectMatchRule_variable) // CHECK-NEXT: HLSLVkLocation (SubjectMatchRule_variable_is_parameter, SubjectMatchRule_field, SubjectMatchRule_function) +// CHECK-NEXT: HLSLVkPushConstant (SubjectMatchRule_variable_is_global) // CHECK-NEXT: Hot (SubjectMatchRule_function) // CHECK-NEXT: HybridPatchable (SubjectMatchRule_function) // CHECK-NEXT: IBAction (SubjectMatchRule_objc_method_is_instance) diff --git a/clang/test/SemaHLSL/vk.pushconstant.invalid.hlsl b/clang/test/SemaHLSL/vk.pushconstant.invalid.hlsl new file mode 100644 index 0000000000000..6b58decfa5188 --- /dev/null +++ b/clang/test/SemaHLSL/vk.pushconstant.invalid.hlsl @@ -0,0 +1,13 @@ +// RUN: %clang_cc1 -triple spirv-unknown-vulkan-compute -x hlsl -emit-llvm -disable-llvm-passes -o - -hlsl-entry main %s -verify + +struct S { + float f; +}; + +// expected-error@+1 {{'vk::binding' attribute is not compatible with 'vk::push_constant' attribute}} +[[vk::push_constant, vk::binding(5)]] +S pcs; + +[numthreads(1, 1, 1)] +void main() { +} diff --git a/clang/test/SemaHLSL/vk.pushconstant.multiple.hlsl b/clang/test/SemaHLSL/vk.pushconstant.multiple.hlsl new file mode 100644 index 0000000000000..9a2ae1266e69c --- /dev/null +++ b/clang/test/SemaHLSL/vk.pushconstant.multiple.hlsl @@ -0,0 +1,13 @@ +// RUN: %clang_cc1 -triple spirv-unknown-vulkan-compute -x hlsl -emit-llvm -disable-llvm-passes -o - -hlsl-entry main %s -verify + +struct S { + float f; +}; + +[[vk::push_constant]] S a; + +// expected-error@+1 {{cannot have more than one push constant block}} +[[vk::push_constant]] S b; + +[numthreads(1, 1, 1)] +void main() {} diff --git a/clang/test/SemaTemplate/address_space-dependent.cpp b/clang/test/SemaTemplate/address_space-dependent.cpp index e17bf60e6a200..cba21b416bb48 100644 --- a/clang/test/SemaTemplate/address_space-dependent.cpp +++ b/clang/test/SemaTemplate/address_space-dependent.cpp @@ -43,7 +43,7 @@ void neg() { template void tooBig() { - __attribute__((address_space(I))) int *bounds; // expected-error {{address space is larger than the maximum supported (8388582)}} + __attribute__((address_space(I))) int *bounds; // expected-error {{address space is larger than the maximum supported (8388581)}} } template @@ -101,7 +101,7 @@ int main() { car<1, 2, 3>(); // expected-note {{in instantiation of function template specialization 'car<1, 2, 3>' requested here}} HasASTemplateFields<1> HASTF; neg<-1>(); // expected-note {{in instantiation of function template specialization 'neg<-1>' requested here}} - correct<0x7FFFE6>(); + correct<0x7FFFE5>(); tooBig<8388650>(); // expected-note {{in instantiation of function template specialization 'tooBig<8388650L>' requested here}} __attribute__((address_space(1))) char *x; diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index 366f8cf36d75c..c05f236197c77 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -179,7 +179,10 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty] : DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_any_ty, llvm_i32_ty], [IntrNoMem]>; -def int_spv_resource_nonuniformindex + def int_spv_pushconstant_getpointer + : DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_any_ty], [IntrNoMem]>; + + def int_spv_resource_nonuniformindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem]>; // Read a value from the image buffer. It does not translate directly to a diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt index 79b76165cd57a..dfdae7c514757 100644 --- a/llvm/lib/Target/SPIRV/CMakeLists.txt +++ b/llvm/lib/Target/SPIRV/CMakeLists.txt @@ -51,6 +51,7 @@ add_llvm_target(SPIRVCodeGen SPIRVUtils.cpp SPIRVEmitNonSemanticDI.cpp SPIRVCBufferAccess.cpp + SPIRVPushConstantAccess.cpp LINK_COMPONENTS Analysis diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h index fa85ee781c249..d7e8bfc3f179e 100644 --- a/llvm/lib/Target/SPIRV/SPIRV.h +++ b/llvm/lib/Target/SPIRV/SPIRV.h @@ -22,6 +22,7 @@ class RegisterBankInfo; ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM); FunctionPass *createSPIRVStructurizerPass(); ModulePass *createSPIRVCBufferAccessLegacyPass(); +ModulePass *createSPIRVPushConstantAccessLegacyPass(SPIRVTargetMachine *TM); FunctionPass *createSPIRVMergeRegionExitTargetsPass(); FunctionPass *createSPIRVStripConvergenceIntrinsicsPass(); ModulePass *createSPIRVLegalizeImplicitBindingPass(); @@ -46,6 +47,7 @@ void initializeSPIRVPreLegalizerCombinerPass(PassRegistry &); void initializeSPIRVPostLegalizerPass(PassRegistry &); void initializeSPIRVStructurizerPass(PassRegistry &); void initializeSPIRVCBufferAccessLegacyPass(PassRegistry &); +void initializeSPIRVPushConstantAccessLegacyPass(PassRegistry &); void initializeSPIRVEmitIntrinsicsPass(PassRegistry &); void initializeSPIRVEmitNonSemanticDIPass(PassRegistry &); void initializeSPIRVLegalizePointerCastPass(PassRegistry &); diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 87ebee6a14eac..534cc99c9ac5c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -3363,6 +3363,15 @@ static SPIRVType *getVulkanBufferType(const TargetExtType *ExtensionType, return GR->getOrCreateVulkanBufferType(MIRBuilder, T, SC, IsWritable); } +static SPIRVType *getVulkanPushConstantType(const TargetExtType *ExtensionType, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + assert(ExtensionType->getNumTypeParameters() == 1 && + "Vulkan push constants have exactly one type as argument."); + auto *T = ExtensionType->getTypeParameter(0); + return GR->getOrCreateVulkanPushConstantType(MIRBuilder, T); +} + static SPIRVType *getLayoutType(const TargetExtType *ExtensionType, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { @@ -3448,6 +3457,8 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType, TargetType = getVulkanBufferType(BuiltinType, MIRBuilder, GR); } else if (Name == "spirv.Padding") { TargetType = GR->getOrCreatePaddingType(MIRBuilder); + } else if (Name == "spirv.PushConstant") { + TargetType = getVulkanPushConstantType(BuiltinType, MIRBuilder, GR); } else if (Name == "spirv.Layout") { TargetType = getLayoutType(BuiltinType, MIRBuilder, GR); } else { diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index ae81d38579c18..3d3dbd898a483 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -1467,6 +1467,27 @@ SPIRVGlobalRegistry::getOrCreatePaddingType(MachineIRBuilder &MIRBuilder) { return R; } +SPIRVType *SPIRVGlobalRegistry::getOrCreateVulkanPushConstantType( + MachineIRBuilder &MIRBuilder, Type *T) { + const auto SC = SPIRV::StorageClass::PushConstant; + + auto Key = SPIRV::irhandle_vkbuffer(T, SC, /* IsWritable= */ false); + if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF())) + return MI; + + // We need to get the SPIR-V type for the element here, so we can add the + // decoration to it. + auto *BlockType = getOrCreateSPIRVType( + T, MIRBuilder, SPIRV::AccessQualifier::None, + /* ExplicitLayoutRequired= */ true, /* EmitIr= */ false); + + buildOpDecorate(BlockType->defs().begin()->getReg(), MIRBuilder, + SPIRV::Decoration::Block, {}); + SPIRVType *R = BlockType; + add(Key, R); + return R; +} + SPIRVType *SPIRVGlobalRegistry::getOrCreateLayoutType( MachineIRBuilder &MIRBuilder, const TargetExtType *T, bool EmitIr) { auto Key = SPIRV::handle(T); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index e5a1a2aa8d70f..ac444c45d6f4f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -613,6 +613,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { SPIRVType *getOrCreatePaddingType(MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateVulkanPushConstantType(MachineIRBuilder &MIRBuilder, + Type *ElemType); + SPIRVType *getOrCreateLayoutType(MachineIRBuilder &MIRBuilder, const TargetExtType *T, bool EmitIr = false); diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index a2e29366dc4cc..4c66c8f1b7636 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -324,6 +324,8 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectImageWriteIntrinsic(MachineInstr &I) const; bool selectResourceGetPointer(Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectPushConstantGetPointer(Register &ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; bool selectResourceNonUniformIndex(Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const; @@ -3843,6 +3845,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_resource_getpointer: { return selectResourceGetPointer(ResVReg, ResType, I); } + case Intrinsic::spv_pushconstant_getpointer: { + return selectPushConstantGetPointer(ResVReg, ResType, I); + } case Intrinsic::spv_discard: { return selectDiscard(ResVReg, ResType, I); } @@ -4113,6 +4118,12 @@ bool SPIRVInstructionSelector::selectResourceGetPointer( .constrainAllUses(TII, TRI, RBI); } +bool SPIRVInstructionSelector::selectPushConstantGetPointer( + Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const { + MRI->replaceRegWith(ResVReg, I.getOperand(2).getReg()); + return true; +} + bool SPIRVInstructionSelector::selectResourceNonUniformIndex( Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const { Register ObjReg = I.getOperand(2).getReg(); diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index b5912c27316c9..2ec4a74d3f6cc 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -95,14 +95,15 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { const LLT p10 = LLT::pointer(10, PSize); // Private const LLT p11 = LLT::pointer(11, PSize); // StorageBuffer const LLT p12 = LLT::pointer(12, PSize); // Uniform + const LLT p13 = LLT::pointer(13, PSize); // PushConstant // TODO: remove copy-pasting here by using concatenation in some way. auto allPtrsScalarsAndVectors = { - p0, p1, p2, p3, p4, p5, p6, p7, p8, - p9, p10, p11, p12, s1, s8, s16, s32, s64, - v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, - v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, - v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + p0, p1, p2, p3, p4, p5, p6, p7, p8, + p9, p10, p11, p12, p13, s1, s8, s16, s32, + s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, + v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, + v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, @@ -133,10 +134,11 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; - auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, p3, p4, - p5, p6, p7, p8, p9, p10, p11, p12}; + auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, + p2, p3, p4, p5, p6, p7, + p8, p9, p10, p11, p12, p13}; - auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12}; + auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13}; auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors; diff --git a/llvm/lib/Target/SPIRV/SPIRVPassRegistry.def b/llvm/lib/Target/SPIRV/SPIRVPassRegistry.def index 1ce131fe7b1bf..9bd61627db765 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPassRegistry.def +++ b/llvm/lib/Target/SPIRV/SPIRVPassRegistry.def @@ -17,6 +17,7 @@ #define MODULE_PASS(NAME, CREATE_PASS) #endif MODULE_PASS("spirv-cbuffer-access", SPIRVCBufferAccess()) +MODULE_PASS("spirv-pushconstant-access", SPIRVPushConstantAccess(*static_cast(this))) #undef MODULE_PASS #ifndef FUNCTION_PASS diff --git a/llvm/lib/Target/SPIRV/SPIRVPushConstantAccess.cpp b/llvm/lib/Target/SPIRV/SPIRVPushConstantAccess.cpp new file mode 100644 index 0000000000000..809d690f68307 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVPushConstantAccess.cpp @@ -0,0 +1,107 @@ +//===- SPIRVPushConstantAccess.cpp - Translate CBuffer Loads ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass changes the types of all the globals in the PushConstant +// address space into a target extension type, and makes all references +// to this global go though a custom SPIR-V intrinsic. +// +// This allows the backend to properly lower the push constant struct type +// to a fully laid out type, and generate the proper OpAccessChain. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVPushConstantAccess.h" +#include "SPIRV.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" +#include "llvm/Frontend/HLSL/CBuffer.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/ReplaceConstant.h" + +#define DEBUG_TYPE "spirv-pushconstant-access" +using namespace llvm; + +static bool replacePushConstantAccesses(Module &M, SPIRVGlobalRegistry *GR) { + SmallVector PushConstants; + for (GlobalVariable &GV : M.globals()) { + if (GV.getAddressSpace() == + storageClassToAddressSpace(SPIRV::StorageClass::PushConstant)) + PushConstants.push_back(&GV); + } + + for (GlobalVariable *GV : PushConstants) { + Type *PCType = llvm::TargetExtType::get( + M.getContext(), "spirv.PushConstant", {GV->getValueType()}); + GlobalVariable *NewGV = new GlobalVariable( + M, PCType, GV->isConstant(), GV->getLinkage(), + /* initializer= */ nullptr, GV->getName(), + /* InsertBefore= */ GV, GV->getThreadLocalMode(), GV->getAddressSpace(), + GV->isExternallyInitialized()); + + SmallVector Users(GV->user_begin(), GV->user_end()); + for (User *U : Users) { + Instruction *I = dyn_cast(U); + if (!I) + continue; + + IRBuilder<> Builder(I); + Value *GetPointerCall = Builder.CreateIntrinsic( + NewGV->getType(), Intrinsic::spv_pushconstant_getpointer, {NewGV}); + GR->buildAssignPtr(Builder, GV->getValueType(), GetPointerCall); + + for (unsigned N = 0; N < I->getNumOperands(); ++N) { + if (I->getOperand(N) == GV) + I->setOperand(N, GetPointerCall); + } + } + + GV->eraseFromParent(); + } + + return true; +} + +PreservedAnalyses SPIRVPushConstantAccess::run(Module &M, + ModuleAnalysisManager &AM) { + const SPIRVSubtarget *ST = TM.getSubtargetImpl(); + SPIRVGlobalRegistry *GR = ST->getSPIRVGlobalRegistry(); + return replacePushConstantAccesses(M, GR) ? PreservedAnalyses::none() + : PreservedAnalyses::all(); +} + +namespace { +class SPIRVPushConstantAccessLegacy : public ModulePass { + SPIRVTargetMachine *TM = nullptr; + +public: + bool runOnModule(Module &M) override { + const SPIRVSubtarget *ST = TM->getSubtargetImpl(); + SPIRVGlobalRegistry *GR = ST->getSPIRVGlobalRegistry(); + return replacePushConstantAccesses(M, GR); + } + StringRef getPassName() const override { + return "SPIRV push constant Access"; + } + SPIRVPushConstantAccessLegacy(SPIRVTargetMachine *TM) + : ModulePass(ID), TM(TM) {} + + static char ID; // Pass identification. +}; +char SPIRVPushConstantAccessLegacy::ID = 0; +} // end anonymous namespace + +INITIALIZE_PASS(SPIRVPushConstantAccessLegacy, DEBUG_TYPE, + "SPIRV push constant Access", false, false) + +ModulePass * +llvm::createSPIRVPushConstantAccessLegacyPass(SPIRVTargetMachine *TM) { + return new SPIRVPushConstantAccessLegacy(TM); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVPushConstantAccess.h b/llvm/lib/Target/SPIRV/SPIRVPushConstantAccess.h new file mode 100644 index 0000000000000..53cfc62a3a8b9 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVPushConstantAccess.h @@ -0,0 +1,27 @@ +//===- SPIRVPushConstantAccess.h - Translate Push constant loads ----------*- +// C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVPUSHCONSTANTACCESS_H_ +#define LLVM_LIB_TARGET_SPIRV_SPIRVPUSHCONSTANTACCESS_H_ + +#include "SPIRVTargetMachine.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class SPIRVPushConstantAccess : public PassInfoMixin { + const SPIRVTargetMachine &TM; + +public: + SPIRVPushConstantAccess(const SPIRVTargetMachine &TM) : TM(TM) {} + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); +}; + +} // namespace llvm + +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVPUSHCONSTANTACCESS_H_ diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp index 10bbca225b20a..6deadcd451b26 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -15,6 +15,7 @@ #include "SPIRVCBufferAccess.h" #include "SPIRVGlobalRegistry.h" #include "SPIRVLegalizerInfo.h" +#include "SPIRVPushConstantAccess.h" #include "SPIRVStructurizerWrapper.h" #include "SPIRVTargetObjectFile.h" #include "SPIRVTargetTransformInfo.h" @@ -50,6 +51,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() { initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PR); initializeSPIRVStructurizerPass(PR); initializeSPIRVCBufferAccessLegacyPass(PR); + initializeSPIRVPushConstantAccessLegacyPass(PR); initializeSPIRVPreLegalizerCombinerPass(PR); initializeSPIRVLegalizePointerCastPass(PR); initializeSPIRVRegularizerPass(PR); @@ -211,6 +213,8 @@ void SPIRVPassConfig::addISelPrepare() { addPass(createSPIRVStripConvergenceIntrinsicsPass()); addPass(createSPIRVLegalizeImplicitBindingPass()); addPass(createSPIRVCBufferAccessLegacyPass()); + addPass( + createSPIRVPushConstantAccessLegacyPass(&getTM())); addPass(createSPIRVEmitIntrinsicsPass(&getTM())); if (TM.getSubtargetImpl()->isLogicalSPIRV()) addPass(createSPIRVLegalizePointerCastPass(&getTM())); diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 7fdb0fafa3719..a822cc8dd623c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -290,6 +290,8 @@ addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) { return SPIRV::StorageClass::StorageBuffer; case 12: return SPIRV::StorageClass::Uniform; + case 13: + return SPIRV::StorageClass::PushConstant; default: report_fatal_error("Unknown address space"); } diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index 45e211a1e5d2a..6247207f078fb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -264,6 +264,8 @@ storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) { return 11; case SPIRV::StorageClass::Uniform: return 12; + case SPIRV::StorageClass::PushConstant: + return 13; default: report_fatal_error("Unable to get address space id"); } diff --git a/llvm/test/CodeGen/SPIRV/llc-pipeline.ll b/llvm/test/CodeGen/SPIRV/llc-pipeline.ll index 6db375445e4a3..d6f4646dfb670 100644 --- a/llvm/test/CodeGen/SPIRV/llc-pipeline.ll +++ b/llvm/test/CodeGen/SPIRV/llc-pipeline.ll @@ -38,6 +38,7 @@ ; SPIRV-O0-NEXT: SPIRV strip convergent intrinsics ; SPIRV-O0-NEXT: SPIRV Legalize Implicit Binding ; SPIRV-O0-NEXT: SPIRV CBuffer Access +; SPIRV-O0-NEXT: SPIRV push constant Access ; SPIRV-O0-NEXT: SPIRV emit intrinsics ; SPIRV-O0-NEXT: FunctionPass Manager ; SPIRV-O0-NEXT: SPIRV legalize bitcast pass @@ -141,6 +142,7 @@ ; SPIRV-Opt-NEXT: SPIRV strip convergent intrinsics ; SPIRV-Opt-NEXT: SPIRV Legalize Implicit Binding ; SPIRV-Opt-NEXT: SPIRV CBuffer Access +; SPIRV-Opt-NEXT: SPIRV push constant Access ; SPIRV-Opt-NEXT: SPIRV emit intrinsics ; SPIRV-Opt-NEXT: FunctionPass Manager ; SPIRV-Opt-NEXT: SPIRV legalize bitcast pass diff --git a/llvm/test/CodeGen/SPIRV/vk-pushconstant-access.ll b/llvm/test/CodeGen/SPIRV/vk-pushconstant-access.ll new file mode 100644 index 0000000000000..079d9025ec4df --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/vk-pushconstant-access.ll @@ -0,0 +1,32 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s + +%struct.S = type <{ float }> + +; CHECK-DAG: %[[#F32:]] = OpTypeFloat 32 +; CHECK-DAG: %[[#UINT:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#S_S:]] = OpTypeStruct %[[#F32]] + +; CHECK-DAG: %[[#PTR_PCS_F:]] = OpTypePointer PushConstant %[[#F32]] +; CHECK-DAG: %[[#PTR_PCS_S:]] = OpTypePointer PushConstant %[[#S_S]] + + +; CHECK-DAG: OpMemberDecorate %[[#S_S]] 0 Offset 0 +; CHECK-DAG: OpDecorate %[[#S_S]] Block + + +@pcs = external hidden addrspace(13) externally_initialized global %struct.S, align 1 +; CHECK: %[[#PCS:]] = OpVariable %[[#PTR_PCS_S]] PushConstant + +define void @main() #1 { +entry: + %0 = call token @llvm.experimental.convergence.entry() + %1 = alloca float, align 4 + %2 = load float, ptr addrspace(13) @pcs, align 1 + store float %2, ptr %1 + ret void +} + +declare token @llvm.experimental.convergence.entry() #2 + +attributes #1 = { convergent noinline norecurse optnone "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } +attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) } diff --git a/llvm/test/CodeGen/SPIRV/vk-pushconstant-layout.ll b/llvm/test/CodeGen/SPIRV/vk-pushconstant-layout.ll new file mode 100644 index 0000000000000..3c5391b532fd1 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/vk-pushconstant-layout.ll @@ -0,0 +1,40 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s +; XFAIL: * +; FIXME(168401): fix the offset of last struct S field. + +%struct.T = type { [3 x <2 x float>] } +%struct.S = type <{ float, <3 x float>, %struct.T }> + +; CHECK-DAG: %[[#PTR_PCS:]] = OpTypePointer PushConstant %[[#S_S:]] + +; CHECK-DAG: %[[#F32:]] = OpTypeFloat 32 +; CHECK-DAG: %[[#V3F32:]] = OpTypeVector %[[#F32]] 3 +; CHECK-DAG: %[[#V2F32:]] = OpTypeVector %[[#F32]] 2 +; CHECK-DAG: %[[#UINT:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#UINT_3:]] = OpConstant %[[#UINT]] 3 + +; CHECK-DAG: %[[#S_S]] = OpTypeStruct %[[#F32]] %[[#V3F32]] %[[#S_T:]] +; CHECK-DAG: %[[#S_T]] = OpTypeStruct %[[#ARR:]] +; CHECK-DAG: %[[#ARR]] = OpTypeArray %[[#V2F32]] %[[#UINT_3]] + +; CHECK-DAG: OpMemberDecorate %[[#S_T]] 0 Offset 0 +; CHECK-DAG: OpMemberDecorate %[[#S_S]] 0 Offset 0 +; CHECK-DAG: OpMemberDecorate %[[#S_S]] 1 Offset 4 +; CHECK-DAG: OpMemberDecorate %[[#S_S]] 2 Offset 16 +; CHECK-DAG: OpDecorate %[[#S_S]] Block +; CHECK-DAG: OpDecorate %[[#ARR]] ArrayStride 8 + + +@pcs = external hidden addrspace(13) externally_initialized global %struct.S, align 1 +; CHECK: %[[#PCS:]] = OpVariable %[[#PTR_PCS]] PushConstant + +define void @main() #1 { +entry: + %0 = call token @llvm.experimental.convergence.entry() + ret void +} + +declare token @llvm.experimental.convergence.entry() #2 + +attributes #1 = { convergent noinline norecurse optnone "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } +attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }