Skip to content

Commit 052158c

Browse files
author
kevyuu
committed
Use spirv intrinsics for raytracing command and builtin
1 parent 3de2363 commit 052158c

File tree

6 files changed

+40
-14
lines changed

6 files changed

+40
-14
lines changed

71_RayTracingPipeline/app_resources/raytrace.rahit.hlsl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include "common.hlsl"
22

3+
#include "nbl/builtin/hlsl/spirv_intrinsics/raytracing.hlsl"
4+
5+
using namespace nbl::hlsl;
6+
37
[[vk::push_constant]] SPushConstants pc;
48

59
[shader("anyhit")]

71_RayTracingPipeline/app_resources/raytrace.rchit.hlsl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#include "common.hlsl"
22

3+
#include "nbl/builtin/hlsl/spirv_intrinsics/raytracing.hlsl"
34
#include "nbl/builtin/hlsl/bda/__ptr.hlsl"
45

6+
using namespace nbl::hlsl;
7+
58
[[vk::push_constant]] SPushConstants pc;
69

710
float3 calculateNormals(int primID, STriangleGeomInfo geom, float2 bary)
@@ -75,15 +78,15 @@ float3 calculateNormals(int primID, STriangleGeomInfo geom, float2 bary)
7578
void main(inout PrimaryPayload payload, in BuiltInTriangleIntersectionAttributes attribs)
7679
{
7780
const int primID = PrimitiveIndex();
78-
const int instanceCustomIndex = InstanceIndex();
79-
const int geometryIndex = GeometryIndex();
81+
const int instanceCustomIndex = spirv::InstanceCustomIndexKHR;
82+
const int geometryIndex = spirv::RayGeometryIndexKHR;
8083
const STriangleGeomInfo geom = vk::RawBufferLoad < STriangleGeomInfo > (pc.triangleGeomInfoBuffer + (instanceCustomIndex + geometryIndex) * sizeof(STriangleGeomInfo));
8184
const float32_t3 vertexNormal = calculateNormals(primID, geom, attribs.barycentrics);
82-
const float32_t3 worldNormal = normalize(mul(vertexNormal, WorldToObject3x4()).xyz);
85+
const float32_t3 worldNormal = normalize(mul(vertexNormal, transpose(spirv::WorldToObjectKHR)).xyz);
8386

8487
payload.materialId = MaterialId::createTriangle(instanceCustomIndex);
8588

8689
payload.worldNormal = worldNormal;
87-
payload.rayDistance = RayTCurrent();
90+
payload.rayDistance = spirv::RayTmaxKHR;
8891

8992
}

71_RayTracingPipeline/app_resources/raytrace.rgen.hlsl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
static const int32_t s_sampleCount = 10;
1010
static const float32_t3 s_clearColor = float32_t3(0.3, 0.3, 0.8);
1111

12+
using namespace nbl::hlsl;
13+
1214
[[vk::push_constant]] SPushConstants pc;
1315

