@@ -313,10 +313,10 @@ static WEX::Common::String getInputValueSetName(size_t Index) {
313313 return ValueSetName;
314314}
315315
316- std::string getCompilerOptionsString (const Operation &Operation,
317- const DataType &OpDataType,
318- const DataType &OutDataType,
319- size_t VectorSize ) {
316+ std::string getCompilerOptionsString (
317+ const Operation &Operation, const DataType &OpDataType,
318+ const DataType &OutDataType, size_t VectorSize ,
319+ std::optional<std::string> AdditionalOptions = std:: nullopt ) {
320320 std::stringstream CompilerOptions;
321321
322322 if (OpDataType.Is16Bit || OutDataType.Is16Bit )
@@ -337,6 +337,9 @@ std::string getCompilerOptionsString(const Operation &Operation,
337337
338338 CompilerOptions << " -DBASIC_OP_TYPE=0x" << std::hex << Operation.Arity ;
339339
340+ if (AdditionalOptions)
341+ CompilerOptions << " " << AdditionalOptions.value ();
342+
340343 return CompilerOptions.str ();
341344}
342345
@@ -387,7 +390,8 @@ template <typename OUT_TYPE, typename T>
387390std::optional<std::vector<OUT_TYPE>>
388391runTest (ID3D12Device *D3DDevice, bool VerboseLogging,
389392 const Operation &Operation, const InputSets<T> &Inputs,
390- size_t ExpectedOutputSize) {
393+ size_t ExpectedOutputSize,
394+ std::optional<std::string> AdditionalCompilerOptions) {
391395 DXASSERT_NOMSG (Inputs.size () == Operation.Arity );
392396
393397 if (VerboseLogging) {
@@ -403,8 +407,9 @@ runTest(ID3D12Device *D3DDevice, bool VerboseLogging,
403407
404408 // We have to construct the string outside of the lambda. Otherwise it's
405409 // cleaned up when the lambda finishes executing but before the shader runs.
406- std::string CompilerOptionsString = getCompilerOptionsString (
407- Operation, OpDataType, OutDataType, Inputs[0 ].size ());
410+ std::string CompilerOptionsString =
411+ getCompilerOptionsString (Operation, OpDataType, OutDataType,
412+ Inputs[0 ].size (), AdditionalCompilerOptions);
408413
409414 dxc::SpecificDllLoader DxilDllLoader;
410415 CComPtr<IStream> TestXML;
@@ -570,13 +575,15 @@ struct ValidationConfig {
570575};
571576
572577template <typename T, typename OUT_TYPE>
573- void runAndVerify (ID3D12Device *D3DDevice, bool VerboseLogging,
574- const Operation &Operation, const InputSets<T> &Inputs,
575- const std::vector<OUT_TYPE> &Expected,
576- const ValidationConfig &ValidationConfig) {
578+ void runAndVerify (
579+ ID3D12Device *D3DDevice, bool VerboseLogging, const Operation &Operation,
580+ const InputSets<T> &Inputs, const std::vector<OUT_TYPE> &Expected,
581+ const ValidationConfig &ValidationConfig,
582+ std::optional<std::string> AdditionalCompilerOptions = std::nullopt ) {
577583
578- std::optional<std::vector<OUT_TYPE>> Actual = runTest<OUT_TYPE>(
579- D3DDevice, VerboseLogging, Operation, Inputs, Expected.size ());
584+ std::optional<std::vector<OUT_TYPE>> Actual =
585+ runTest<OUT_TYPE>(D3DDevice, VerboseLogging, Operation, Inputs,
586+ Expected.size (), AdditionalCompilerOptions);
580587
581588 // If the test didn't run, don't verify anything.
582589 if (!Actual)
@@ -1253,6 +1260,19 @@ FLOAT_SPECIAL_OP(OpType::IsInf, (std::isinf(A)));
12531260FLOAT_SPECIAL_OP (OpType::IsNan, (std::isnan(A)));
12541261#undef FLOAT_SPECIAL_OP
12551262
1263+ //
1264+ // Wave Ops
1265+ //
1266+
1267+ #define WAVE_ACTIVE_OP (OP, IMPL ) \
1268+ template <typename T> struct Op <OP, T, 1 > : DefaultValidation<T> { \
1269+ T operator ()(T A, T WaveSize) { return IMPL; } \
1270+ };
1271+
1272+ WAVE_ACTIVE_OP (OpType::WaveActiveSum, (A * WaveSize));
1273+
1274+ #undef WAVE_ACTIVE_OP
1275+
12561276//
12571277// dispatchTest
12581278//
@@ -1296,9 +1316,25 @@ template <OpType OP, typename T> struct ExpectedBuilder {
12961316 }
12971317};
12981318
1319+ template <OpType OP, typename T> struct WaveOpExpectedBuilder {
1320+
1321+ static auto buildExpected (Op<OP, T, 1 > Op, const InputSets<T> &Inputs,
1322+ UINT WaveSize) {
1323+ DXASSERT_NOMSG (Inputs.size () == 1 );
1324+ const T WaveSizeT = static_cast <T>(WaveSize);
1325+
1326+ std::vector<decltype (Op (T (), WaveSizeT))> Expected;
1327+ Expected.reserve (Inputs[0 ].size ());
1328+
1329+ for (size_t I = 0 ; I < Inputs[0 ].size (); ++I)
1330+ Expected.push_back (Op (Inputs[0 ][I], WaveSizeT));
1331+
1332+ return Expected;
1333+ }
1334+ };
1335+
12991336template <typename T, OpType OP>
1300- void dispatchTest (ID3D12Device *D3DDevice, bool VerboseLogging,
1301- size_t OverrideInputSize) {
1337+ std::vector<size_t > getInputSizesToTest (size_t OverrideInputSize) {
13021338 std::vector<size_t > InputVectorSizes;
13031339 const std::array<size_t , 8 > DefaultInputSizes = {3 , 5 , 16 , 17 ,
13041340 35 , 100 , 256 , 1024 };
@@ -1319,8 +1355,17 @@ void dispatchTest(ID3D12Device *D3DDevice, bool VerboseLogging,
13191355 InputVectorSizes.push_back (MaxInputSize);
13201356 }
13211357
1322- constexpr const Operation &Operation = getOperation (OP);
1358+ return InputVectorSizes;
1359+ }
1360+
1361+ template <typename T, OpType OP>
1362+ void dispatchTest (ID3D12Device *D3DDevice, bool VerboseLogging,
1363+ size_t OverrideInputSize) {
13231364
1365+ const std::vector<size_t > InputVectorSizes =
1366+ getInputSizesToTest<T, OP>(OverrideInputSize);
1367+
1368+ constexpr const Operation &Operation = getOperation (OP);
13241369 Op<OP, T, Operation.Arity > Op;
13251370
13261371 for (size_t VectorSize : InputVectorSizes) {
@@ -1334,6 +1379,32 @@ void dispatchTest(ID3D12Device *D3DDevice, bool VerboseLogging,
13341379 }
13351380}
13361381
1382+ template <typename T, OpType OP>
1383+ void dispatchWaveOpTest (ID3D12Device *D3DDevice, bool VerboseLogging,
1384+ size_t OverrideInputSize, UINT WaveSize) {
1385+
1386+ const std::vector<size_t > InputVectorSizes =
1387+ getInputSizesToTest<T, OP>(OverrideInputSize);
1388+
1389+ constexpr const Operation &Operation = getOperation (OP);
1390+ Op<OP, T, Operation.Arity > Op;
1391+
1392+ const std::string AdditionalCompilerOptions =
1393+ " -DWAVE_SIZE=" + std::to_string (WaveSize) +
1394+ " -DNUMTHREADS_X=" + std::to_string (WaveSize);
1395+
1396+ for (size_t VectorSize : InputVectorSizes) {
1397+ std::vector<std::vector<T>> Inputs =
1398+ buildTestInputs<T>(VectorSize, Operation.InputSets , Operation.Arity );
1399+
1400+ auto Expected =
1401+ WaveOpExpectedBuilder<OP, T>::buildExpected (Op, Inputs, WaveSize);
1402+
1403+ runAndVerify (D3DDevice, VerboseLogging, Operation, Inputs, Expected,
1404+ Op.ValidationConfig , AdditionalCompilerOptions);
1405+ }
1406+ }
1407+
13371408} // namespace LongVector
13381409
13391410using namespace LongVector ;
@@ -1342,6 +1413,14 @@ using namespace LongVector;
13421413#define HLK_TEST (Op, DataType ) \
13431414 TEST_METHOD (Op##_##DataType) { runTest<DataType, OpType::Op>(); }
13441415
1416+ #define HLK_WAVEOP_TEST (Op, DataType ) \
1417+ TEST_METHOD (Op##_##DataType) { \
1418+ BEGIN_TEST_METHOD_PROPERTIES () \
1419+ TEST_METHOD_PROPERTY (L" Priority" , L" 2" ) \
1420+ END_TEST_METHOD_PROPERTIES () \
1421+ runWaveOpTest<DataType, OpType::Op>(); \
1422+ }
1423+
13451424class DxilConf_SM69_Vectorized {
13461425public:
13471426 BEGIN_TEST_CLASS (DxilConf_SM69_Vectorized)
@@ -1405,6 +1484,9 @@ class DxilConf_SM69_Vectorized {
14051484 WEX::TestExecution::RuntimeParameters::TryGetValue (L" InputSize" ,
14061485 OverrideInputSize);
14071486
1487+ WEX::TestExecution::RuntimeParameters::TryGetValue (L" WaveLaneCount" ,
1488+ OverrideWaveLaneCount);
1489+
14081490 bool IsRITP = false ;
14091491 WEX::TestExecution::RuntimeParameters::TryGetValue (L" RITP" , IsRITP);
14101492
@@ -1428,16 +1510,47 @@ class DxilConf_SM69_Vectorized {
14281510 return true ;
14291511 }
14301512
1431- template <typename T, OpType OP> void runTest () {
1432- WEX::TestExecution::SetVerifyOutput verifySettings (
1433- WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
1434-
1513+ TEST_METHOD_SETUP (methodSetup) {
14351514 // It's possible a previous test case caused a device removal. If it did we
14361515 // need to try and create a new device.
1437- if (!D3DDevice || D3DDevice->GetDeviceRemovedReason () != S_OK)
1516+ if (!D3DDevice || D3DDevice->GetDeviceRemovedReason () != S_OK) {
1517+ hlsl_test::LogCommentFmt (
1518+ L" Device was lost: Attempting to create a new D3D12 device." );
14381519 VERIFY_IS_TRUE (
14391520 createDevice (&D3DDevice, ExecTestUtils::D3D_SHADER_MODEL_6_9, false ));
1521+ }
14401522
1523+ return true ;
1524+ }
1525+
1526+ template <typename T, OpType OP> void runWaveOpTest () {
1527+ WEX::TestExecution::SetVerifyOutput VerifySettings (
1528+ WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
1529+
1530+ UINT WaveSize = 0 ;
1531+
1532+ if (OverrideWaveLaneCount > 0 ) {
1533+ WaveSize = OverrideWaveLaneCount;
1534+ hlsl_test::LogCommentFmt (
1535+ L" Using overridden WaveLaneCount of %d for this test." , WaveSize);
1536+ } else {
1537+ D3D12_FEATURE_DATA_D3D12_OPTIONS1 WaveOpts;
1538+ VERIFY_SUCCEEDED (D3DDevice->CheckFeatureSupport (
1539+ D3D12_FEATURE_D3D12_OPTIONS1, &WaveOpts, sizeof (WaveOpts)));
1540+
1541+ WaveSize = WaveOpts.WaveLaneCountMin ;
1542+ }
1543+
1544+ DXASSERT_NOMSG (WaveSize > 0 );
1545+ DXASSERT ((WaveSize & (WaveSize - 1 )) == 0 , " must be a power of 2" );
1546+
1547+ dispatchWaveOpTest<T, OP>(D3DDevice, VerboseLogging, OverrideInputSize,
1548+ WaveSize);
1549+ }
1550+
1551+ template <typename T, OpType OP> void runTest () {
1552+ WEX::TestExecution::SetVerifyOutput verifySettings (
1553+ WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
14411554 dispatchTest<T, OP>(D3DDevice, VerboseLogging, OverrideInputSize);
14421555 }
14431556
@@ -2052,9 +2165,22 @@ class DxilConf_SM69_Vectorized {
20522165 HLK_TEST (LoadAndStore_RD_SB_SRV, double );
20532166 HLK_TEST (LoadAndStore_RD_SB_UAV, double );
20542167
2168+ HLK_WAVEOP_TEST (WaveActiveSum, int16_t );
2169+ HLK_WAVEOP_TEST (WaveActiveSum, int32_t );
2170+ HLK_WAVEOP_TEST (WaveActiveSum, int64_t );
2171+
2172+ HLK_WAVEOP_TEST (WaveActiveSum, uint16_t );
2173+ HLK_WAVEOP_TEST (WaveActiveSum, uint32_t );
2174+ HLK_WAVEOP_TEST (WaveActiveSum, uint64_t );
2175+
2176+ HLK_WAVEOP_TEST (WaveActiveSum, HLSLHalf_t);
2177+ HLK_WAVEOP_TEST (WaveActiveSum, float );
2178+ HLK_WAVEOP_TEST (WaveActiveSum, double );
2179+
20552180private:
20562181 bool Initialized = false ;
20572182 bool VerboseLogging = false ;
20582183 size_t OverrideInputSize = 0 ;
2184+ UINT OverrideWaveLaneCount = 0 ;
20592185 CComPtr<ID3D12Device> D3DDevice;
20602186};
0 commit comments