Skip to content

Commit 53b9ae1

Browse files
author
devsh
committed
Merge remote-tracking branch 'remotes/origin/ray_tracing_spirv_intrinsics'
2 parents f67c026 + e1972c7 commit 53b9ae1

File tree

6 files changed

+54
-20
lines changed

6 files changed

+54
-20
lines changed
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
#include "common.hlsl"
22

3+
#include "nbl/builtin/hlsl/spirv_intrinsics/raytracing.hlsl"
4+
5+
using namespace nbl::hlsl;
6+
37
[[vk::push_constant]] SPushConstants pc;
48

59
[shader("anyhit")]
610
void main(inout PrimaryPayload payload, in BuiltInTriangleIntersectionAttributes attribs)
711
{
8-
const int instID = InstanceID();
12+
const int instID = spirv::InstanceCustomIndexKHR;
913
const STriangleGeomInfo geom = vk::RawBufferLoad < STriangleGeomInfo > (pc.triangleGeomInfoBuffer + instID * sizeof(STriangleGeomInfo));
1014

1115
const uint32_t bitpattern = payload.pcg();
16+
// Cannot use spirv::ignoreIntersectionKHR and spirv::terminateRayKHR due to https://github.com/microsoft/DirectXShaderCompiler/issues/7279
1217
if (geom.material.alphaTest(bitpattern))
1318
IgnoreHit();
1419
}

71_RayTracingPipeline/app_resources/raytrace.rchit.hlsl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
#include "common.hlsl"
22

3+
#include "nbl/builtin/hlsl/spirv_intrinsics/core.hlsl"
4+
#include "nbl/builtin/hlsl/spirv_intrinsics/raytracing.hlsl"
35
#include "nbl/builtin/hlsl/bda/__ptr.hlsl"
46

7+
using namespace nbl::hlsl;
8+
59
[[vk::push_constant]] SPushConstants pc;
610

711
float3 calculateNormals(int primID, STriangleGeomInfo geom, float2 bary)
@@ -74,16 +78,16 @@ float3 calculateNormals(int primID, STriangleGeomInfo geom, float2 bary)
7478
[shader("closesthit")]
7579
void main(inout PrimaryPayload payload, in BuiltInTriangleIntersectionAttributes attribs)
7680
{
77-
const int primID = PrimitiveIndex();
78-
const int instanceCustomIndex = InstanceIndex();
79-
const int geometryIndex = GeometryIndex();
81+
const int primID = spirv::PrimitiveId;
82+
const int instanceCustomIndex = spirv::InstanceCustomIndexKHR;
83+
const int geometryIndex = spirv::RayGeometryIndexKHR;
8084
const STriangleGeomInfo geom = vk::RawBufferLoad < STriangleGeomInfo > (pc.triangleGeomInfoBuffer + (instanceCustomIndex + geometryIndex) * sizeof(STriangleGeomInfo));
8185
const float32_t3 vertexNormal = calculateNormals(primID, geom, attribs.barycentrics);
82-
const float32_t3 worldNormal = normalize(mul(vertexNormal, WorldToObject3x4()).xyz);
86+
const float32_t3 worldNormal = normalize(mul(vertexNormal, transpose(spirv::WorldToObjectKHR)).xyz);
8387

8488
payload.materialId = MaterialId::createTriangle(instanceCustomIndex);
8589

8690
payload.worldNormal = worldNormal;
87-
payload.rayDistance = RayTCurrent();
91+
payload.rayDistance = spirv::RayTmaxKHR;
8892

8993
}

71_RayTracingPipeline/app_resources/raytrace.rgen.hlsl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
static const int32_t s_sampleCount = 10;
1010
static const float32_t3 s_clearColor = float32_t3(0.3, 0.3, 0.8);
1111

12+
using namespace nbl::hlsl;
13+
1214
[[vk::push_constant]] SPushConstants pc;
1315

