Skip to content

Commit e546428

Browse files
committed
Patch things for cooperative binary search test
1 parent 4969227 commit e546428

File tree

3 files changed

+1316
-7
lines changed

3 files changed

+1316
-7
lines changed

72_CooperativeBinarySearch/app_resources/binarySearch.comp.hlsl

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,115 @@
55
#pragma wave shader_stage(compute)
66

77
#include "common.h"
8+
#include "nbl/builtin/hlsl/glsl_compat/subgroup_ballot.hlsl"
89
using namespace nbl::hlsl;
910

10-
[[vk::push_constant]] ConstantBuffer<PushConstants> Constants;
11+
[[vk::push_constant]] PushConstants Constants;
1112
[[vk::binding(0)]] StructuredBuffer<uint> Histogram;
1213
[[vk::binding(1)]] RWStructuredBuffer<uint> Output;
1314

1415
static const uint32_t GroupsharedSize = 256;
1516

17+
uint getNextPowerOfTwo(uint number) {
18+
return 2 << firstbithigh(number - 1);
19+
}
20+
21+
uint getLaneWithFirstBitSet(bool condition) {
22+
uint4 ballot = WaveActiveBallot(condition);
23+
if (all(ballot == 0)) {
24+
return WaveGetLaneCount();
25+
}
26+
return nbl::hlsl::glsl::subgroupBallotFindLSB(ballot);
27+
}
28+
29+
// findValue must be the same across the entire wave
30+
// Could use something like WaveReadFirstLane to be fully sure
31+
uint binarySearchLowerBoundFindValue(uint findValue, StructuredBuffer<uint> searchBuffer, uint searchBufferSize) {
32+
uint lane = WaveGetLaneIndex();
33+
34+
uint left = 0;
35+
uint right = searchBufferSize - 1;
36+
37+
uint32_t range = getNextPowerOfTwo(right - left);
38+
// do pivots as long as we can't coalesced load
39+
while (range > WaveGetLaneCount())
40+
{
41+
// there must be at least 1 gap between subsequent pivots
42+
const uint32_t step = range / WaveGetLaneCount();
43+
const uint32_t halfStep = step >> 1;
44+
const uint32_t pivotOffset = lane * step+halfStep;
45+
const uint32_t pivotIndex = left + pivotOffset;
46+
47+
uint4 notGreaterPivots = WaveActiveBallot(pivotIndex < right && !(findValue < searchBuffer[pivotIndex]));
48+
uint partition = nbl::hlsl::glsl::subgroupBallotBitCount(notGreaterPivots);
49+
// only move left if needed
50+
if (partition != 0)
51+
left += partition * step - halfStep;
52+
// if we go into final half partition, the range becomes less too
53+
range = partition != WaveGetLaneCount() ? step : halfStep;
54+
}
55+
56+
uint threadSearchIndex = left + lane;
57+
bool laneValid = threadSearchIndex < searchBufferSize;
58+
uint histAtIndex = laneValid ? searchBuffer[threadSearchIndex] : -1;
59+
uint firstLaneGreaterThan = getLaneWithFirstBitSet(histAtIndex > findValue);
60+
61+
return left + firstLaneGreaterThan - 1;
62+
}
63+
64+
groupshared uint shared_groupSearchBufferMinIndex;
65+
groupshared uint shared_groupSearchBufferMaxIndex;
66+
groupshared uint shared_groupSearchValues[GroupsharedSize];
67+
68+
// Binary search using the entire workgroup, making it log32 or log64 (every iteration, the possible set of
69+
// values is divided by the number of lanes in a wave)
70+
uint binarySearchLowerBoundCooperative(uint groupIndex, uint groupThread, StructuredBuffer<uint> searchBuffer, uint searchBufferSize) {
71+
uint minSearchValue = groupIndex.x * GroupsharedSize;
72+
uint maxSearchValue = ((groupIndex.x + 1) * GroupsharedSize) - 1;
73+
74+
// On each workgroup, two subgroups do the search
75+
// - One searches for the minimum, the other searches for the maximum
76+
// - Store the minimum and maximum on groupshared memory, then do a barrier
77+
uint wave = groupThread / WaveGetLaneCount();
78+
if (wave < 2) {
79+
uint search = wave == 0 ? minSearchValue : maxSearchValue;
80+
uint searchResult = binarySearchLowerBoundFindValue(search, searchBuffer, searchBufferSize);
81+
if (WaveIsFirstLane()) {
82+
if (wave == 0) shared_groupSearchBufferMinIndex = searchResult;
83+
else shared_groupSearchBufferMaxIndex = searchResult;
84+
}
85+
}
86+
GroupMemoryBarrierWithGroupSync();
87+
88+
// Since every instance has at least one triangle, we know that having workgroup values
89+
// for each value in the range of minimum to maximum will suffice.
90+
91+
// Write every value in the range to groupshared memory and barrier.
92+
uint idx = shared_groupSearchBufferMinIndex + groupThread.x;
93+
if (idx <= shared_groupSearchBufferMaxIndex) {
94+
shared_groupSearchValues[groupThread.x] = searchBuffer[idx];
95+
}
96+
GroupMemoryBarrierWithGroupSync();
97+
98+
uint maxValueIndex = shared_groupSearchBufferMaxIndex - shared_groupSearchBufferMinIndex;
99+
100+
uint searchValue = minSearchValue + groupThread;
101+
uint currentSearchValueIndex = 0;
102+
uint laneValue = shared_groupSearchBufferMaxIndex;
103+
while (currentSearchValueIndex <= maxValueIndex) {
104+
uint curValue = shared_groupSearchValues[currentSearchValueIndex];
105+
if (curValue > searchValue) {
106+
laneValue = shared_groupSearchBufferMinIndex + currentSearchValueIndex - 1;
107+
break;
108+
}
109+
currentSearchValueIndex ++;
110+
}
111+
112+
return laneValue;
113+
}
114+
16115
[numthreads(256, 1, 1)]
17116
void main(const uint3 thread : SV_DispatchThreadID, const uint3 groupThread : SV_GroupThreadID, const uint3 group : SV_GroupID)
18117
{
19-
118+
Output[thread.x] = binarySearchLowerBoundCooperative(group.x, groupThread.x, Histogram, Constants.EntityCount);
20119
}

