Skip to content

Commit 1c56eb0

Browse files
committed
reduce workgroup macro definitions, use config string
1 parent 683aa87 commit 1c56eb0

File tree

3 files changed

+24
-39
lines changed

3 files changed

+24
-39
lines changed

23_Arithmetic2UnitTest/app_resources/shaderCommon.hlsl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,6 @@
33
using namespace nbl;
44
using namespace hlsl;
55

6-
// https://github.com/microsoft/DirectXShaderCompiler/issues/6144
7-
uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize() {return uint32_t3(WORKGROUP_SIZE,1,1);}
8-
9-
#ifndef ITEMS_PER_INVOCATION
10-
#error "Define ITEMS_PER_INVOCATION!"
11-
#endif
12-
136
[[vk::push_constant]] PushConstantData pc;
147

158
struct device_capabilities
@@ -24,7 +17,3 @@ struct device_capabilities
2417
#ifndef OPERATION
2518
#error "Define OPERATION!"
2619
#endif
27-
28-
#ifndef SUBGROUP_SIZE_LOG2
29-
#error "Define SUBGROUP_SIZE_LOG2!"
30-
#endif

23_Arithmetic2UnitTest/app_resources/testWorkgroup.comp.hlsl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55
#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl"
66
#include "nbl/builtin/hlsl/workgroup2/arithmetic.hlsl"
77

8-
static const uint32_t WORKGROUP_SIZE = 1u << WORKGROUP_SIZE_LOG2;
8+
using config_t = WORKGROUP_CONFIG_T;
99

1010
#include "shaderCommon.hlsl"
1111

12-
using config_t = workgroup2::ArithmeticConfiguration<WORKGROUP_SIZE_LOG2, SUBGROUP_SIZE_LOG2, ITEMS_PER_INVOCATION>;
13-
1412
typedef vector<uint32_t, config_t::ItemsPerInvocation_0> type_t;
1513

1614
// final (level 1/2) scan needs to fit in one subgroup exactly
@@ -52,7 +50,7 @@ struct operation_t
5250
template<class Binop>
5351
static void subtest()
5452
{
55-
assert(glsl::gl_SubgroupSize() == 1u<<SUBGROUP_SIZE_LOG2)
53+
assert(glsl::gl_SubgroupSize() == config_t::SubgroupSize)
5654

5755
operation_t<Binop,device_capabilities> func;
5856
func();
@@ -69,7 +67,7 @@ void test()
6967
subtest<arithmetic::maximum<uint32_t> >();
7068
}
7169

