Skip to content

Commit 8a40d08

Browse files
authored
[HLSL][SPIR-V] Implement vk::location for inputs (#169479)
This commit adds the support for vk::location attribute which can be applied to input and output variables. As in/inout parameters are not supported yet, vk::location on such parameters is not tested. As implemented in DXC, vk::location has the following rules: - input and outputs are handled independently. - input/output lowered to a SPIR-V builtins are not using the assigned vk::location and thus ignored. - input/output lowered to a Location decoration must either all have explicit locations, or none. Mixing is not allowed (except with builtins).
1 parent c2a0350 commit 8a40d08

17 files changed

+354
-44
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5172,6 +5172,14 @@ def HLSLVkConstantId : InheritableAttr {
51725172
let Documentation = [VkConstantIdDocs];
51735173
}
51745174

5175+
def HLSLVkLocation : HLSLAnnotationAttr {
5176+
let Spellings = [CXX11<"vk", "location">];
5177+
let Args = [IntArgument<"Location">];
5178+
let Subjects = SubjectList<[ParmVar, Field, Function], ErrorDiag>;
5179+
let LangOpts = [HLSL];
5180+
let Documentation = [HLSLVkLocationDocs];
5181+
}
5182+
51755183
def RandomizeLayout : InheritableAttr {
51765184
let Spellings = [GCC<"randomize_layout">];
51775185
let Subjects = SubjectList<[Record]>;

clang/include/clang/Basic/AttrDocs.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8981,6 +8981,18 @@ The descriptor set is optional and defaults to 0 if not provided.
89818981
}];
89828982
}
89838983

