@@ -85,7 +85,7 @@ class CooperativeBinarySearch final : public application_templates::MonoDeviceAp
8585 SPushConstantRange pcRange = {};
8686 pcRange.stageFlags = IShader::E_SHADER_STAGE::ESS_COMPUTE;
8787 pcRange.offset = 0u ;
88- pcRange.size = 2 * sizeof (uint32_t );
88+ pcRange.size = sizeof (nbl::hlsl::PushConstants );
8989 auto layout = m_device->createPipelineLayout ({ &pcRange,1 }, smart_refctd_ptr (m_descriptorSetLayout));
9090 IGPUComputePipeline::SCreationParams params = {};
9191 params.layout = layout.get ();
@@ -186,11 +186,18 @@ class CooperativeBinarySearch final : public application_templates::MonoDeviceAp
186186 m_cmdbuf->pipelineBarrier (EDF_NONE, depInfo);
187187
188188
189- const uint32_t pushConstants[2 ] = { 1920 , 1080 };
190189 const IGPUDescriptorSet* set = m_descriptorSet.get ();
190+ const uint32_t numIndices = sizeof (TestCaseIndices) / sizeof (TestCaseIndices[0 ]);
191+ const uint32_t lastValue = TestCaseIndices[numIndices - 1 ];
192+ const uint32_t totalValues = lastValue + 100 ;
193+ nbl::hlsl::PushConstants coopBinarySearchPC = {
194+ .EntityCount = numIndices,
195+ };
196+
191197 m_cmdbuf->bindComputePipeline (m_pipeline.get ());
192198 m_cmdbuf->bindDescriptorSets (EPBP_COMPUTE, m_pipeline->getLayout (), 0u , 1u , &set);
193- m_cmdbuf->dispatch (240 , 135 , 1u );
199+ m_cmdbuf->pushConstants (m_pipeline->getLayout (), nbl::hlsl::ShaderStage::ESS_COMPUTE, 0u , sizeof (nbl::hlsl::PushConstants), &coopBinarySearchPC);
200+ m_cmdbuf->dispatch ((totalValues + 255u ) / 256u , 1u , 1u );
194201
195202 layoutBufferBarrier[0 ].barrier .dep = layoutBufferBarrier[0 ].barrier .dep .nextBarrier (PIPELINE_STAGE_FLAGS::COPY_BIT,ACCESS_FLAGS::TRANSFER_READ_BIT);
196203 m_cmdbuf->pipelineBarrier (EDF_NONE,depInfo);
@@ -216,7 +223,14 @@ class CooperativeBinarySearch final : public application_templates::MonoDeviceAp
216223
217224 auto ptr = m_allocations[1 ].memory ->getMappedPointer ();
218225 assert (ptr);
219- printf (" readback ptr %p\n " , ptr);
226+
227+ uint32_t * valuesPtr = reinterpret_cast <uint32_t *>(ptr);
228+ for (uint32_t i = 0 ; i < totalValues; i++) {
229+ uint32_t value = valuesPtr[i];
230+ const uint32_t * binarySearchResult = std::upper_bound (TestCaseIndices, TestCaseIndices + numIndices, i);
231+ uint32_t lowerBoundIndex = std::distance (TestCaseIndices, binarySearchResult) - 1 ;
232+ assert (value == lowerBoundIndex);
233+ }
220234
221235 m_keepRunning = false ;
222236 }
0 commit comments