72_CooperativeBinarySearch/main.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ using namespace nbl::examples;
2222

2323
//using namespace glm;
2424

25+
static constexpr uint32_t TestCaseIndices[] = {
26+
#include "testCaseData.h"
27+
};
28+
29+
2530
void cpu_tests();
2631

2732
class CooperativeBinarySearch final : public application_templates::MonoDeviceApplication, public BuiltinResourcesApplication
@@ -101,14 +106,19 @@ class CooperativeBinarySearch final : public application_templates::MonoDeviceAp
101106

102107
auto reqs = m_buffers[i]->getMemoryReqs();
103108
reqs.memoryTypeBits &= m_device->getPhysicalDevice()->getHostVisibleMemoryTypeBits();
104-
m_device->allocate(reqs, m_buffers[i].get());
109+
110+
m_allocations[i] = m_device->allocate(reqs, m_buffers[i].get());
111+
112+
auto allocationType = i == 0 ? IDeviceMemoryAllocation::EMCAF_WRITE : IDeviceMemoryAllocation::EMCAF_READ;
113+
auto mapResult = m_allocations[i].memory->map({ 0ull,m_allocations[i].memory->getAllocationSize() }, allocationType);
114+
assert(mapResult);
105115
}
106116

107117
smart_refctd_ptr<IDescriptorPool> descriptorPool = nullptr;
108118
{
109119
IDescriptorPool::SCreateInfo createInfo = {};
110120
createInfo.maxSets = 1;
111-
createInfo.maxDescriptorCount[static_cast<uint32_t>(IDescriptor::E_TYPE::ET_STORAGE_BUFFER)] = 1;
121+
createInfo.maxDescriptorCount[static_cast<uint32_t>(IDescriptor::E_TYPE::ET_STORAGE_BUFFER)] = bindingCount;
112122
descriptorPool = m_device->createDescriptorPool(std::move(createInfo));
113123
}
114124

@@ -130,6 +140,14 @@ class CooperativeBinarySearch final : public application_templates::MonoDeviceAp
130140

131141
m_device->updateDescriptorSets(bindingCount, writeDescriptorSets, 0u, nullptr);
132142

143+
// Write test data to the m_buffers[0]
144+
auto outPtr = m_allocations[0].memory->getMappedPointer();
145+
assert(outPtr);
146+
memcpy(
147+
reinterpret_cast<void*>(outPtr),
148+
reinterpret_cast<const void*>(&TestCaseIndices[0]),
149+
sizeof(TestCaseIndices));
150+
133151
// In contrast to fences, we just need one semaphore to rule all dispatches
134152
return true;
135153
}
@@ -196,9 +214,8 @@ class CooperativeBinarySearch final : public application_templates::MonoDeviceAp
196214
m_device->blockForSemaphores(waitInfos);
197215
}
198216

199-
auto mem = m_buffers[1]->getBoundMemory();
200-
assert(mem.memory->isMappable());
201-
auto* ptr = mem.memory->map({ .offset = 0, .length = mem.memory->getAllocationSize() });
217+
auto ptr = m_allocations[1].memory->getMappedPointer();
218+
assert(ptr);
202219
printf("readback ptr %p\n", ptr);
203220

204221
m_keepRunning = false;
@@ -216,6 +233,7 @@ class CooperativeBinarySearch final : public application_templates::MonoDeviceAp
216233
smart_refctd_ptr<IGPUDescriptorSet> m_descriptorSet;
217234

218235
smart_refctd_ptr<IGPUBuffer> m_buffers[2];
236+
nbl::video::IDeviceMemoryAllocator::SAllocation m_allocations[2] = {};
219237
smart_refctd_ptr<IGPUCommandBuffer> m_cmdbuf = nullptr;
220238
IQueue* m_queue;
221239
smart_refctd_ptr<IGPUCommandPool> m_commandPool;

0 commit comments

Comments
 (0)