72-
[numthreads(WORKGROUP_SIZE,1,1)]
70+
[numthreads(config_t::WorkgroupSize,1,1)]
7371
void main()
7472
{
7573
test();

23_Arithmetic2UnitTest/main.cpp

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class Workgroup2ScanTestApp final : public application_templates::BasicMultiQueu
186186
for (auto subgroupSize = MinSubgroupSize; subgroupSize <= MaxSubgroupSize; subgroupSize *= 2u)
187187
{
188188
const uint8_t subgroupSizeLog2 = hlsl::findMSB(subgroupSize);
189-
for (uint32_t workgroupSize = subgroupSize; workgroupSize <= MaxWorkgroupSize; workgroupSize *= 2u)
189+
for (uint32_t workgroupSize = 64; workgroupSize <= MaxWorkgroupSize; workgroupSize *= 2u)
190190
{
191191
// make sure renderdoc captures everything for debugging
192192
m_api->startCapture();
@@ -198,14 +198,15 @@ class Workgroup2ScanTestApp final : public application_templates::BasicMultiQueu
198198
uint32_t itemsPerWG = workgroupSize * itemsPerInvocation;
199199
m_logger->log("Testing Items per Invocation %u", ILogger::ELL_INFO, itemsPerInvocation);
200200
bool passed = true;
201-
passed = runTest<emulatedReduction, false>(subgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
202-
logTestOutcome(passed, itemsPerWG);
203-
passed = runTest<emulatedScanInclusive, false>(subgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
204-
logTestOutcome(passed, itemsPerWG);
205-
passed = runTest<emulatedScanExclusive, false>(subgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
206-
logTestOutcome(passed, itemsPerWG);
207-
208-
hlsl::workgroup2::SArithmeticConfiguration wgConfig = hlsl::workgroup2::SArithmeticConfiguration::create(hlsl::findMSB(workgroupSize), subgroupSizeLog2, itemsPerInvocation);
201+
//passed = runTest<emulatedReduction, false>(subgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
202+
//logTestOutcome(passed, itemsPerWG);
203+
//passed = runTest<emulatedScanInclusive, false>(subgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
204+
//logTestOutcome(passed, itemsPerWG);
205+
//passed = runTest<emulatedScanExclusive, false>(subgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
206+
//logTestOutcome(passed, itemsPerWG);
207+
208+
hlsl::workgroup2::SArithmeticConfiguration wgConfig;
209+
wgConfig.init(hlsl::findMSB(workgroupSize), subgroupSizeLog2, itemsPerInvocation);
209210
itemsPerWG = wgConfig.VirtualWorkgroupSize * wgConfig.ItemsPerInvocation_0;
210211
m_logger->log("Testing Item Count %u", ILogger::ELL_INFO, itemsPerWG);
211212
passed = runTest<emulatedReduction, true>(workgroupTestSource, elementCount, subgroupSizeLog2, workgroupSize, bool(useNative), itemsPerWG, itemsPerInvocation) && passed;
@@ -306,28 +307,25 @@ class Workgroup2ScanTestApp final : public application_templates::BasicMultiQueu
306307
smart_refctd_ptr<ICPUShader> overriddenUnspecialized;
307308
if constexpr (WorkgroupTest)
308309
{
309-
const std::string definitions[6] = {
310+
hlsl::workgroup2::SArithmeticConfiguration wgConfig;
311+
wgConfig.init(hlsl::findMSB(workgroupSize), subgroupSizeLog2, itemsPerInvoc);
312+
313+
const std::string definitions[3] = {
310314
"workgroup2::" + arith_name,
311-
std::to_string(workgroupSizeLog2),
312-
std::to_string(itemsPerWG),
313-
std::to_string(itemsPerInvoc),
314-
std::to_string(subgroupSizeLog2),
315+
wgConfig.getConfigTemplateStructString(),
315316
std::to_string(arith_name=="reduction")
316317
};
317318

318-
const IShaderCompiler::SMacroDefinition defines[7] = {
319+
const IShaderCompiler::SMacroDefinition defines[4] = {
319320
{ "OPERATION", definitions[0] },
320-
{ "WORKGROUP_SIZE_LOG2", definitions[1] },
321-
{ "ITEMS_PER_WG", definitions[2] },
322-
{ "ITEMS_PER_INVOCATION", definitions[3] },
323-
{ "SUBGROUP_SIZE_LOG2", definitions[4] },
324-
{ "IS_REDUCTION", definitions[5] },
321+
{ "WORKGROUP_CONFIG_T", definitions[1] },
322+
{ "IS_REDUCTION", definitions[2] },
325323
{ "TEST_NATIVE", "1" }
326324
};
327325
if (useNative)
328-
options.preprocessorOptions.extraDefines = { defines, defines + 7 };
326+
options.preprocessorOptions.extraDefines = { defines, defines + 4 };
329327
else
330-
options.preprocessorOptions.extraDefines = { defines, defines + 6 };
328+
options.preprocessorOptions.extraDefines = { defines, defines + 3 };
331329

332330
overriddenUnspecialized = compiler->compileToSPIRV((const char*)source->getContent()->getPointer(), options);
333331
}
@@ -358,7 +356,7 @@ class Workgroup2ScanTestApp final : public application_templates::BasicMultiQueu
358356
auto pipeline = createPipeline(overriddenUnspecialized.get(),subgroupSizeLog2);
359357

360358
// TODO: overlap dispatches with memory readbacks (requires multiple copies of `buffers`)
361-
uint32_t workgroupCount = min(elementCount / itemsPerWG, m_physicalDevice->getLimits().maxComputeWorkGroupCount[0]);
359+
uint32_t workgroupCount = 1;// min(elementCount / itemsPerWG, m_physicalDevice->getLimits().maxComputeWorkGroupCount[0]);
362360

363361
cmdbuf->begin(IGPUCommandBuffer::USAGE::NONE);
364362
cmdbuf->bindComputePipeline(pipeline.get());

0 commit comments

Comments
 (0)