1416
[[vk::binding(0, 0)]] RaytracingAccelerationStructure topLevelAS;
@@ -23,8 +25,8 @@ float32_t nextRandomUnorm(inout nbl::hlsl::Xoroshiro64StarStar rnd)
2325
[shader("raygeneration")]
2426
void main()
2527
{
26-
const uint32_t3 launchID = DispatchRaysIndex();
27-
const uint32_t3 launchSize = DispatchRaysDimensions();
28+
const uint32_t3 launchID = spirv::LaunchIdKHR;
29+
const uint32_t3 launchSize = spirv::LaunchSizeKHR;
2830
const uint32_t2 coords = launchID.xy;
2931

3032
const uint32_t seed1 = nbl::hlsl::random::Pcg::create(pc.frameCounter)();
@@ -53,9 +55,11 @@ void main()
5355
rayDesc.TMin = 0.01;
5456
rayDesc.TMax = 10000.0;
5557

58+
[[vk::ext_storage_class(spv::StorageClassRayPayloadKHR)]]
5659
PrimaryPayload payload;
5760
payload.pcg = PrimaryPayload::generator_t::create(rnd());
58-
TraceRay(topLevelAS, RAY_FLAG_NONE, 0xff, ERT_PRIMARY, 0, EMT_PRIMARY, rayDesc, payload);
61+
spirv::traceRayKHR(topLevelAS, spv::RayFlagsMaskNone, 0xff, ERT_PRIMARY, 0, EMT_PRIMARY, rayDesc.Origin, rayDesc.TMin, rayDesc.Direction, rayDesc.TMax, payload);
62+
// TraceRay(topLevelAS, RAY_FLAG_NONE, 0xff, ERT_PRIMARY, 0, EMT_PRIMARY, rayDesc, payload);
5963

6064
const float32_t rayDistance = payload.rayDistance;
6165
if (rayDistance < 0)
@@ -67,9 +71,10 @@ void main()
6771
const float32_t3 worldPosition = pc.camPos + (camDirection * rayDistance);
6872

6973
// make sure to call with least live state
74+
[[vk::ext_storage_class(spv::StorageClassCallableDataKHR)]]
7075
RayLight cLight;
7176
cLight.inHitPosition = worldPosition;
72-
CallShader(pc.light.type, cLight);
77+
spirv::executeCallable(pc.light.type, cLight);
7378

7479
const float32_t3 worldNormal = payload.worldNormal;
7580

@@ -97,12 +102,16 @@ void main()
97102
rayDesc.TMin = 0.01;
98103
rayDesc.TMax = cLight.outLightDistance;
99104

105+
[[vk::ext_storage_class(spv::StorageClassRayPayloadKHR)]]
100106
OcclusionPayload occlusionPayload;
101107
// negative means its a hit, the miss shader will flip it back around to positive
102108
occlusionPayload.attenuation = -1.f;
103109
// abuse of miss shader to mean "not hit shader" solves us having to call closest hit shaders
104-
uint32_t shadowRayFlags = RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH | RAY_FLAG_SKIP_CLOSEST_HIT_SHADER;
105-
TraceRay(topLevelAS, shadowRayFlags, 0xFF, ERT_OCCLUSION, 0, EMT_OCCLUSION, rayDesc, occlusionPayload);
110+
uint32_t shadowRayFlags = spv::RayFlagsTerminateOnFirstHitKHRMask | spv::RayFlagsSkipClosestHitShaderKHRMask;
111+
spirv::traceRayKHR(topLevelAS, shadowRayFlags, 0xFF, ERT_OCCLUSION, 0, EMT_OCCLUSION, rayDesc.Origin, rayDesc.TMin, rayDesc.Direction, rayDesc.TMax, occlusionPayload);
112+
113+
// uint32_t shadowRayFlags = RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH | RAY_FLAG_SKIP_CLOSEST_HIT_SHADER;
114+
// TraceRay(topLevelAS, shadowRayFlags, 0xFF, ERT_OCCLUSION, 0, EMT_OCCLUSION, rayDesc, occlusionPayload);
106115

107116
attenuation = occlusionPayload.attenuation;
108117
if (occlusionPayload.attenuation > 1.f/1024.f)

71_RayTracingPipeline/app_resources/raytrace.rint.hlsl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
#include "common.hlsl"
22

3+
#include "nbl/builtin/hlsl/spirv_intrinsics/core.hlsl"
4+
#include "nbl/builtin/hlsl/spirv_intrinsics/raytracing.hlsl"
5+
6+
using namespace nbl::hlsl;
7+
38
[[vk::push_constant]] SPushConstants pc;
49

510
struct Ray
@@ -26,22 +31,23 @@ float32_t hitSphere(SProceduralGeomInfo s, Ray r)
2631
void main()
2732
{
2833
Ray ray;
29-
ray.origin = WorldRayOrigin();
30-
ray.direction = WorldRayDirection();
34+
ray.origin = spirv::WorldRayOriginKHR;
35+
ray.direction = spirv::WorldRayDirectionKHR;
3136

32-
const int primID = PrimitiveIndex();
37+
const int primID = spirv::PrimitiveId;
3338

3439
// Sphere data
3540
SProceduralGeomInfo sphere = vk::RawBufferLoad<SProceduralGeomInfo>(pc.proceduralGeomInfoBuffer + primID * sizeof(SProceduralGeomInfo));
3641

3742
const float32_t tHit = hitSphere(sphere, ray);
3843

44+
[[vk::ext_storage_class(spv::StorageClassHitAttributeKHR)]]
3945
ProceduralHitAttribute hitAttrib;
4046

4147
// Report hit point
4248
if (tHit > 0)
4349
{
4450
hitAttrib.center = sphere.center;
45-
ReportHit(tHit, 0, hitAttrib);
51+
spirv::reportIntersectionKHR(tHit, 0);
4652
}
4753
}
Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
#include "common.hlsl"
22

3+
#include "nbl/builtin/hlsl/spirv_intrinsics/core.hlsl"
4+
#include "nbl/builtin/hlsl/spirv_intrinsics/raytracing.hlsl"
5+
using namespace nbl::hlsl;
6+
37
[[vk::push_constant]] SPushConstants pc;
48

59
[shader("closesthit")]
610
void main(inout PrimaryPayload payload, in ProceduralHitAttribute attrib)
711
{
8-
const float32_t3 worldPosition = WorldRayOrigin() + WorldRayDirection() * RayTCurrent();
12+
const float32_t3 worldPosition = spirv::WorldRayOriginKHR + spirv::WorldRayDirectionKHR * spirv::RayTmaxKHR;
913
const float32_t3 worldNormal = normalize(worldPosition - attrib.center);
1014

11-
payload.materialId = MaterialId::createProcedural(PrimitiveIndex()); // we use negative value to indicate that this is procedural
15+
payload.materialId = MaterialId::createProcedural(spirv::PrimitiveId); // we use negative value to indicate that this is procedural
1216

1317
payload.worldNormal = worldNormal;
14-
payload.rayDistance = RayTCurrent();
18+
payload.rayDistance = spirv::RayTmaxKHR;
1519

1620
}

71_RayTracingPipeline/app_resources/raytrace_shadow.rahit.hlsl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
#include "common.hlsl"
22
#include "nbl/builtin/hlsl/spirv_intrinsics/raytracing.hlsl"
3+
#include "nbl/builtin/hlsl/spirv_intrinsics/core.hlsl"
4+
5+
using namespace nbl::hlsl;
36

47
[[vk::push_constant]] SPushConstants pc;
58

69
[shader("anyhit")]
710
void main(inout OcclusionPayload payload, in BuiltInTriangleIntersectionAttributes attribs)
811
{
9-
const int instID = InstanceID();
12+
const int instID = spirv::InstanceCustomIndexKHR;
1013
const STriangleGeomInfo geom = vk::RawBufferLoad < STriangleGeomInfo > (pc.triangleGeomInfoBuffer + instID * sizeof(STriangleGeomInfo));
1114
const Material material = nbl::hlsl::_static_cast<Material>(geom.material);
1215

1316
const float attenuation = (1.f-material.alpha) * payload.attenuation;
1417
// DXC cogegens weird things in the presence of termination instructions
1518
payload.attenuation = attenuation;
19+
20+
21+
// Cannot use spirv::ignoreIntersectionKHR and spirv::terminateRayKHR due to https://github.com/microsoft/DirectXShaderCompiler/issues/7279
1622
// arbitrary constant, whatever you want the smallest attenuation to be. Remember until miss, the attenuatio is negative
1723
if (attenuation > -1.f/1024.f)
1824
AcceptHitAndEndSearch();

0 commit comments

Comments
 (0)