55#pragma wave shader_stage (compute)
66
77#include "common.h"
8+ #include "nbl/builtin/hlsl/glsl_compat/subgroup_ballot.hlsl"
89using namespace nbl::hlsl;
910
10- [[vk::push_constant]] ConstantBuffer < PushConstants> Constants;
11+ [[vk::push_constant]] PushConstants Constants;
1112[[vk::binding (0 )]] StructuredBuffer<uint > Histogram;
1213[[vk::binding (1 )]] RWStructuredBuffer <uint > Output;
1314
1415static const uint32_t GroupsharedSize = 256 ;
1516
17+ uint getNextPowerOfTwo (uint number) {
18+ return 2 << firstbithigh (number - 1 );
19+ }
20+
21+ uint getLaneWithFirstBitSet (bool condition) {
22+ uint4 ballot = WaveActiveBallot (condition);
23+ if (all (ballot == 0 )) {
24+ return WaveGetLaneCount ();
25+ }
26+ return nbl::hlsl::glsl::subgroupBallotFindLSB (ballot);
27+ }
28+
29+ // findValue must be the same across the entire wave
30+ // Could use something like WaveReadFirstLane to be fully sure
31+ uint binarySearchLowerBoundFindValue (uint findValue, StructuredBuffer<uint > searchBuffer, uint searchBufferSize) {
32+ uint lane = WaveGetLaneIndex ();
33+
34+ uint left = 0 ;
35+ uint right = searchBufferSize - 1 ;
36+
37+ uint32_t range = getNextPowerOfTwo (right - left);
38+ // do pivots as long as we can't coalesced load
39+ while (range > WaveGetLaneCount ())
40+ {
41+ // there must be at least 1 gap between subsequent pivots
42+ const uint32_t step = range / WaveGetLaneCount ();
43+ const uint32_t halfStep = step >> 1 ;
44+ const uint32_t pivotOffset = lane * step+halfStep;
45+ const uint32_t pivotIndex = left + pivotOffset;
46+
47+ uint4 notGreaterPivots = WaveActiveBallot (pivotIndex < right && !(findValue < searchBuffer[pivotIndex]));
48+ uint partition = nbl::hlsl::glsl::subgroupBallotBitCount (notGreaterPivots);
49+ // only move left if needed
50+ if (partition != 0 )
51+ left += partition * step - halfStep;
52+ // if we go into final half partition, the range becomes less too
53+ range = partition != WaveGetLaneCount () ? step : halfStep;
54+ }
55+
56+ uint threadSearchIndex = left + lane;
57+ bool laneValid = threadSearchIndex < searchBufferSize;
58+ uint histAtIndex = laneValid ? searchBuffer[threadSearchIndex] : -1 ;
59+ uint firstLaneGreaterThan = getLaneWithFirstBitSet (histAtIndex > findValue);
60+
61+ return left + firstLaneGreaterThan - 1 ;
62+ }
63+
64+ groupshared uint shared_groupSearchBufferMinIndex;
65+ groupshared uint shared_groupSearchBufferMaxIndex;
66+ groupshared uint shared_groupSearchValues[GroupsharedSize];
67+
68+ // Binary search using the entire workgroup, making it log32 or log64 (every iteration, the possible set of
69+ // values is divided by the number of lanes in a wave)
70+ uint binarySearchLowerBoundCooperative (uint groupIndex, uint groupThread, StructuredBuffer<uint > searchBuffer, uint searchBufferSize) {
71+ uint minSearchValue = groupIndex.x * GroupsharedSize;
72+ uint maxSearchValue = ((groupIndex.x + 1 ) * GroupsharedSize) - 1 ;
73+
74+ // On each workgroup, two subgroups do the search
75+ // - One searches for the minimum, the other searches for the maximum
76+ // - Store the minimum and maximum on groupshared memory, then do a barrier
77+ uint wave = groupThread / WaveGetLaneCount ();
78+ if (wave < 2 ) {
79+ uint search = wave == 0 ? minSearchValue : maxSearchValue;
80+ uint searchResult = binarySearchLowerBoundFindValue (search, searchBuffer, searchBufferSize);
81+ if (WaveIsFirstLane ()) {
82+ if (wave == 0 ) shared_groupSearchBufferMinIndex = searchResult;
83+ else shared_groupSearchBufferMaxIndex = searchResult;
84+ }
85+ }
86+ GroupMemoryBarrierWithGroupSync ();
87+
88+ // Since every instance has at least one triangle, we know that having workgroup values
89+ // for each value in the range of minimum to maximum will suffice.
90+
91+ // Write every value in the range to groupshared memory and barrier.
92+ uint idx = shared_groupSearchBufferMinIndex + groupThread.x;
93+ if (idx <= shared_groupSearchBufferMaxIndex) {
94+ shared_groupSearchValues[groupThread.x] = searchBuffer[idx];
95+ }
96+ GroupMemoryBarrierWithGroupSync ();
97+
98+ uint maxValueIndex = shared_groupSearchBufferMaxIndex - shared_groupSearchBufferMinIndex;
99+
100+ uint searchValue = minSearchValue + groupThread;
101+ uint currentSearchValueIndex = 0 ;
102+ uint laneValue = shared_groupSearchBufferMaxIndex;
103+ while (currentSearchValueIndex <= maxValueIndex) {
104+ uint curValue = shared_groupSearchValues[currentSearchValueIndex];
105+ if (curValue > searchValue) {
106+ laneValue = shared_groupSearchBufferMinIndex + currentSearchValueIndex - 1 ;
107+ break ;
108+ }
109+ currentSearchValueIndex ++;
110+ }
111+
112+ return laneValue;
113+ }
114+
16115[numthreads (256 , 1 , 1 )]
17116void main (const uint3 thread : SV_DispatchThreadID , const uint3 groupThread : SV_GroupThreadID , const uint3 group : SV_GroupID )
18117{
19-
118+ Output[thread.x] = binarySearchLowerBoundCooperative (group.x, groupThread.x, Histogram, Constants.EntityCount);
20119}
0 commit comments