88
99#include "onnxruntime_c_api.h"
1010
11- #define ONNX_API (x ) \
11+ // Use as a wrapper for ORT api call. If ORT api hasn't returned null, it has failed.
12+ #define ONNX_VALIDATE_STATUS (x ) \
1213 if ((status = (x)) != NULL) \
1314 goto error;
1415
16+ // If we run on GPU, we do not use the custom allocator (redis allocator), but
17+ // the default allocator returned by ORT api. If we run on CPU, we do not need to
18+ // use this api, as the global allocator is already set to be the custom allocator.
1519#ifdef RAI_ONNXRUNTIME_USE_CUDA
16- #define GET_GLOBAL_ALLOCATOR ONNX_API(ort->GetAllocatorWithDefaultOptions(&global_allocator))
20+ #define GET_GLOBAL_ALLOCATOR \
21+ ONNX_VALIDATE_STATUS(ort->GetAllocatorWithDefaultOptions(&global_allocator))
1722#else
1823#define GET_GLOBAL_ALLOCATOR
1924#endif
@@ -45,7 +50,7 @@ void *AllocatorAlloc(OrtAllocator *ptr, size_t size) {
4550void AllocatorFree (OrtAllocator * ptr , void * p ) {
4651 (void )ptr ;
4752 size_t allocated_size = RedisModule_MallocSize (p );
48- atomic_fetch_add (& OnnxMemory , -1 * allocated_size );
53+ atomic_fetch_sub (& OnnxMemory , allocated_size );
4954 atomic_fetch_add (& OnnxMemoryAccessCounter , 1 );
5055 return RedisModule_Free (p );
5156}
@@ -179,19 +184,19 @@ OrtValue *RAI_OrtValueFromTensors(RAI_Tensor **ts, size_t count, RAI_Error *erro
179184 OrtValue * out ;
180185 GET_GLOBAL_ALLOCATOR
181186 if (count > 1 ) {
182- ONNX_API (
187+ ONNX_VALIDATE_STATUS (
183188 ort -> CreateTensorAsOrtValue (global_allocator , batched_shape , t0 -> tensor .dl_tensor .ndim ,
184189 RAI_GetOrtDataTypeFromDL (t0 -> tensor .dl_tensor .dtype ), & out ))
185190
186191 char * ort_data ;
187- ONNX_API (ort -> GetTensorMutableData (out , (void * * )& ort_data ))
192+ ONNX_VALIDATE_STATUS (ort -> GetTensorMutableData (out , (void * * )& ort_data ))
188193 size_t offset = 0 ;
189194 for (size_t i = 0 ; i < count ; i ++ ) {
190195 memcpy (ort_data + offset , RAI_TensorData (ts [i ]), RAI_TensorByteSize (ts [i ]));
191196 offset += RAI_TensorByteSize (ts [i ]);
192197 }
193198 } else {
194- ONNX_API (ort -> CreateTensorWithDataAsOrtValue (
199+ ONNX_VALIDATE_STATUS (ort -> CreateTensorWithDataAsOrtValue (
195200 global_allocator -> Info (global_allocator ), t0 -> tensor .dl_tensor .data ,
196201 RAI_TensorByteSize (t0 ), t0 -> tensor .dl_tensor .shape , t0 -> tensor .dl_tensor .ndim ,
197202 RAI_GetOrtDataTypeFromDL (t0 -> tensor .dl_tensor .dtype ), & out ))
@@ -213,7 +218,7 @@ RAI_Tensor *RAI_TensorCreateFromOrtValue(OrtValue *v, size_t batch_offset, long
213218 int64_t * strides = NULL ;
214219
215220 int is_tensor ;
216- ONNX_API (ort -> IsTensor (v , & is_tensor ))
221+ ONNX_VALIDATE_STATUS (ort -> IsTensor (v , & is_tensor ))
217222 if (!is_tensor ) {
218223 // TODO: if not tensor, flatten the data structure (sequence or map) and store it in a
219224 // tensor. If return value is string, emit warning.
@@ -223,15 +228,15 @@ RAI_Tensor *RAI_TensorCreateFromOrtValue(OrtValue *v, size_t batch_offset, long
223228 ret = RAI_TensorNew ();
224229 DLContext ctx = (DLContext ){.device_type = kDLCPU , .device_id = 0 };
225230 OrtTensorTypeAndShapeInfo * info ;
226- ONNX_API (ort -> GetTensorTypeAndShape (v , & info ))
231+ ONNX_VALIDATE_STATUS (ort -> GetTensorTypeAndShape (v , & info ))
227232
228233 {
229234 size_t ndims ;
230- ONNX_API (ort -> GetDimensionsCount (info , & ndims ))
235+ ONNX_VALIDATE_STATUS (ort -> GetDimensionsCount (info , & ndims ))
231236 int64_t dims [ndims ];
232- ONNX_API (ort -> GetDimensions (info , dims , ndims ))
237+ ONNX_VALIDATE_STATUS (ort -> GetDimensions (info , dims , ndims ))
233238 enum ONNXTensorElementDataType ort_dtype ;
234- ONNX_API (ort -> GetTensorElementType (info , & ort_dtype ))
239+ ONNX_VALIDATE_STATUS (ort -> GetTensorElementType (info , & ort_dtype ))
235240 int64_t total_batch_size = dims [0 ];
236241 total_batch_size = total_batch_size > 0 ? total_batch_size : 1 ;
237242
@@ -253,9 +258,9 @@ RAI_Tensor *RAI_TensorCreateFromOrtValue(OrtValue *v, size_t batch_offset, long
253258 DLDataType dtype = RAI_GetDLDataTypeFromORT (ort_dtype );
254259#ifdef RAI_COPY_RUN_OUTPUT
255260 char * ort_data ;
256- ONNX_API (ort -> GetTensorMutableData (v , (void * * )& ort_data ))
261+ ONNX_VALIDATE_STATUS (ort -> GetTensorMutableData (v , (void * * )& ort_data ))
257262 size_t elem_count ;
258- ONNX_API (ort -> GetTensorShapeElementCount (info , & elem_count ))
263+ ONNX_VALIDATE_STATUS (ort -> GetTensorShapeElementCount (info , & elem_count ))
259264
260265 const size_t len = dtype .bits * elem_count / 8 ;
261266 const size_t total_bytesize = len * sizeof (char );
@@ -314,50 +319,55 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
314319 OrtStatus * status = NULL ;
315320
316321 if (env == NULL ) {
317- ONNX_API (ort -> CreateEnv (ORT_LOGGING_LEVEL_WARNING , "test" , & env ))
322+ ONNX_VALIDATE_STATUS (ort -> CreateEnv (ORT_LOGGING_LEVEL_WARNING , "test" , & env ))
318323#ifndef RAI_ONNXRUNTIME_USE_CUDA
319- ONNX_API (ort -> CreateCustomDeviceAllocator (ORT_API_VERSION , AllocatorAlloc , AllocatorFree ,
320- AllocatorInfo , & global_allocator ))
321- ONNX_API (ort -> RegisterCustomDeviceAllocator (env , global_allocator ))
324+ ONNX_VALIDATE_STATUS (ort -> CreateCustomDeviceAllocator (
325+ ORT_API_VERSION , AllocatorAlloc , AllocatorFree , AllocatorInfo , & global_allocator ))
326+ ONNX_VALIDATE_STATUS (ort -> RegisterCustomDeviceAllocator (env , global_allocator ))
322327#endif
323328 }
324329
325- ONNX_API (ort -> CreateSessionOptions (& session_options ))
330+ ONNX_VALIDATE_STATUS (ort -> CreateSessionOptions (& session_options ))
326331
327332#ifndef RAI_ONNXRUNTIME_USE_CUDA
328333 // These are required to ensure that onnx will use the registered REDIS allocator (for CPU
329334 // only).
330- ONNX_API (ort -> AddSessionConfigEntry (session_options , "session.use_env_allocators" , "1" ))
331- ONNX_API (ort -> DisableCpuMemArena (session_options ))
335+ ONNX_VALIDATE_STATUS (
336+ ort -> AddSessionConfigEntry (session_options , "session.use_env_allocators" , "1" ))
337+ ONNX_VALIDATE_STATUS (ort -> DisableCpuMemArena (session_options ))
332338#endif
333339
334340 // TODO: these options could be configured at the AI.CONFIG level
335- ONNX_API (ort -> SetSessionGraphOptimizationLevel (session_options , 1 ))
336- ONNX_API (ort -> SetIntraOpNumThreads (session_options , (int )opts .backends_intra_op_parallelism ))
337- ONNX_API (ort -> SetInterOpNumThreads (session_options , (int )opts .backends_inter_op_parallelism ))
341+ ONNX_VALIDATE_STATUS (ort -> SetSessionGraphOptimizationLevel (session_options , 1 ))
342+ ONNX_VALIDATE_STATUS (
343+ ort -> SetIntraOpNumThreads (session_options , (int )opts .backends_intra_op_parallelism ))
344+ ONNX_VALIDATE_STATUS (
345+ ort -> SetInterOpNumThreads (session_options , (int )opts .backends_inter_op_parallelism ))
338346 if (!setDeviceId (devicestr , session_options , error )) {
347+ ort -> ReleaseSessionOptions (session_options );
339348 return NULL ;
340349 }
341350
342- ONNX_API (ort -> CreateSessionFromArray (env , modeldef , modellen , session_options , & session ))
351+ ONNX_VALIDATE_STATUS (
352+ ort -> CreateSessionFromArray (env , modeldef , modellen , session_options , & session ))
343353
344354 size_t n_input_nodes ;
345- ONNX_API (ort -> SessionGetInputCount (session , & n_input_nodes ))
355+ ONNX_VALIDATE_STATUS (ort -> SessionGetInputCount (session , & n_input_nodes ))
346356 size_t n_output_nodes ;
347- ONNX_API (ort -> SessionGetOutputCount (session , & n_output_nodes ))
357+ ONNX_VALIDATE_STATUS (ort -> SessionGetOutputCount (session , & n_output_nodes ))
348358
349359 GET_GLOBAL_ALLOCATOR
350360 inputs_ = array_new (char * , n_input_nodes );
351361 for (long long i = 0 ; i < n_input_nodes ; i ++ ) {
352362 char * input_name ;
353- ONNX_API (ort -> SessionGetInputName (session , i , global_allocator , & input_name ))
363+ ONNX_VALIDATE_STATUS (ort -> SessionGetInputName (session , i , global_allocator , & input_name ))
354364 inputs_ = array_append (inputs_ , input_name );
355365 }
356366
357367 outputs_ = array_new (char * , n_output_nodes );
358368 for (long long i = 0 ; i < n_output_nodes ; i ++ ) {
359369 char * output_name ;
360- ONNX_API (ort -> SessionGetOutputName (session , i , global_allocator , & output_name ))
370+ ONNX_VALIDATE_STATUS (ort -> SessionGetOutputName (session , i , global_allocator , & output_name ))
361371 outputs_ = array_append (outputs_ , output_name );
362372 }
363373
@@ -415,12 +425,12 @@ void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) {
415425
416426 GET_GLOBAL_ALLOCATOR
417427 for (uint32_t i = 0 ; i < model -> ninputs ; i ++ ) {
418- ONNX_API (ort -> AllocatorFree (global_allocator , model -> inputs [i ]))
428+ ONNX_VALIDATE_STATUS (ort -> AllocatorFree (global_allocator , model -> inputs [i ]))
419429 }
420430 array_free (model -> inputs );
421431
422432 for (uint32_t i = 0 ; i < model -> noutputs ; i ++ ) {
423- ONNX_API (ort -> AllocatorFree (global_allocator , model -> outputs [i ]))
433+ ONNX_VALIDATE_STATUS (ort -> AllocatorFree (global_allocator , model -> outputs [i ]))
424434 }
425435 array_free (model -> outputs );
426436
@@ -467,10 +477,10 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
467477
468478 OrtStatus * status = NULL ;
469479 size_t n_input_nodes ;
470- ONNX_API (ort -> SessionGetInputCount (session , & n_input_nodes ))
480+ ONNX_VALIDATE_STATUS (ort -> SessionGetInputCount (session , & n_input_nodes ))
471481
472482 size_t n_output_nodes ;
473- ONNX_API (ort -> SessionGetOutputCount (session , & n_output_nodes ))
483+ ONNX_VALIDATE_STATUS (ort -> SessionGetOutputCount (session , & n_output_nodes ))
474484 GET_GLOBAL_ALLOCATOR {
475485 const char * input_names [n_input_nodes ];
476486 const char * output_names [n_output_nodes ];
@@ -497,7 +507,8 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
497507
498508 for (size_t i = 0 ; i < n_input_nodes ; i ++ ) {
499509 char * input_name ;
500- ONNX_API (ort -> SessionGetInputName (session , i , global_allocator , & input_name ))
510+ ONNX_VALIDATE_STATUS (
511+ ort -> SessionGetInputName (session , i , global_allocator , & input_name ))
501512 input_names [i ] = input_name ;
502513
503514 RAI_Tensor * batched_input_tensors [nbatches ];
@@ -514,23 +525,25 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
514525
515526 for (size_t i = 0 ; i < n_output_nodes ; i ++ ) {
516527 char * output_name ;
517- ONNX_API (ort -> SessionGetOutputName (session , i , global_allocator , & output_name ))
528+ ONNX_VALIDATE_STATUS (
529+ ort -> SessionGetOutputName (session , i , global_allocator , & output_name ))
518530 output_names [i ] = output_name ;
519531 outputs [i ] = NULL ;
520532 }
521533
522534 OrtRunOptions * run_options = NULL ;
523- ONNX_API (ort -> Run (session , run_options , input_names , (const OrtValue * const * )inputs ,
524- n_input_nodes , output_names , n_output_nodes , outputs ))
535+ ONNX_VALIDATE_STATUS (ort -> Run (session , run_options , input_names ,
536+ (const OrtValue * const * )inputs , n_input_nodes , output_names ,
537+ n_output_nodes , outputs ))
525538
526539 for (size_t i = 0 ; i < n_output_nodes ; i ++ ) {
527540 if (nbatches > 1 ) {
528541 OrtTensorTypeAndShapeInfo * info ;
529- ONNX_API (ort -> GetTensorTypeAndShape (outputs [i ], & info ))
542+ ONNX_VALIDATE_STATUS (ort -> GetTensorTypeAndShape (outputs [i ], & info ))
530543 size_t ndims ;
531- ONNX_API (ort -> GetDimensionsCount (info , & ndims ))
544+ ONNX_VALIDATE_STATUS (ort -> GetDimensionsCount (info , & ndims ))
532545 int64_t dims [ndims ];
533- ONNX_API (ort -> GetDimensions (info , dims , ndims ))
546+ ONNX_VALIDATE_STATUS (ort -> GetDimensions (info , dims , ndims ))
534547 if (dims [0 ] != total_batch_size ) {
535548 RAI_SetError (error , RAI_EMODELRUN ,
536549 "ERR Model did not generate the expected batch size" );
0 commit comments