Skip to content

Commit eccd3d8

Browse files
alsepkowdamyanpgithub-actions[bot]
authored
Execution Tests: Long Vector WaveActiveSum (microsoft#7878)
Adds the basic framework for WaveActiveOp tests and the test cases for WaveActiveSum. This partially addresses microsoft#7472 WARP requires an update for this test to pass so this test will not run in automation for now (no priority set in TAEF metadata). --------- Co-authored-by: Damyan Pepper <damyanp@microsoft.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent facd05a commit eccd3d8

File tree

4 files changed

+167
-22
lines changed

4 files changed

+167
-22
lines changed

tools/clang/unittests/HLSLExec/LongVectorOps.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,4 +193,6 @@ OP_LOAD_AND_STORE_SB(LoadAndStore_RD_SB_SRV, "RootDescriptor_SRV")
193193
#undef OP_LOAD_AND_STORE
194194
#undef OP_LOAD_AND_STORE_DEFINES
195195

196+
OP_DEFAULT(Wave, WaveActiveSum, 1, "WaveActiveSum", "")
197+
196198
#undef OP

tools/clang/unittests/HLSLExec/LongVectorTestData.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ struct HLSLHalf_t {
114114

115115
Val = DirectX::PackedVector::XMConvertFloatToHalf(F);
116116
}
117+
HLSLHalf_t(const uint32_t U) {
118+
float F = static_cast<float>(U);
119+
Val = DirectX::PackedVector::XMConvertFloatToHalf(F);
120+
}
117121

118122
// PackedVector::HALF is a uint16. Make sure we don't ever accidentally
119123
// convert one of these to a HLSLHalf_t by arithmetically converting it to a

tools/clang/unittests/HLSLExec/LongVectors.cpp

Lines changed: 147 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,10 @@ static WEX::Common::String getInputValueSetName(size_t Index) {
313313
return ValueSetName;
314314
}
315315

316-
std::string getCompilerOptionsString(const Operation &Operation,
317-
const DataType &OpDataType,
318-
const DataType &OutDataType,
319-
size_t VectorSize) {
316+
std::string getCompilerOptionsString(
317+
const Operation &Operation, const DataType &OpDataType,
318+
const DataType &OutDataType, size_t VectorSize,
319+
std::optional<std::string> AdditionalOptions = std::nullopt) {
320320
std::stringstream CompilerOptions;
321321

322322
if (OpDataType.Is16Bit || OutDataType.Is16Bit)
@@ -337,6 +337,9 @@ std::string getCompilerOptionsString(const Operation &Operation,
337337

338338
CompilerOptions << " -DBASIC_OP_TYPE=0x" << std::hex << Operation.Arity;
339339

340+
if (AdditionalOptions)
341+
CompilerOptions << " " << AdditionalOptions.value();
342+
340343
return CompilerOptions.str();
341344
}
342345

@@ -387,7 +390,8 @@ template <typename OUT_TYPE, typename T>
387390
std::optional<std::vector<OUT_TYPE>>
388391
runTest(ID3D12Device *D3DDevice, bool VerboseLogging,
389392
const Operation &Operation, const InputSets<T> &Inputs,
390-
size_t ExpectedOutputSize) {
393+
size_t ExpectedOutputSize,
394+
std::optional<std::string> AdditionalCompilerOptions) {
391395
DXASSERT_NOMSG(Inputs.size() == Operation.Arity);
392396

393397
if (VerboseLogging) {
@@ -403,8 +407,9 @@ runTest(ID3D12Device *D3DDevice, bool VerboseLogging,
403407

404408
// We have to construct the string outside of the lambda. Otherwise it's
405409
// cleaned up when the lambda finishes executing but before the shader runs.
406-
std::string CompilerOptionsString = getCompilerOptionsString(
407-
Operation, OpDataType, OutDataType, Inputs[0].size());
410+
std::string CompilerOptionsString =
411+
getCompilerOptionsString(Operation, OpDataType, OutDataType,
412+
Inputs[0].size(), AdditionalCompilerOptions);
408413

409414
dxc::SpecificDllLoader DxilDllLoader;
410415
CComPtr<IStream> TestXML;
@@ -570,13 +575,15 @@ struct ValidationConfig {
570575
};
571576

572577
template <typename T, typename OUT_TYPE>
573-
void runAndVerify(ID3D12Device *D3DDevice, bool VerboseLogging,
574-
const Operation &Operation, const InputSets<T> &Inputs,
575-
const std::vector<OUT_TYPE> &Expected,
576-
const ValidationConfig &ValidationConfig) {
578+
void runAndVerify(
579+
ID3D12Device *D3DDevice, bool VerboseLogging, const Operation &Operation,
580+
const InputSets<T> &Inputs, const std::vector<OUT_TYPE> &Expected,
581+
const ValidationConfig &ValidationConfig,
582+
std::optional<std::string> AdditionalCompilerOptions = std::nullopt) {
577583

578-
std::optional<std::vector<OUT_TYPE>> Actual = runTest<OUT_TYPE>(
579-
D3DDevice, VerboseLogging, Operation, Inputs, Expected.size());
584+
std::optional<std::vector<OUT_TYPE>> Actual =
585+
runTest<OUT_TYPE>(D3DDevice, VerboseLogging, Operation, Inputs,
586+
Expected.size(), AdditionalCompilerOptions);
580587

581588
// If the test didn't run, don't verify anything.
582589
if (!Actual)
@@ -1253,6 +1260,19 @@ FLOAT_SPECIAL_OP(OpType::IsInf, (std::isinf(A)));
12531260
FLOAT_SPECIAL_OP(OpType::IsNan, (std::isnan(A)));
12541261
#undef FLOAT_SPECIAL_OP
12551262

1263+
//
1264+
// Wave Ops
1265+
//
1266+
1267+
#define WAVE_ACTIVE_OP(OP, IMPL) \
1268+
template <typename T> struct Op<OP, T, 1> : DefaultValidation<T> { \
1269+
T operator()(T A, T WaveSize) { return IMPL; } \
1270+
};
1271+
1272+
WAVE_ACTIVE_OP(OpType::WaveActiveSum, (A * WaveSize));
1273+
1274+
#undef WAVE_ACTIVE_OP
1275+
12561276
//
12571277
// dispatchTest
12581278
//
@@ -1296,9 +1316,25 @@ template <OpType OP, typename T> struct ExpectedBuilder {
12961316
}
12971317
};
12981318

1319+
template <OpType OP, typename T> struct WaveOpExpectedBuilder {
1320+
1321+
static auto buildExpected(Op<OP, T, 1> Op, const InputSets<T> &Inputs,
1322+
UINT WaveSize) {
1323+
DXASSERT_NOMSG(Inputs.size() == 1);
1324+
const T WaveSizeT = static_cast<T>(WaveSize);
1325+
1326+
std::vector<decltype(Op(T(), WaveSizeT))> Expected;
1327+
Expected.reserve(Inputs[0].size());
1328+
1329+
for (size_t I = 0; I < Inputs[0].size(); ++I)
1330+
Expected.push_back(Op(Inputs[0][I], WaveSizeT));
1331+
1332+
return Expected;
1333+
}
1334+
};
1335+
12991336
template <typename T, OpType OP>
1300-
void dispatchTest(ID3D12Device *D3DDevice, bool VerboseLogging,
1301-
size_t OverrideInputSize) {
1337+
std::vector<size_t> getInputSizesToTest(size_t OverrideInputSize) {
13021338
std::vector<size_t> InputVectorSizes;
13031339
const std::array<size_t, 8> DefaultInputSizes = {3, 5, 16, 17,
13041340
35, 100, 256, 1024};
@@ -1319,8 +1355,17 @@ void dispatchTest(ID3D12Device *D3DDevice, bool VerboseLogging,
13191355
InputVectorSizes.push_back(MaxInputSize);
13201356
}
13211357

1322-
constexpr const Operation &Operation = getOperation(OP);
1358+
return InputVectorSizes;
1359+
}
1360+
1361+
template <typename T, OpType OP>
1362+
void dispatchTest(ID3D12Device *D3DDevice, bool VerboseLogging,
1363+
size_t OverrideInputSize) {
13231364

1365+
const std::vector<size_t> InputVectorSizes =
1366+
getInputSizesToTest<T, OP>(OverrideInputSize);
1367+
1368+
constexpr const Operation &Operation = getOperation(OP);
13241369
Op<OP, T, Operation.Arity> Op;
13251370

13261371
for (size_t VectorSize : InputVectorSizes) {
@@ -1334,6 +1379,32 @@ void dispatchTest(ID3D12Device *D3DDevice, bool VerboseLogging,
13341379
}
13351380
}
13361381

1382+
template <typename T, OpType OP>
1383+
void dispatchWaveOpTest(ID3D12Device *D3DDevice, bool VerboseLogging,
1384+
size_t OverrideInputSize, UINT WaveSize) {
1385+
1386+
const std::vector<size_t> InputVectorSizes =
1387+
getInputSizesToTest<T, OP>(OverrideInputSize);
1388+
1389+
constexpr const Operation &Operation = getOperation(OP);
1390+
Op<OP, T, Operation.Arity> Op;
1391+
1392+
const std::string AdditionalCompilerOptions =
1393+
"-DWAVE_SIZE=" + std::to_string(WaveSize) +
1394+
" -DNUMTHREADS_X=" + std::to_string(WaveSize);
1395+
1396+
for (size_t VectorSize : InputVectorSizes) {
1397+
std::vector<std::vector<T>> Inputs =
1398+
buildTestInputs<T>(VectorSize, Operation.InputSets, Operation.Arity);
1399+
1400+
auto Expected =
1401+
WaveOpExpectedBuilder<OP, T>::buildExpected(Op, Inputs, WaveSize);
1402+
1403+
runAndVerify(D3DDevice, VerboseLogging, Operation, Inputs, Expected,
1404+
Op.ValidationConfig, AdditionalCompilerOptions);
1405+
}
1406+
}
1407+
13371408
} // namespace LongVector
13381409

13391410
using namespace LongVector;
@@ -1342,6 +1413,14 @@ using namespace LongVector;
13421413
#define HLK_TEST(Op, DataType) \
13431414
TEST_METHOD(Op##_##DataType) { runTest<DataType, OpType::Op>(); }
13441415

1416+
#define HLK_WAVEOP_TEST(Op, DataType) \
1417+
TEST_METHOD(Op##_##DataType) { \
1418+
BEGIN_TEST_METHOD_PROPERTIES() \
1419+
TEST_METHOD_PROPERTY(L"Priority", L"2") \
1420+
END_TEST_METHOD_PROPERTIES() \
1421+
runWaveOpTest<DataType, OpType::Op>(); \
1422+
}
1423+
13451424
class DxilConf_SM69_Vectorized {
13461425
public:
13471426
BEGIN_TEST_CLASS(DxilConf_SM69_Vectorized)
@@ -1405,6 +1484,9 @@ class DxilConf_SM69_Vectorized {
14051484
WEX::TestExecution::RuntimeParameters::TryGetValue(L"InputSize",
14061485
OverrideInputSize);
14071486

1487+
WEX::TestExecution::RuntimeParameters::TryGetValue(L"WaveLaneCount",
1488+
OverrideWaveLaneCount);
1489+
14081490
bool IsRITP = false;
14091491
WEX::TestExecution::RuntimeParameters::TryGetValue(L"RITP", IsRITP);
14101492

@@ -1428,16 +1510,47 @@ class DxilConf_SM69_Vectorized {
14281510
return true;
14291511
}
14301512

1431-
template <typename T, OpType OP> void runTest() {
1432-
WEX::TestExecution::SetVerifyOutput verifySettings(
1433-
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
1434-
1513+
TEST_METHOD_SETUP(methodSetup) {
14351514
// It's possible a previous test case caused a device removal. If it did we
14361515
// need to try and create a new device.
1437-
if (!D3DDevice || D3DDevice->GetDeviceRemovedReason() != S_OK)
1516+
if (!D3DDevice || D3DDevice->GetDeviceRemovedReason() != S_OK) {
1517+
hlsl_test::LogCommentFmt(
1518+
L"Device was lost: Attempting to create a new D3D12 device.");
14381519
VERIFY_IS_TRUE(
14391520
createDevice(&D3DDevice, ExecTestUtils::D3D_SHADER_MODEL_6_9, false));
1521+
}
14401522

1523+
return true;
1524+
}
1525+
1526+
template <typename T, OpType OP> void runWaveOpTest() {
1527+
WEX::TestExecution::SetVerifyOutput VerifySettings(
1528+
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
1529+
1530+
UINT WaveSize = 0;
1531+
1532+
if (OverrideWaveLaneCount > 0) {
1533+
WaveSize = OverrideWaveLaneCount;
1534+
hlsl_test::LogCommentFmt(
1535+
L"Using overridden WaveLaneCount of %d for this test.", WaveSize);
1536+
} else {
1537+
D3D12_FEATURE_DATA_D3D12_OPTIONS1 WaveOpts;
1538+
VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport(
1539+
D3D12_FEATURE_D3D12_OPTIONS1, &WaveOpts, sizeof(WaveOpts)));
1540+
1541+
WaveSize = WaveOpts.WaveLaneCountMin;
1542+
}
1543+
1544+
DXASSERT_NOMSG(WaveSize > 0);
1545+
DXASSERT((WaveSize & (WaveSize - 1)) == 0, "must be a power of 2");
1546+
1547+
dispatchWaveOpTest<T, OP>(D3DDevice, VerboseLogging, OverrideInputSize,
1548+
WaveSize);
1549+
}
1550+
1551+
template <typename T, OpType OP> void runTest() {
1552+
WEX::TestExecution::SetVerifyOutput verifySettings(
1553+
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
14411554
dispatchTest<T, OP>(D3DDevice, VerboseLogging, OverrideInputSize);
14421555
}
14431556

@@ -2052,9 +2165,22 @@ class DxilConf_SM69_Vectorized {
20522165
HLK_TEST(LoadAndStore_RD_SB_SRV, double);
20532166
HLK_TEST(LoadAndStore_RD_SB_UAV, double);
20542167

2168+
HLK_WAVEOP_TEST(WaveActiveSum, int16_t);
2169+
HLK_WAVEOP_TEST(WaveActiveSum, int32_t);
2170+
HLK_WAVEOP_TEST(WaveActiveSum, int64_t);
2171+
2172+
HLK_WAVEOP_TEST(WaveActiveSum, uint16_t);
2173+
HLK_WAVEOP_TEST(WaveActiveSum, uint32_t);
2174+
HLK_WAVEOP_TEST(WaveActiveSum, uint64_t);
2175+
2176+
HLK_WAVEOP_TEST(WaveActiveSum, HLSLHalf_t);
2177+
HLK_WAVEOP_TEST(WaveActiveSum, float);
2178+
HLK_WAVEOP_TEST(WaveActiveSum, double);
2179+
20552180
private:
20562181
bool Initialized = false;
20572182
bool VerboseLogging = false;
20582183
size_t OverrideInputSize = 0;
2184+
UINT OverrideWaveLaneCount = 0;
20592185
CComPtr<ID3D12Device> D3DDevice;
20602186
};

tools/clang/unittests/HLSLExec/ShaderOpArith.xml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4101,7 +4101,20 @@ void MSMain(uint GID : SV_GroupIndex,
41014101
}
41024102
#endif
41034103
4104-
[numthreads(1,1,1)]
4104+
#ifdef NUMTHREADS_X
4105+
#define NUMTHREADS_ATTR [numthreads(NUMTHREADS_X, 1, 1)]
4106+
#else
4107+
#define NUMTHREADS_ATTR [numthreads(1, 1, 1)]
4108+
#endif
4109+
4110+
#ifdef WAVE_SIZE
4111+
#define WAVE_SIZE_ATTR [WaveSize(WAVE_SIZE)]
4112+
#else
4113+
#define WAVE_SIZE_ATTR
4114+
#endif
4115+
4116+
WAVE_SIZE_ATTR
4117+
NUMTHREADS_ATTR
41054118
void main(uint GI : SV_GroupIndex) {
41064119
41074120
#ifdef FUNC_SHUFFLE_VECTOR

0 commit comments

Comments
 (0)