@@ -25,41 +25,47 @@ float3 unpackNormals3x10(uint32_t v)
2525 return clamp (float3 (pn) / 511.0 , -1.0 , 1.0 );
2626}
2727
28- float3 calculateSmoothNormals ( int instID, int primID, SGeomInfo geom, float2 bary)
28+ float3 calculateNormals ( int primID, SGeomInfo geom, float2 bary)
2929{
3030 const uint indexType = geom.indexType;
31- const uint objType = geom.objType ;
31+ const uint normalType = geom.normalType ;
3232
3333 const uint64_t vertexBufferAddress = geom.vertexBufferAddress;
3434 const uint64_t indexBufferAddress = geom.indexBufferAddress;
3535 const uint64_t normalBufferAddress = geom.normalBufferAddress;
3636
3737 uint32_t3 indices;
38- switch (indexType )
38+ if (indexBufferAddress == 0 )
3939 {
40- case 0 : // EIT_16BIT
41- indices = uint32_t3 ((nbl::hlsl::bda::__ptr<uint16_t3>::create (indexBufferAddress)+primID).deref ().load ());
42- break ;
43- case 1 : // EIT_32BIT
44- indices = uint32_t3 ((nbl::hlsl::bda::__ptr<uint32_t3>::create (indexBufferAddress)+primID).deref ().load ());
45- break ;
46- default : // EIT_NONE
40+ indices[0 ] = primID * 3 ;
41+ indices[1 ] = indices[0 ] + 1 ;
42+ indices[2 ] = indices[0 ] + 2 ;
43+ }
44+ else {
45+ switch (indexType)
4746 {
48- indices[0 ] = primID * 3 ;
49- indices[1 ] = indices[0 ] + 1 ;
50- indices[2 ] = indices[0 ] + 2 ;
47+ case 0 : // EIT_16BIT
48+ indices = uint32_t3 ((nbl::hlsl::bda::__ptr<uint16_t3>::create (indexBufferAddress)+primID).deref ().load ());
49+ break ;
50+ case 1 : // EIT_32BIT
51+ indices = uint32_t3 ((nbl::hlsl::bda::__ptr<uint32_t3>::create (indexBufferAddress)+primID).deref ().load ());
52+ break ;
5153 }
5254 }
5355
56+ if (normalBufferAddress == 0 || normalType == NT_UNKNOWN)
57+ {
58+ float3 v0 = vk::RawBufferLoad<float3 >(vertexBufferAddress + indices[0 ] * 12 );
59+ float3 v1 = vk::RawBufferLoad<float3 >(vertexBufferAddress + indices[1 ] * 12 );
60+ float3 v2 = vk::RawBufferLoad<float3 >(vertexBufferAddress + indices[2 ] * 12 );
61+
62+ return normalize (cross (v2 - v0, v1 - v0));
63+ }
64+
5465 float3 n0, n1, n2;
55- switch (objType )
66+ switch (normalType )
5667 {
57- case OT_CUBE:
58- case OT_SPHERE:
59- case OT_RECTANGLE:
60- case OT_CYLINDER:
61- //case OT_ARROW:
62- case OT_CONE:
68+ case NT_R8G8B8A8_SNORM:
6369 {
6470 uint32_t v0 = vk::RawBufferLoad<uint32_t>(normalBufferAddress + indices[0 ] * 4 );
6571 uint32_t v1 = vk::RawBufferLoad<uint32_t>(normalBufferAddress + indices[1 ] * 4 );
@@ -70,13 +76,13 @@ float3 calculateSmoothNormals(int instID, int primID, SGeomInfo geom, float2 bar
7076 n2 = normalize (nbl::hlsl::spirv::unpackSnorm4x8 (v2).xyz);
7177 }
7278 break ;
73- case OT_ICOSPHERE:
74- default :
79+ case NT_R32G32B32_SFLOAT:
7580 {
7681 n0 = normalize (vk::RawBufferLoad<float3 >(normalBufferAddress + indices[0 ] * 12 ));
7782 n1 = normalize (vk::RawBufferLoad<float3 >(normalBufferAddress + indices[1 ] * 12 ));
7883 n2 = normalize (vk::RawBufferLoad<float3 >(normalBufferAddress + indices[2 ] * 12 ));
7984 }
85+ break ;
8086 }
8187
8288 float3 barycentrics = float3 (0.0 , bary);
@@ -113,15 +119,16 @@ void main(uint32_t3 threadID : SV_DispatchThreadID)
113119
114120 if (spirv::rayQueryGetIntersectionTypeKHR (query, true ) == spv::RayQueryCommittedIntersectionTypeRayQueryCommittedIntersectionTriangleKHR)
115121 {
116- const int instID = spirv::rayQueryGetIntersectionInstanceIdKHR (query, true );
122+ const int instanceCustomIndex = spirv::rayQueryGetIntersectionInstanceCustomIndexKHR (query, true );
123+ const int geometryIndex = spirv::rayQueryGetIntersectionGeometryIndexKHR (query, true );
117124 const int primID = spirv::rayQueryGetIntersectionPrimitiveIndexKHR (query, true );
118125
119126 // TODO: candidate for `bda::__ptr<SGeomInfo>`
120- const SGeomInfo geom = vk::RawBufferLoad<SGeomInfo>(pc.geometryInfoBuffer + instID * sizeof (SGeomInfo),8 );
127+ const SGeomInfo geom = vk::RawBufferLoad<SGeomInfo>(pc.geometryInfoBuffer + (instanceCustomIndex + geometryIndex) * sizeof (SGeomInfo), 8 );
121128
122129 float3 normals;
123130 float2 barycentrics = spirv::rayQueryGetIntersectionBarycentricsKHR (query, true );
124- normals = calculateSmoothNormals (instID, primID, geom, barycentrics);
131+ normals = calculateNormals ( primID, geom, barycentrics);
125132
126133 normals = normalize (normals) * 0.5 + 0.5 ;
127134 color = float4 (normals, 1.0 );
0 commit comments