Skip to content

Commit a11702e

Browse files
authored
[SPIRV] Add the derivative group execution mode only on shader types that allow it. (microsoft#7628)
DXC allows user to use decrivative instruction in shader models that do not allow it, but they must be dead code that will be removed. However, when we see a derivative instruction in the SPIR-V backend that is not in a pixel shader we assume it need the DerivativeGroup execution mode, and we fail when we try to add it to a vertex shader. To allow out implementation to match DXIL, we will not assume we can add the execution mode. We will only add it for shader that we know can use is, and skip the other. If the derivative instruction is not removed during optimizations, there will be a validation error. While fixing this, we observed another bug that is fixed at the same time since they are closely related. The TaskNV and TaskEXT shader types do not have the same id, and the SPV_KHR_compute_shader_derivatives does not work with the NV mesh shader extension. That was fixed up. Fixes microsoft#7478
1 parent d751c82 commit a11702e

File tree

4 files changed

+115
-19
lines changed

4 files changed

+115
-19
lines changed

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4399,9 +4399,7 @@ SpirvEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr,
43994399
spvBuilder.createImageQuery(spv::Op::OpImageQueryLod, queryResultType,
44004400
expr->getExprLoc(), sampledImage, coordinate);
44014401

4402-
if (spvContext.isCS() || spvContext.isNode()) {
4403-
addDerivativeGroupExecutionMode();
4404-
}
4402+
addDerivativeGroupExecutionMode();
44054403
// The first component of the float2 contains the mipmap array layer.
44064404
// The second component of the float2 represents the unclamped lod.
44074405
return spvBuilder.createCompositeExtract(astContext.FloatTy, query,
@@ -5780,9 +5778,7 @@ SpirvEmitter::processTextureSampleGather(const CXXMemberCallExpr *expr,
57805778

57815779
const auto retType = expr->getDirectCallee()->getReturnType();
57825780
if (isSample) {
5783-
if (spvContext.isCS() || spvContext.isNode()) {
5784-
addDerivativeGroupExecutionMode();
5785-
}
5781+
addDerivativeGroupExecutionMode();
57865782
return createImageSample(retType, imageType, image, sampler, coordinate,
57875783
/*compareVal*/ nullptr, /*bias*/ nullptr,
57885784
/*lod*/ nullptr, std::make_pair(nullptr, nullptr),
@@ -5870,9 +5866,9 @@ SpirvEmitter::processTextureSampleBiasLevel(const CXXMemberCallExpr *expr,
58705866

58715867
const auto retType = expr->getDirectCallee()->getReturnType();
58725868

5873-
if (!lod && (spvContext.isCS() || spvContext.isNode())) {
5869+
if (!lod)
58745870
addDerivativeGroupExecutionMode();
5875-
}
5871+
58765872
return createImageSample(
58775873
retType, imageType, image, sampler, coordinate,
58785874
/*compareVal*/ nullptr, bias, lod, std::make_pair(nullptr, nullptr),
@@ -5992,9 +5988,7 @@ SpirvEmitter::processTextureSampleCmp(const CXXMemberCallExpr *expr) {
59925988
const auto retType = expr->getDirectCallee()->getReturnType();
59935989
const auto imageType = imageExpr->getType();
59945990

5995-
if (spvContext.isCS()) {
5996-
addDerivativeGroupExecutionMode();
5997-
}
5991+
addDerivativeGroupExecutionMode();
59985992

59995993
return createImageSample(
60005994
retType, imageType, image, sampler, coordinate, compareVal,
@@ -6047,9 +6041,7 @@ SpirvEmitter::processTextureSampleCmpBias(const CXXMemberCallExpr *expr) {
60476041
const auto retType = expr->getDirectCallee()->getReturnType();
60486042
const auto imageType = imageExpr->getType();
60496043

6050-
if (spvContext.isCS()) {
6051-
addDerivativeGroupExecutionMode();
6052-
}
6044+
addDerivativeGroupExecutionMode();
60536045

60546046
return createImageSample(
60556047
retType, imageType, image, sampler, coordinate, compareVal, bias,
@@ -9782,8 +9774,7 @@ SpirvInstruction *SpirvEmitter::processDerivativeIntrinsic(
97829774
QualType returnType = arg->getAstResultType();
97839775
assert(isFloatOrVecOfFloatType(returnType));
97849776

9785-
if (!spvContext.isPS())
9786-
addDerivativeGroupExecutionMode();
9777+
addDerivativeGroupExecutionMode();
97879778
needsLegalization = true;
97889779

97899780
QualType B32Type = astContext.FloatTy;
@@ -12512,8 +12503,7 @@ SpirvInstruction *SpirvEmitter::processIntrinsicUsingSpirvInst(
1251212503
case spv::Op::OpFwidth:
1251312504
case spv::Op::OpFwidthFine:
1251412505
case spv::Op::OpFwidthCoarse:
12515-
if (spvContext.isCS() || spvContext.isNode())
12516-
addDerivativeGroupExecutionMode();
12506+
addDerivativeGroupExecutionMode();
1251712507
needsLegalization = true;
1251812508
break;
1251912509
default:
@@ -15771,8 +15761,29 @@ bool SpirvEmitter::spirvToolsValidate(std::vector<uint32_t> *mod,
1577115761
return tools.Validate(mod->data(), mod->size(), options);
1577215762
}
1577315763

15764+
static bool canUseDerivativeGroupExecutionMode(SpirvContext::ShaderModelKind sm,
15765+
bool usingEXTMeshShader) {
15766+
switch (sm) {
15767+
case SpirvContext::ShaderModelKind::Compute:
15768+
case SpirvContext::ShaderModelKind::Node:
15769+
return true;
15770+
15771+
// The KHR extension that allows derivative instruction in mesh and task
15772+
// (amplification) shader does not work with SPV_NV_mesh_shader extesion.
15773+
case SpirvContext::ShaderModelKind::Mesh:
15774+
case SpirvContext::ShaderModelKind::Amplification:
15775+
return usingEXTMeshShader;
15776+
default:
15777+
return false;
15778+
}
15779+
}
15780+
1577415781
void SpirvEmitter::addDerivativeGroupExecutionMode() {
15775-
assert(spvContext.isCS());
15782+
bool usingEXTMeshShader =
15783+
featureManager.isExtensionEnabled(Extension::EXT_mesh_shader);
15784+
SpirvContext::ShaderModelKind sm = spvContext.getCurrentShaderModelKind();
15785+
if (!canUseDerivativeGroupExecutionMode(sm, usingEXTMeshShader))
15786+
return;
1577615787

1577715788
SpirvExecutionMode *numThreadsEm =
1577815789
cast<SpirvExecutionMode>(spvBuilder.getModule()->findExecutionMode(
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %dxc -T as_6_5 -E main -fspv-target-env=vulkan1.3 %s -spirv | FileCheck %s --check-prefix=VK13
2+
// RUN: %dxc -T as_6_5 -E main -fspv-target-env=vulkan1.1 -Vd %s -spirv | FileCheck %s --check-prefix=VK11
3+
4+
// VK13-DAG: OpCapability ComputeDerivativeGroupLinearKHR
5+
// VK13-DAG: OpCapability DerivativeControl
6+
// VK13-DAG: OpCapability MeshShadingEXT
7+
// VK13-DAG: OpExtension "SPV_EXT_mesh_shader"
8+
// VK13-DAG: OpExtension "SPV_KHR_compute_shader_derivatives"
9+
// VK13: OpEntryPoint TaskEXT %main "main"
10+
// VK13: OpExecutionMode %main DerivativeGroupLinearKHR
11+
12+
// VK11-DAG: OpExtension "SPV_NV_mesh_shader"
13+
// VK11: OpEntryPoint TaskNV %main "main"
14+
// VK11-NOT: OpExecutionMode %main DerivativeGroup
15+
16+
struct AmplificationPayload
17+
{
18+
float4 value;
19+
};
20+
21+
groupshared AmplificationPayload payload;
22+
23+
[numthreads(4, 1, 1)]
24+
void main(in uint tid : SV_GroupThreadID, in uint gtid : SV_GroupID)
25+
{
26+
payload.value = ddx_coarse(float4(tid, 0, 0, 0));
27+
DispatchMesh(1,1,1, payload);
28+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: %dxc -T ms_6_5 -E main -fspv-target-env=vulkan1.3 %s -spirv | FileCheck %s --check-prefix=VK13
2+
// RUN: %dxc -T ms_6_5 -E main -fspv-target-env=vulkan1.1 -Vd %s -spirv | FileCheck %s --check-prefix=VK11
3+
4+
// VK13-DAG: OpCapability ComputeDerivativeGroupLinearKHR
5+
// VK13-DAG: OpCapability DerivativeControl
6+
// vk13-DAG: OpCapability MeshShadingEXT
7+
// VK13-DAG: OpExtension "SPV_EXT_mesh_shader"
8+
// VK13-DAG: OpExtension "SPV_KHR_compute_shader_derivatives"
9+
// VK13: OpEntryPoint MeshEXT %main "main"
10+
// VK13: OpExecutionMode %main DerivativeGroupLinearKHR
11+
12+
// VK11-DAG: OpExtension "SPV_NV_mesh_shader"
13+
// VK11: OpEntryPoint MeshNV %main "main"
14+
// VK11-NOT: OpExecutionMode %main DerivativeGroup
15+
16+
struct VSOut
17+
{
18+
float4 pos : SV_Position;
19+
};
20+
21+
[numthreads(4, 1, 1)]
22+
[outputtopology("triangle")]
23+
void main(in uint tid : SV_GroupThreadID, out vertices VSOut verts[3], out indices uint3 tris[1])
24+
{
25+
SetMeshOutputCounts(3, 1);
26+
27+
float4 val = ddx_coarse(float4(tid, 0, 0, 0));
28+
29+
verts[0].pos = val;
30+
verts[1].pos = val + float4(0,1,0,0);
31+
verts[2].pos = val + float4(1,0,0,0);
32+
33+
tris[0] = uint3(0,1,2);
34+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: %dxc -T vs_6_0 -E main -DCOND=false -fspv-target-env=vulkan1.3 %s -spirv | FileCheck %s
2+
// CHECK-NOT: OpCapability DerivativeControl
3+
// CHECK-NOT: OpExtension "SPV_KHR_compute_shader_derivatives"
4+
5+
// RUN: not %dxc -T vs_6_0 -E main -DCOND=true -fspv-target-env=vulkan1.3 %s -spirv 2>&1 | FileCheck %s -check-prefix=ERROR
6+
// ERROR: generated SPIR-V is invalid:
7+
// ERROR-NEXT: Derivative instructions require Fragment, GLCompute, MeshEXT or TaskEXT execution model: DPdx
8+
9+
struct VSOut
10+
{
11+
float4 pos : SV_Position;
12+
};
13+
14+
VSOut main(float4 pos : POSITION)
15+
{
16+
VSOut output;
17+
output.pos = pos;
18+
if (COND)
19+
{
20+
output.pos += ddx(pos);
21+
}
22+
return output;
23+
}

0 commit comments

Comments
 (0)