@@ -16,9 +16,10 @@ int RAI_InitGlobalRunSessionsORT() {
1616 OnnxRunSessionCtx * * run_sessions_array =
1717 array_new (OnnxRunSessionCtx * , RAI_working_threads_num );
1818 for (size_t i = 0 ; i < RAI_working_threads_num ; i ++ ) {
19- OnnxRunSessionCtx * entry = RedisModule_Calloc ( 1 , sizeof (OnnxRunSessionCtx ));
20- entry -> runState = RedisModule_Calloc ( 1 , sizeof (entry -> runState ));
19+ OnnxRunSessionCtx * entry = RedisModule_Alloc ( sizeof (OnnxRunSessionCtx ));
20+ entry -> runState = RedisModule_Alloc ( sizeof (entry -> runState ));
2121 * entry -> runState = RUN_SESSION_AVAILABLE ;
22+ entry -> queuingTime = LLONG_MAX ;
2223 run_sessions_array = array_append (run_sessions_array , entry );
2324 }
2425 onnx_global_run_sessions -> OnnxRunSessions = run_sessions_array ;
@@ -44,9 +45,10 @@ int RAI_AddNewDeviceORT(const char *device_str) {
4445 // initialized to NULL.
4546 size_t size = RedisAI_GetNumThreadsPerQueue ();
4647 for (size_t i = 0 ; i < size ; i ++ ) {
47- OnnxRunSessionCtx * entry = RedisModule_Calloc ( 1 , sizeof (OnnxRunSessionCtx ));
48- entry -> runState = RedisModule_Calloc ( 1 , sizeof (entry -> runState ));
48+ OnnxRunSessionCtx * entry = RedisModule_Alloc ( sizeof (OnnxRunSessionCtx ));
49+ entry -> runState = RedisModule_Alloc ( sizeof (entry -> runState ));
4950 * entry -> runState = RUN_SESSION_AVAILABLE ;
51+ entry -> queuingTime = LLONG_MAX ;
5052 run_sessions_array = array_append (run_sessions_array , entry );
5153 }
5254 onnx_global_run_sessions -> OnnxRunSessions = run_sessions_array ;
@@ -65,7 +67,10 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s
6567 long long timeout = RedisAI_GetModelExecutionTimeout ();
6668 for (size_t i = 0 ; i < len ; i ++ ) {
6769 // Check if a sessions is running for too long, and kill it if is still active.
68- if (curr_time - run_sessions_ctx [i ]-> queuingTime > timeout ) {
70+ // If entry doesn't contain active session, its queueing time is LLONG_MAX
71+ // (thus the following condition will always be evaluated as false)
72+ if (curr_time - __atomic_load_n (& (run_sessions_ctx [i ]-> queuingTime ), __ATOMIC_RELAXED ) >
73+ timeout ) {
6974 if (__sync_bool_compare_and_swap (run_sessions_ctx [i ]-> runState , RUN_SESSION_ACTIVE ,
7075 RUN_SESSION_INVALID )) {
7176 // Set termination flag, validate that ONNX API succeeded (returns NULL)
@@ -91,10 +96,11 @@ void RAI_ActivateRunSessionCtxORT(OrtRunOptions *new_run_options, long *run_sess
9196 }
9297 OnnxRunSessionCtx * entry = onnx_global_run_sessions -> OnnxRunSessions [* run_session_index ];
9398 RedisModule_Assert (* entry -> runState == RUN_SESSION_AVAILABLE );
99+ RedisModule_Assert (entry -> queuingTime == LLONG_MAX );
94100
95101 // Update the entry with the current session data.
96102 entry -> runOptions = new_run_options ;
97- entry -> queuingTime = mstime ();
103+ __atomic_store_n ( & ( entry -> queuingTime ), mstime (), __ATOMIC_RELAXED );
98104 __atomic_store_n (entry -> runState , RUN_SESSION_ACTIVE , __ATOMIC_RELAXED );
99105 pthread_rwlock_unlock (& (onnx_global_run_sessions -> rwlock ));
100106}
@@ -104,14 +110,16 @@ void RAI_ResetRunSessionCtxORT(long run_session_index) {
104110 pthread_rwlock_rdlock (& (onnx_global_run_sessions -> rwlock ));
105111 OnnxRunSessionCtx * entry = onnx_global_run_sessions -> OnnxRunSessions [run_session_index ];
106112
107- // Busy wait until we get a valid state, as we might access this entry from
108- // the main thread callback and call ONNX API to terminate the run session.
109- RunSessionState state ;
110- do {
111- state = __atomic_load_n (entry -> runState , __ATOMIC_RELAXED );
112- } while (state != RUN_SESSION_ACTIVE && state != RUN_SESSION_TERMINATED );
113-
113+ // In most cases, state will be ACTIVE at this point, and we want to turn in to
114+ // AVAILABLE atomically, so we won't call the kill switch at the same time.
115+ if (!__sync_bool_compare_and_swap (entry -> runState , RUN_SESSION_ACTIVE , RUN_SESSION_AVAILABLE )) {
116+ // If state was not ACTIVE, it is INVALID/TERMINATED, due to a timeout that
117+ // has occurred. We do busy wait until the state is set to TERMINATE.
118+ while (!__sync_bool_compare_and_swap (entry -> runState , RUN_SESSION_TERMINATED ,
119+ RUN_SESSION_AVAILABLE ))
120+ ;
121+ }
122+ __atomic_store_n (& (entry -> queuingTime ), LLONG_MAX , __ATOMIC_RELAXED );
114123 ort -> ReleaseRunOptions (entry -> runOptions );
115- __atomic_store_n (entry -> runState , RUN_SESSION_AVAILABLE , __ATOMIC_RELAXED );
116124 pthread_rwlock_unlock (& (onnx_global_run_sessions -> rwlock ));
117125}
0 commit comments