1313 if ((status = (x)) != NULL) \
1414 goto error;
1515
16- // If we run on GPU, we do not use the custom allocator (redis allocator), but
17- // the default allocator returned by ORT api. If we run on CPU, we do not need to
18- // use this api, as the global allocator is already set to be the custom allocator.
19- #ifdef RAI_ONNXRUNTIME_USE_CUDA
20- #define GET_GLOBAL_ALLOCATOR \
21- ONNX_VALIDATE_STATUS(ort->GetAllocatorWithDefaultOptions(&global_allocator))
22- #else
23- #define GET_GLOBAL_ALLOCATOR
24- #endif
25-
2616OrtEnv * env = NULL ;
17+
18+ // For model that run on GPU, onnx will not use the custom allocator (redis allocator), but
19+ // the onnx allocator for GPU. But for the auxilery allocations of the input and output names,
20+ // we will use the custom global allocator for models that run on GPU as well
2721OrtAllocator * global_allocator = NULL ;
2822unsigned long long OnnxMemory = 0 ;
2923unsigned long long OnnxMemoryAccessCounter = 0 ;
@@ -182,7 +176,6 @@ OrtValue *RAI_OrtValueFromTensors(RAI_Tensor **ts, size_t count, RAI_Error *erro
182176 batched_shape [0 ] = batch_size ;
183177
184178 OrtValue * out ;
185- GET_GLOBAL_ALLOCATOR
186179 if (count > 1 ) {
187180 ONNX_VALIDATE_STATUS (
188181 ort -> CreateTensorAsOrtValue (global_allocator , batched_shape , t0 -> tensor .dl_tensor .ndim ,
@@ -321,18 +314,16 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
321314
322315 if (env == NULL ) {
323316 ONNX_VALIDATE_STATUS (ort -> CreateEnv (ORT_LOGGING_LEVEL_WARNING , "test" , & env ))
324- #ifndef RAI_ONNXRUNTIME_USE_CUDA
325317 ONNX_VALIDATE_STATUS (ort -> CreateCustomDeviceAllocator (
326318 ORT_API_VERSION , AllocatorAlloc , AllocatorFree , AllocatorInfo , & global_allocator ))
327319 ONNX_VALIDATE_STATUS (ort -> RegisterCustomDeviceAllocator (env , global_allocator ))
328- #endif
329320 }
330321
331322 ONNX_VALIDATE_STATUS (ort -> CreateSessionOptions (& session_options ))
332323
333324#ifndef RAI_ONNXRUNTIME_USE_CUDA
334- // These are required to ensure that onnx will use the registered REDIS allocator (for CPU
335- // only ).
325+ // These are required to ensure that onnx will use the registered REDIS allocator (for
326+ // a model that defined to run on CPU ).
336327 ONNX_VALIDATE_STATUS (
337328 ort -> AddSessionConfigEntry (session_options , "session.use_env_allocators" , "1" ))
338329 ONNX_VALIDATE_STATUS (ort -> DisableCpuMemArena (session_options ))
@@ -357,7 +348,6 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
357348 size_t n_output_nodes ;
358349 ONNX_VALIDATE_STATUS (ort -> SessionGetOutputCount (session , & n_output_nodes ))
359350
360- GET_GLOBAL_ALLOCATOR
361351 inputs_ = array_new (char * , n_input_nodes );
362352 for (long long i = 0 ; i < n_input_nodes ; i ++ ) {
363353 char * input_name ;
@@ -424,7 +414,6 @@ void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) {
424414 const OrtApi * ort = OrtGetApiBase ()-> GetApi (1 );
425415 OrtStatus * status = NULL ;
426416
427- GET_GLOBAL_ALLOCATOR
428417 for (uint32_t i = 0 ; i < model -> ninputs ; i ++ ) {
429418 ONNX_VALIDATE_STATUS (ort -> AllocatorFree (global_allocator , model -> inputs [i ]))
430419 }
@@ -481,8 +470,7 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
481470 ONNX_VALIDATE_STATUS (ort -> SessionGetInputCount (session , & n_input_nodes ))
482471
483472 size_t n_output_nodes ;
484- ONNX_VALIDATE_STATUS (ort -> SessionGetOutputCount (session , & n_output_nodes ))
485- GET_GLOBAL_ALLOCATOR {
473+ ONNX_VALIDATE_STATUS (ort -> SessionGetOutputCount (session , & n_output_nodes )) {
486474 const char * input_names [n_input_nodes ];
487475 const char * output_names [n_output_nodes ];
488476
0 commit comments