From 4eb6115c06ad91c020c2f29ba851f4a3b2fc8048 Mon Sep 17 00:00:00 2001 From: David Rohr Date: Mon, 14 Apr 2025 16:01:21 +0200 Subject: [PATCH] ONNX: Use CMake defines not env variables --- Common/ML/CMakeLists.txt | 16 ++++++++-------- Common/ML/src/OrtInterface.cxx | 32 +++++++++++++------------------- 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/Common/ML/CMakeLists.txt b/Common/ML/CMakeLists.txt index 74be306c8b6a5..540fe8ebf271c 100644 --- a/Common/ML/CMakeLists.txt +++ b/Common/ML/CMakeLists.txt @@ -10,17 +10,17 @@ # or submit itself to any jurisdiction. # Pass ORT variables as a preprocessor definition -if(DEFINED ENV{ORT_ROCM_BUILD}) - add_compile_definitions(ORT_ROCM_BUILD=$ENV{ORT_ROCM_BUILD}) +if(ORT_ROCM_BUILD) + add_compile_definitions(ORT_ROCM_BUILD=1) endif() -if(DEFINED ENV{ORT_CUDA_BUILD}) - add_compile_definitions(ORT_CUDA_BUILD=$ENV{ORT_CUDA_BUILD}) +if(ORT_CUDA_BUILD) + add_compile_definitions(ORT_CUDA_BUILD=1) endif() -if(DEFINED ENV{ORT_MIGRAPHX_BUILD}) - add_compile_definitions(ORT_MIGRAPHX_BUILD=$ENV{ORT_MIGRAPHX_BUILD}) +if(ORT_MIGRAPHX_BUILD) + add_compile_definitions(ORT_MIGRAPHX_BUILD=1) endif() -if(DEFINED ENV{ORT_TENSORRT_BUILD}) - add_compile_definitions(ORT_TENSORRT_BUILD=$ENV{ORT_TENSORRT_BUILD}) +if(ORT_TENSORRT_BUILD) + add_compile_definitions(ORT_TENSORRT_BUILD=1) endif() o2_add_library(ML diff --git a/Common/ML/src/OrtInterface.cxx b/Common/ML/src/OrtInterface.cxx index fc784dd14d2dc..88f548bd4fe7b 100644 --- a/Common/ML/src/OrtInterface.cxx +++ b/Common/ML/src/OrtInterface.cxx @@ -59,29 +59,23 @@ void OrtModel::reset(std::unordered_map optionsMap) std::string dev_mem_str = "Hip"; #if defined(ORT_ROCM_BUILD) -#if ORT_ROCM_BUILD == 1 - if (device == "ROCM") { - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId)); - LOG(info) << "(ORT) ROCM execution provider set"; - } -#endif + if (device == "ROCM") { + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId)); + LOG(info) << "(ORT) ROCM execution provider set"; + } #endif #if defined(ORT_MIGRAPHX_BUILD) -#if ORT_MIGRAPHX_BUILD == 1 - if (device == "MIGRAPHX") { - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId)); - LOG(info) << "(ORT) MIGraphX execution provider set"; - } -#endif + if (device == "MIGRAPHX") { + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId)); + LOG(info) << "(ORT) MIGraphX execution provider set"; + } #endif #if defined(ORT_CUDA_BUILD) -#if ORT_CUDA_BUILD == 1 - if (device == "CUDA") { - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId)); - LOG(info) << "(ORT) CUDA execution provider set"; - dev_mem_str = "Cuda"; - } -#endif + if (device == "CUDA") { + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId)); + LOG(info) << "(ORT) CUDA execution provider set"; + dev_mem_str = "Cuda"; + } #endif if (allocateDeviceMemory) {