1416
[[vk::binding(0, 0)]] RaytracingAccelerationStructure topLevelAS;
@@ -23,8 +25,8 @@ float32_t nextRandomUnorm(inout nbl::hlsl::Xoroshiro64StarStar rnd)
2325
[shader("raygeneration")]
2426
void main()
2527
{
26-
const uint32_t3 launchID = DispatchRaysIndex();
27-
const uint32_t3 launchSize = DispatchRaysDimensions();
28+
const uint32_t3 launchID = spirv::LaunchIdKHR;
29+
const uint32_t3 launchSize = spirv::LaunchSizeKHR;
2830
const uint32_t2 coords = launchID.xy;
2931

3032
const uint32_t seed1 = nbl::hlsl::random::Pcg::create(pc.frameCounter)();
@@ -53,9 +55,11 @@ void main()
5355
rayDesc.TMin = 0.01;
5456
rayDesc.TMax = 10000.0;
5557

58+
[[vk::ext_storage_class(spv::StorageClassRayPayloadKHR)]]
5659
PrimaryPayload payload;
5760
payload.pcg = PrimaryPayload::generator_t::create(rnd());
58-
TraceRay(topLevelAS, RAY_FLAG_NONE, 0xff, ERT_PRIMARY, 0, EMT_PRIMARY, rayDesc, payload);
61+
spirv::traceRayKHR(topLevelAS, spv::RayFlagsMaskNone, 0xff, ERT_PRIMARY, 0, EMT_PRIMARY, rayDesc.Origin, rayDesc.TMin, rayDesc.Direction, rayDesc.TMax, payload);
62+
// TraceRay(topLevelAS, RAY_FLAG_NONE, 0xff, ERT_PRIMARY, 0, EMT_PRIMARY, rayDesc, payload);
5963

6064
const float32_t rayDistance = payload.rayDistance;
6165
if (rayDistance < 0)
@@ -67,9 +71,10 @@ void main()
6771
const float32_t3 worldPosition = pc.camPos + (camDirection * rayDistance);
6872

6973
// make sure to call with least live state
74+
[[vk::ext_storage_class(spv::StorageClassCallableDataKHR)]]
7075
RayLight cLight;
7176
cLight.inHitPosition = worldPosition;
72-
CallShader(pc.light.type, cLight);
77+
spirv::executeCallable(pc.light.type, cLight);
7378

7479
const float32_t3 worldNormal = payload.worldNormal;
7580

@@ -97,12 +102,16 @@ void main()
97102
rayDesc.TMin = 0.01;
98103
rayDesc.TMax = cLight.outLightDistance;
99104

105+
[[vk::ext_storage_class(spv::StorageClassRayPayloadKHR)]]
100106
OcclusionPayload occlusionPayload;
101107
// negative means its a hit, the miss shader will flip it back around to positive
102108
occlusionPayload.attenuation = -1.f;
103109
// abuse of miss shader to mean "not hit shader" solves us having to call closest hit shaders
104-
uint32_t shadowRayFlags = RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH | RAY_FLAG_SKIP_CLOSEST_HIT_SHADER;
105-
TraceRay(topLevelAS, shadowRayFlags, 0xFF, ERT_OCCLUSION, 0, EMT_OCCLUSION, rayDesc, occlusionPayload);
110+
uint32_t shadowRayFlags = spv::RayFlagsTerminateOnFirstHitKHRMask | spv::RayFlagsSkipClosestHitShaderKHRMask;
111+
spirv::traceRayKHR(topLevelAS, shadowRayFlags, 0xFF, ERT_OCCLUSION, 0, EMT_OCCLUSION, rayDesc.Origin, rayDesc.TMin, rayDesc.Direction, rayDesc.TMax, occlusionPayload);
112+
113+
// uint32_t shadowRayFlags = RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH | RAY_FLAG_SKIP_CLOSEST_HIT_SHADER;
114+
// TraceRay(topLevelAS, shadowRayFlags, 0xFF, ERT_OCCLUSION, 0, EMT_OCCLUSION, rayDesc, occlusionPayload);
106115

107116
attenuation = occlusionPayload.attenuation;
108117
if (occlusionPayload.attenuation > 1.f/1024.f)

71_RayTracingPipeline/app_resources/raytrace.rint.hlsl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include "common.hlsl"
22

3+
#include "nbl/builtin/hlsl/spirv_intrinsics/raytracing.hlsl"
4+
5+
using namespace nbl::hlsl;
6+
37
[[vk::push_constant]] SPushConstants pc;
48

59
struct Ray
@@ -26,8 +30,8 @@ float32_t hitSphere(SProceduralGeomInfo s, Ray r)
2630
void main()
2731
{
2832
Ray ray;
29-
ray.origin = WorldRayOrigin();
30-
ray.direction = WorldRayDirection();
33+
ray.origin = spirv::WorldRayOriginKHR;
34+
ray.direction = spirv::WorldRayDirectionKHR;
3135

3236
const int primID = PrimitiveIndex();
3337

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
#include "common.hlsl"
22

3+
#include "nbl/builtin/hlsl/spirv_intrinsics/raytracing.hlsl"
4+
using namespace nbl::hlsl;
5+
36
[[vk::push_constant]] SPushConstants pc;
47

58
[shader("closesthit")]
69
void main(inout PrimaryPayload payload, in ProceduralHitAttribute attrib)
710
{
8-
const float32_t3 worldPosition = WorldRayOrigin() + WorldRayDirection() * RayTCurrent();
11+
const float32_t3 worldPosition = spirv::WorldRayOriginKHR + spirv::WorldRayDirectionKHR * spirv::RayTmaxKHR;
912
const float32_t3 worldNormal = normalize(worldPosition - attrib.center);
1013

1114
payload.materialId = MaterialId::createProcedural(PrimitiveIndex()); // we use negative value to indicate that this is procedural
1215

1316
payload.worldNormal = worldNormal;
14-
payload.rayDistance = RayTCurrent();
17+
payload.rayDistance = spirv::RayTmaxKHR;
1518

1619
}

71_RayTracingPipeline/app_resources/raytrace_shadow.rahit.hlsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "common.hlsl"
22
#include "nbl/builtin/hlsl/spirv_intrinsics/raytracing.hlsl"
3+
#include "nbl/builtin/hlsl/spirv_intrinsics/core.hlsl"
4+
5+
using namespace nbl::hlsl;
36

47
[[vk::push_constant]] SPushConstants pc;
58

0 commit comments

Comments
 (0)