8984+
def HLSLVkLocationDocs : Documentation {
8985+
let Category = DocCatVariable;
8986+
let Content = [{
8987+
Attribute used for specifying the location number for the stage input/output
8988+
variables. Allowed on function parameters, function returns, and struct
8989+
fields. This parameter has no effect when used outside of an entrypoint
8990+
parameter/parameter field/return value.
8991+
8992+
This attribute maps to the 'Location' SPIR-V decoration.
8993+
}];
8994+
}
8995+
89848996
def WebAssemblyFuncrefDocs : Documentation {
89858997
let Category = DocCatType;
89868998
let Content = [{

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13238,6 +13238,9 @@ def err_hlsl_semantic_index_overlap : Error<"semantic index overlap %0">;
1323813238
def err_hlsl_semantic_unsupported_iotype_for_stage
1323913239
: Error<"semantic %0 is unsupported in %2 shaders as %1, requires one of "
1324013240
"the following: %3">;
13241+
def err_hlsl_semantic_partial_explicit_indexing
13242+
: Error<"partial explicit stage input location assignment via "
13243+
"vk::location(X) unsupported">;
1324113244

1324213245
def warn_hlsl_user_defined_type_missing_member: Warning<"binding type '%select{t|u|b|s|c}0' only applies to types containing %select{SRV resources|UAV resources|constant buffer resources|sampler state|numeric types}0">, InGroup<LegacyConstantRegisterBinding>;
1324313246
def err_hlsl_binding_type_mismatch: Error<"binding type '%select{t|u|b|s|c}0' only applies to %select{SRV resources|UAV resources|constant buffer resources|sampler state|numeric variables in the global scope}0">;

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ class SemaHLSL : public SemaBase {
168168
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
169169
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
170170
void handleVkBindingAttr(Decl *D, const ParsedAttr &AL);
171+
void handleVkLocationAttr(Decl *D, const ParsedAttr &AL);
171172
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
172173
void handleShaderAttr(Decl *D, const ParsedAttr &AL);
173174
void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL);
@@ -236,17 +237,33 @@ class SemaHLSL : public SemaBase {
236237

237238
IdentifierInfo *RootSigOverrideIdent = nullptr;
238239

240+
// Information about the current subtree being flattened.
239241
struct SemanticInfo {
240242
HLSLParsedSemanticAttr *Semantic;
241-
std::optional<uint32_t> Index;
243+
std::optional<uint32_t> Index = std::nullopt;
242244
};
243245

246+
// Bitmask used to recall if the current semantic subtree is
247+
// input, output or inout.
244248
enum IOType {
245249
In = 0b01,
246250
Out = 0b10,
247251
InOut = 0b11,
248252
};
249253

254+
// The context shared by all semantics with the same IOType during
255+
// flattening.
256+
struct SemanticContext {
257+
// Present if any semantic sharing the same IO type has an explicit or
258+
// implicit SPIR-V location index assigned.
259+
std::optional<bool> UsesExplicitVkLocations = std::nullopt;
260+
// The set of semantics found to be active during flattening. Used to detect
261+
// index collisions.
262+
llvm::StringSet<> ActiveSemantics = {};
263+
// The IOType of this semantic set.
264+
IOType CurrentIOType;
265+
};
266+
250267
struct SemanticStageInfo {
251268
llvm::Triple::EnvironmentType Stage;
252269
IOType AllowedIOTypesMask;
@@ -259,19 +276,17 @@ class SemaHLSL : public SemaBase {
259276

260277
void checkSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
261278
const HLSLAppliedSemanticAttr *SemanticAttr,
262-
bool IsInput);
279+
const SemanticContext &SC);
263280

264281
bool determineActiveSemanticOnScalar(FunctionDecl *FD,
265282
DeclaratorDecl *OutputDecl,
266283
DeclaratorDecl *D,
267284
SemanticInfo &ActiveSemantic,
268-
llvm::StringSet<> &ActiveSemantics,
269-
bool IsInput);
285+
SemanticContext &SC);
270286

271287
bool determineActiveSemantic(FunctionDecl *FD, DeclaratorDecl *OutputDecl,
272288
DeclaratorDecl *D, SemanticInfo &ActiveSemantic,
273-
llvm::StringSet<> &ActiveSemantics,
274-
bool IsInput);
289+
SemanticContext &SC);
275290

276291
void processExplicitBindingsOnDecl(VarDecl *D);
277292

@@ -282,7 +297,7 @@ class SemaHLSL : public SemaBase {
282297
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
283298

284299
void diagnoseSemanticStageMismatch(
285-
const Attr *A, llvm::Triple::EnvironmentType Stage, bool IsInput,
300+
const Attr *A, llvm::Triple::EnvironmentType Stage, IOType CurrentIOType,
286301
std::initializer_list<SemanticStageInfo> AllowedStages);
287302

288303
uint32_t getNextImplicitBindingOrderID() {

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -582,20 +582,22 @@ static llvm::Value *createSPIRVLocationLoad(IRBuilder<> &B, llvm::Module &M,
582582
return B.CreateLoad(Ty, GV);
583583
}
584584

585-
llvm::Value *
586-
CGHLSLRuntime::emitSPIRVUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
587-
HLSLAppliedSemanticAttr *Semantic,
588-
std::optional<unsigned> Index) {
585+
llvm::Value *CGHLSLRuntime::emitSPIRVUserSemanticLoad(
586+
llvm::IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
587+
HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
589588
Twine BaseName = Twine(Semantic->getAttrName()->getName());
590589
Twine VariableName = BaseName.concat(Twine(Index.value_or(0)));
591590

592591
unsigned Location = SPIRVLastAssignedInputSemanticLocation;
592+
if (auto *L = Decl->getAttr<HLSLVkLocationAttr>())
593+
Location = L->getLocation();
593594

594595
// DXC completely ignores the semantic/index pair. Location are assigned from
595596
// the first semantic to the last.
596597
llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Type);
597598
unsigned ElementCount = AT ? AT->getNumElements() : 1;
598599
SPIRVLastAssignedInputSemanticLocation += ElementCount;
600+
599601
return createSPIRVLocationLoad(B, CGM.getModule(), Type, Location,
600602
VariableName.str());
601603
}
@@ -616,10 +618,14 @@ static void createSPIRVLocationStore(IRBuilder<> &B, llvm::Module &M,
616618

617619
void CGHLSLRuntime::emitSPIRVUserSemanticStore(
618620
llvm::IRBuilder<> &B, llvm::Value *Source,
619-
HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
621+
const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic,
622+
std::optional<unsigned> Index) {
620623
Twine BaseName = Twine(Semantic->getAttrName()->getName());
621624
Twine VariableName = BaseName.concat(Twine(Index.value_or(0)));
625+
622626
unsigned Location = SPIRVLastAssignedOutputSemanticLocation;
627+
if (auto *L = Decl->getAttr<HLSLVkLocationAttr>())
628+
Location = L->getLocation();
623629

624630
// DXC completely ignores the semantic/index pair. Location are assigned from
625631
// the first semantic to the last.
@@ -671,7 +677,7 @@ llvm::Value *CGHLSLRuntime::emitUserSemanticLoad(
671677
IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
672678
HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
673679
if (CGM.getTarget().getTriple().isSPIRV())
674-
return emitSPIRVUserSemanticLoad(B, Type, Semantic, Index);
680+
return emitSPIRVUserSemanticLoad(B, Type, Decl, Semantic, Index);
675681

676682
if (CGM.getTarget().getTriple().isDXIL())
677683
return emitDXILUserSemanticLoad(B, Type, Semantic, Index);
@@ -684,7 +690,7 @@ void CGHLSLRuntime::emitUserSemanticStore(IRBuilder<> &B, llvm::Value *Source,
684690
HLSLAppliedSemanticAttr *Semantic,
685691
std::optional<unsigned> Index) {
686692
if (CGM.getTarget().getTriple().isSPIRV())
687-
return emitSPIRVUserSemanticStore(B, Source, Semantic, Index);
693+
return emitSPIRVUserSemanticStore(B, Source, Decl, Semantic, Index);
688694

689695
if (CGM.getTarget().getTriple().isDXIL())
690696
return emitDXILUserSemanticStore(B, Source, Semantic, Index);
@@ -693,8 +699,9 @@ void CGHLSLRuntime::emitUserSemanticStore(IRBuilder<> &B, llvm::Value *Source,
693699
}
694700

695701
llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
696-
IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
697-
HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
702+
IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
703+
const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic,
704+
std::optional<unsigned> Index) {
698705

699706
std::string SemanticName = Semantic->getAttrName()->getName().upper();
700707
if (SemanticName == "SV_GROUPINDEX") {
@@ -730,8 +737,12 @@ llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
730737
return buildVectorInput(B, GroupIDIntrinsic, Type);
731738
}
732739

740+
const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
741+
assert(ShaderAttr && "Entry point has no shader attribute");
742+
llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
743+
733744
if (SemanticName == "SV_POSITION") {
734-
if (CGM.getTriple().getEnvironment() == Triple::EnvironmentType::Pixel) {
745+
if (ST == Triple::EnvironmentType::Pixel) {
735746
if (CGM.getTarget().getTriple().isSPIRV())
736747
return createSPIRVBuiltinLoad(B, CGM.getModule(), Type,
737748
Semantic->getAttrName()->getName(),
@@ -740,7 +751,7 @@ llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
740751
return emitDXILUserSemanticLoad(B, Type, Semantic, Index);
741752
}
742753

743-
if (CGM.getTriple().getEnvironment() == Triple::EnvironmentType::Vertex) {
754+
if (ST == Triple::EnvironmentType::Vertex) {
744755
return emitUserSemanticLoad(B, Type, Decl, Semantic, Index);
745756
}
746757
}
@@ -798,7 +809,7 @@ llvm::Value *CGHLSLRuntime::handleScalarSemanticLoad(
798809

799810
std::optional<unsigned> Index = Semantic->getSemanticIndex();
800811
if (Semantic->getAttrName()->getName().starts_with_insensitive("SV_"))
801-
return emitSystemSemanticLoad(B, Type, Decl, Semantic, Index);
812+
return emitSystemSemanticLoad(B, FD, Type, Decl, Semantic, Index);
802813
return emitUserSemanticLoad(B, Type, Decl, Semantic, Index);
803814
}
804815

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ class CGHLSLRuntime {
179179
protected:
180180
CodeGenModule &CGM;
181181

182-
llvm::Value *emitSystemSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
182+
llvm::Value *emitSystemSemanticLoad(llvm::IRBuilder<> &B,
183+
const FunctionDecl *FD, llvm::Type *Type,
183184
const clang::DeclaratorDecl *Decl,
184185
HLSLAppliedSemanticAttr *Semantic,
185186
std::optional<unsigned> Index);
@@ -278,6 +279,7 @@ class CGHLSLRuntime {
278279
HLSLResourceBindingAttr *RBA);
279280

280281
llvm::Value *emitSPIRVUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
282+
const clang::DeclaratorDecl *Decl,
281283
HLSLAppliedSemanticAttr *Semantic,
282284
std::optional<unsigned> Index);
283285
llvm::Value *emitDXILUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
@@ -289,6 +291,7 @@ class CGHLSLRuntime {
289291
std::optional<unsigned> Index);
290292

291293
void emitSPIRVUserSemanticStore(llvm::IRBuilder<> &B, llvm::Value *Source,
294+
const clang::DeclaratorDecl *Decl,
292295
HLSLAppliedSemanticAttr *Semantic,
293296
std::optional<unsigned> Index);
294297
void emitDXILUserSemanticStore(llvm::IRBuilder<> &B, llvm::Value *Source,

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7703,6 +7703,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
77037703
case ParsedAttr::AT_HLSLUnparsedSemantic:
77047704
S.HLSL().handleSemanticAttr(D, AL);
77057705
break;
7706+
case ParsedAttr::AT_HLSLVkLocation:
7707+
S.HLSL().handleVkLocationAttr(D, AL);
7708+
break;
77067709

77077710
case ParsedAttr::AT_AbiTag:
77087711
handleAbiTagAttr(S, D, AL);

0 commit comments

Comments
 (0)