Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions unified-runtime/source/adapters/level_zero/v2/queue_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "kernel.hpp"
#include "lockable.hpp"
#include "memory.hpp"
#include "ur.hpp"
// #include "ur/ur.hpp"

#include "../common/latency_tracker.hpp"
#include "../helpers/kernel_helpers.hpp"
Expand All @@ -34,6 +34,8 @@

namespace v2 {

// constexpr uint64_t initialSlotsForBatches = 10;

ur_queue_batched_t::ur_queue_batched_t(
ur_context_handle_t hContext, ur_device_handle_t hDevice, uint32_t ordinal,
ze_command_queue_priority_t priority, std::optional<int32_t> index,
Expand Down Expand Up @@ -123,10 +125,10 @@ ur_result_t batch_manager::renewRegularUnlocked(
ur_result_t
ur_queue_batched_t::renewBatchUnlocked(locked<batch_manager> &batchLocked) {
if (batchLocked->isLimitOfUsedCommandListsReached()) {
UR_CALL(queueFinishUnlocked(batchLocked));
return queueFinishUnlocked(batchLocked);
} else {
return batchLocked->renewRegularUnlocked(getNewRegularCmdList());
}

return batchLocked->renewRegularUnlocked(getNewRegularCmdList());
}

ur_result_t batch_manager::enqueueCurrentBatchUnlocked() {
Expand Down Expand Up @@ -214,20 +216,26 @@ ur_result_t batch_manager::batchFinish() {

UR_CALL(activeBatch.releaseSubmittedKernels());

{
if (!isActiveBatchEmpty()) {
TRACK_SCOPE_LATENCY("ur_queue_batched_t::resetRegCmdlist");
ZE2UR_CALL(zeCommandListReset, (activeBatch.getZeCommandList()));

setBatchEmpty();
regularGenerationNumber++;
}

runBatches.clear();
setBatchEmpty();
// regularGenerationNumber++;

return UR_RESULT_SUCCESS;
}

ur_result_t
ur_queue_batched_t::queueFinishUnlocked(locked<batch_manager> &batchLocked) {
UR_CALL(batchLocked->enqueueCurrentBatchUnlocked());
if (!batchLocked->isActiveBatchEmpty()) {
UR_CALL(batchLocked->enqueueCurrentBatchUnlocked());
}

UR_CALL(batchLocked->hostSynchronize());

UR_CALL(queueFinishPoolsUnlocked());
Expand Down Expand Up @@ -1070,7 +1078,12 @@ ur_queue_batched_t::queueFlushUnlocked(locked<batch_manager> &batchLocked) {

ur_result_t ur_queue_batched_t::queueFlush() {
auto batchLocked = currentCmdLists.lock();
return queueFlushUnlocked(batchLocked);

if (batchLocked->isActiveBatchEmpty()) {
return UR_RESULT_SUCCESS;
} else {
return queueFlushUnlocked(batchLocked);
}
}

} // namespace v2
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@

namespace v2 {

// The limit of regular command lists stored for execution; if exceeded, the
// vector is cleared as part of queueFinish and slots are renewed.
inline constexpr uint64_t initialSlotsForBatches = 10;

// For the explanation of the purpose of generation numbers, see the comment for
// regularGenerationNumber below
inline constexpr ur_event_generation_t initialGenerationNumber = 0;

struct batch_manager {
private:
// The currently active regular command list, which may be replaced in the
Expand Down Expand Up @@ -75,9 +83,6 @@ struct batch_manager {
// associated with the event has already been submitted for execution and
// additional submission of the current batch is not needed.
ur_event_generation_t regularGenerationNumber;
// The limit of regular command lists stored for execution; if exceeded, the
// vector is cleared as part of queueFinish and slots are renewed.
static constexpr uint64_t initialSlotsForBatches = 10;
// Whether any operation has been enqueued on the current batch
bool isEmpty = true;

Expand All @@ -91,7 +96,7 @@ struct batch_manager {
immediateList(context, device,
std::forward<v2::raii::command_list_unique_handle>(
commandListImmediate)),
regularGenerationNumber(0) {
regularGenerationNumber(initialGenerationNumber) {
runBatches.reserve(initialSlotsForBatches);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ ur_result_t urQueueCreate(ur_context_handle_t hContext,
flags |= UR_QUEUE_FLAG_SUBMISSION_BATCHED;
}

// For tests in CI
flags |= UR_QUEUE_FLAG_SUBMISSION_BATCHED;

auto zeIndex = v2::getZeIndex(pProperties);

bool immediate = true;
Expand All @@ -95,6 +98,8 @@ ur_result_t urQueueCreate(ur_context_handle_t hContext,
"urQueueCreate called with both UR_QUEUE_FLAG_SUBMISSION_BATCHED "
"and UR_QUEUE_FLAG_SUBMISSION_IMMEDIATE in ur_queue_flags_t. "
"Defaulting to the immediate submission mode.");

flags &= ~UR_QUEUE_FLAG_SUBMISSION_BATCHED;
}

immediate = true;
Expand Down
6 changes: 6 additions & 0 deletions unified-runtime/test/adapters/level_zero/enqueue_alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <thread>

#include "ur_api.h"
#include "uur/utils.h"
#include <uur/fixtures.h>

struct EnqueueAllocTestParam {
Expand Down Expand Up @@ -81,6 +82,7 @@ struct urL0EnqueueAllocMultiQueueSameDeviceTest
for (size_t i = 0; i < param.numQueues; i++) {
ur_queue_handle_t queue = nullptr;
ASSERT_SUCCESS(urQueueCreate(context, device, 0, &queue));
SKIP_IF_BATCHED_QUEUE(queue);
queues.push_back(queue);
}
}
Expand Down Expand Up @@ -353,6 +355,10 @@ TEST_P(urL0EnqueueAllocMultiQueueSameDeviceTest, SuccessMt) {
const auto checkUSMSupportFunc =
std::get<1>(this->GetParam()).funcParams.checkUSMSupportFunc;

if (numQueues > 0) {
SKIP_IF_BATCHED_QUEUE(queues[0]);;
}

ur_device_usm_access_capability_flags_t USMSupport = 0;
ASSERT_SUCCESS(checkUSMSupportFunc(device, USMSupport));
if (!(USMSupport & UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS)) {
Expand Down
16 changes: 16 additions & 0 deletions unified-runtime/test/adapters/level_zero/v2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ add_l0_v2_devices_test(memory_residency
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/ur_level_zero.cpp
)

add_l0_v2_devices_test(batched_queue
batched_queue_test.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/adapter.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/common.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/device.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/platform.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/ur_level_zero.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/v2/event_pool_cache.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/v2/event_pool.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/v2/event_provider_counter.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/v2/event_provider_normal.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/v2/event.cpp
${PROJECT_SOURCE_DIR}/source/ur/ur.cpp
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/v2/command_list_cache.cpp
)

if(NOT UR_FOUND_DPCXX)
# Tests that require kernels can't be used if we aren't generating
# device binaries
Expand Down
Loading
Loading