@@ -101,6 +101,112 @@ static void Dag_StoreOutputsFromModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *c
101101 RAI_ContextUnlock (rinfo );
102102}
103103
104+ static int _StoreTensorInKeySpace (RedisModuleCtx * ctx , RAI_Tensor * tensor ,
105+ RedisModuleString * persist_key_name , bool mangled_name ) {
106+
107+ int ret = REDISMODULE_ERR ;
108+ RedisModuleKey * key ;
109+ size_t persist_key_len ;
110+ const char * persist_key_str = RedisModule_StringPtrLen (persist_key_name , & persist_key_len );
111+
112+ RedisModuleString * demangled_key_name ;
113+ if (mangled_name ) {
114+ demangled_key_name = RedisModule_CreateString (NULL , persist_key_str , persist_key_len - 4 );
115+ } else {
116+ demangled_key_name = RedisModule_CreateString (NULL , persist_key_str , persist_key_len );
117+ }
118+
119+ const int status =
120+ RAI_OpenKey_Tensor (ctx , demangled_key_name , & key , REDISMODULE_READ | REDISMODULE_WRITE );
121+ if (status == REDISMODULE_ERR ) {
122+ RedisModule_ReplyWithError (ctx , "ERR could not save tensor" );
123+ goto clean_up ;
124+ }
125+ if (RedisModule_ModuleTypeSetValue (key , RedisAI_TensorType , RAI_TensorGetShallowCopy (tensor )) !=
126+ REDISMODULE_OK ) {
127+ RedisModule_ReplyWithError (ctx , "ERR could not save tensor" );
128+ RedisModule_CloseKey (key );
129+ goto clean_up ;
130+ }
131+ // Only if we got until here, tensor is saved in keyspace.
132+ RedisAI_ReplicateTensorSet (ctx , demangled_key_name , tensor );
133+ ret = REDISMODULE_OK ;
134+
135+ clean_up :
136+ RedisModule_FreeString (NULL , demangled_key_name );
137+ return ret ;
138+ }
139+
140+ static void _DAG_PersistTensors (RedisModuleCtx * ctx , RedisAI_RunInfo * rinfo ) {
141+
142+ AI_dictIterator * persist_iter = AI_dictGetSafeIterator (rinfo -> dagTensorsPersistedContext );
143+ AI_dictEntry * persist_entry = AI_dictNext (persist_iter );
144+
145+ while (persist_entry ) {
146+ RedisModuleString * persist_key_name = AI_dictGetKey (persist_entry );
147+ AI_dictEntry * tensor_entry = AI_dictFind (rinfo -> dagTensorsContext , persist_key_name );
148+ RedisModule_Assert (tensor_entry );
149+ RAI_Tensor * tensor = AI_dictGetVal (tensor_entry );
150+ if (tensor == NULL ) {
151+ persist_entry = AI_dictNext (persist_iter );
152+ continue ;
153+ }
154+ if (_StoreTensorInKeySpace (ctx , tensor , persist_key_name , true) == REDISMODULE_ERR ) {
155+ * rinfo -> dagError = 1 ;
156+ RedisModule_Log (ctx , "warning" ,
157+ "Could not persist tensor under the key (%s) after executing DAGRUN "
158+ "command, persist stopped" ,
159+ RedisModule_StringPtrLen (persist_key_name , NULL ));
160+ AI_dictReleaseIterator (persist_iter );
161+ return ;
162+ }
163+ persist_entry = AI_dictNext (persist_iter );
164+ }
165+ AI_dictReleaseIterator (persist_iter );
166+ }
167+
168+ static void _ModelSingleOp_PersistTensors (RedisModuleCtx * ctx , RAI_DagOp * op ) {
169+
170+ const size_t noutputs = RAI_ModelRunCtxNumOutputs (op -> mctx );
171+ for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
172+ RedisModuleString * persist_key_name = op -> outkeys [outputNumber ];
173+ RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (op -> mctx , outputNumber );
174+ tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
175+ if (!tensor )
176+ continue ;
177+
178+ if (_StoreTensorInKeySpace (ctx , tensor , persist_key_name , false) == REDISMODULE_ERR ) {
179+ RedisModule_Log (ctx , "warning" ,
180+ "Could not persist tensor under the key (%s) after executing DAGRUN "
181+ "command, persist stopped" ,
182+ RedisModule_StringPtrLen (persist_key_name , NULL ));
183+ op -> result = REDISMODULE_ERR ;
184+ return ;
185+ }
186+ }
187+ }
188+
189+ static void _ScriptSingleOp_PersistTensors (RedisModuleCtx * ctx , RAI_DagOp * op ) {
190+
191+ const size_t noutputs = RAI_ScriptRunCtxNumOutputs (op -> sctx );
192+ for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
193+ RedisModuleString * persist_key_name = op -> outkeys [outputNumber ];
194+ RAI_Tensor * tensor = RAI_ScriptRunCtxOutputTensor (op -> sctx , outputNumber );
195+ tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
196+ if (!tensor )
197+ continue ;
198+
199+ if (_StoreTensorInKeySpace (ctx , tensor , persist_key_name , false) == REDISMODULE_ERR ) {
200+ RedisModule_Log (ctx , "warning" ,
201+ "Could not persist tensor under the key (%s) after executing DAGRUN "
202+ "command, persist stopped" ,
203+ RedisModule_StringPtrLen (persist_key_name , NULL ));
204+ op -> result = REDISMODULE_ERR ;
205+ return ;
206+ }
207+ }
208+ }
209+
104210/**
105211 * Execution of a MODELRUN DAG step.
106212 * If an error occurs, it is recorded in the DagOp struct.
@@ -490,128 +596,24 @@ void RedisAI_BatchedDagRunSessionStep(RedisAI_RunInfo **batched_rinfo, const cha
490596 return ;
491597}
492598
493- static int _StoreTensorInKeySpace (RedisModuleCtx * ctx , RAI_Tensor * tensor ,
494- RedisModuleString * persist_key_name , bool mangled_name ) {
495-
496- int ret = REDISMODULE_ERR ;
497- RedisModuleKey * key ;
498- size_t persist_key_len ;
499- const char * persist_key_str = RedisModule_StringPtrLen (persist_key_name , & persist_key_len );
500-
501- RedisModuleString * demangled_key_name ;
502- if (mangled_name ) {
503- demangled_key_name = RedisModule_CreateString (NULL , persist_key_str , persist_key_len - 4 );
504- } else {
505- demangled_key_name = RedisModule_CreateString (NULL , persist_key_str , persist_key_len );
506- }
507-
508- const int status =
509- RAI_OpenKey_Tensor (ctx , demangled_key_name , & key , REDISMODULE_READ | REDISMODULE_WRITE );
510- if (status == REDISMODULE_ERR ) {
511- RedisModule_ReplyWithError (ctx , "ERR could not save tensor" );
512- goto clean_up ;
513- } else {
514- if (RedisModule_ModuleTypeSetValue (key , RedisAI_TensorType ,
515- RAI_TensorGetShallowCopy (tensor )) != REDISMODULE_OK ) {
516- RedisModule_ReplyWithError (ctx , "ERR could not save tensor" );
517- goto clean_up ;
518- }
519- }
520- ret = REDISMODULE_OK ;
521-
522- clean_up :
523- RedisModule_CloseKey (key );
524- RedisAI_ReplicateTensorSet (ctx , demangled_key_name , tensor );
525- RedisModule_FreeString (NULL , demangled_key_name );
526- return ret ;
527- }
528-
529- static void _PersistTensors (RedisModuleCtx * ctx , RedisAI_RunInfo * rinfo ) {
530-
531- AI_dictIterator * persist_iter = AI_dictGetSafeIterator (rinfo -> dagTensorsPersistedContext );
532- AI_dictEntry * persist_entry = AI_dictNext (persist_iter );
533-
534- while (persist_entry ) {
535- RedisModuleString * persist_key_name = AI_dictGetKey (persist_entry );
536- AI_dictEntry * tensor_entry = AI_dictFind (rinfo -> dagTensorsContext , persist_key_name );
537- if (tensor_entry ) {
538- RAI_Tensor * tensor = AI_dictGetVal (tensor_entry );
539- if (tensor == NULL ) {
540- persist_entry = AI_dictNext (persist_iter );
541- continue ;
542- }
543- if (_StoreTensorInKeySpace (ctx , tensor , persist_key_name , true) == REDISMODULE_ERR )
544- rinfo -> dagReplyLength ++ ;
545-
546- } else {
547- RedisModule_ReplyWithError (ctx ,
548- "ERR specified persistent key that was not used in DAG" );
549- rinfo -> dagReplyLength ++ ;
550- RedisModule_Log (ctx , "warning" ,
551- "on DAGRUN's PERSIST specified persistent key (%s) that "
552- "was not used on DAG. Logging all local context keys" ,
553- RedisModule_StringPtrLen (persist_key_name , NULL ));
554- AI_dictIterator * local_iter = AI_dictGetSafeIterator (rinfo -> dagTensorsContext );
555- AI_dictEntry * local_entry = AI_dictNext (local_iter );
556-
557- while (local_entry ) {
558- RedisModuleString * localcontext_key_name = AI_dictGetKey (local_entry );
559- RedisModule_Log (ctx , "warning" , "DAG's local context key (%s)" ,
560- RedisModule_StringPtrLen (localcontext_key_name , NULL ));
561- local_entry = AI_dictNext (local_iter );
562- }
563- AI_dictReleaseIterator (local_iter );
564-
565- for (size_t opN = 0 ; opN < array_len (rinfo -> dagOps ); opN ++ ) {
566- RedisModule_Log (ctx , "warning" , "DAG's op n# %zu - cmdType %d ( argc %d )" , opN ,
567- rinfo -> dagOps [opN ]-> commandType , rinfo -> dagOps [opN ]-> argc );
568- }
569- }
570- persist_entry = AI_dictNext (persist_iter );
571- }
572- AI_dictReleaseIterator (persist_iter );
573- }
574-
575- static void _ModelSingleOp_PersistTensors (RedisModuleCtx * ctx , RAI_DagOp * op ) {
576- const size_t noutputs = RAI_ModelRunCtxNumOutputs (op -> mctx );
577- for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
578- RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (op -> mctx , outputNumber );
579- tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
580- if (tensor )
581- _StoreTensorInKeySpace (ctx , tensor , op -> outkeys [outputNumber ], false);
582- }
583- }
584-
585- static void _ScriptSingleOp_PersistTensors (RedisModuleCtx * ctx , RAI_DagOp * op ) {
586- const size_t noutputs = RAI_ScriptRunCtxNumOutputs (op -> sctx );
587- for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
588- RAI_Tensor * tensor = RAI_ScriptRunCtxOutputTensor (op -> sctx , outputNumber );
589- tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
590- if (tensor )
591- _StoreTensorInKeySpace (ctx , tensor , op -> outkeys [outputNumber ], false);
592- }
593- }
594-
595599int RedisAI_DagRun_Reply (RedisModuleCtx * ctx , RedisModuleString * * argv , int argc ) {
596600 REDISMODULE_NOT_USED (argv );
597601 REDISMODULE_NOT_USED (argc );
598602 RedisAI_RunInfo * rinfo = RedisModule_GetBlockedClientPrivateData (ctx );
599603
600- if (RAI_GetErrorCode ( rinfo -> err ) == RAI_EDAGRUN ) {
601- RedisModule_ReplyWithError (ctx , RAI_GetErrorOneLine ( rinfo -> err ) );
604+ if (* rinfo -> timedOut ) {
605+ RedisModule_ReplyWithSimpleString (ctx , "TIMEDOUT" );
602606 RAI_FreeRunInfo (rinfo );
603- return REDISMODULE_ERR ;
607+ return REDISMODULE_OK ;
604608 }
605- int dag_error = 0 ;
606- char * detail_oneline ;
607609
608- size_t n_dagOps = array_len (rinfo -> dagOps );
609-
610- if (* rinfo -> timedOut ) {
611- RedisModule_ReplyWithSimpleString (ctx , "TIMEDOUT" );
610+ if (RAI_GetErrorCode (rinfo -> err ) == RAI_EDAGRUN ) {
611+ RedisModule_ReplyWithError (ctx , RAI_GetErrorOneLine (rinfo -> err ));
612612 RAI_FreeRunInfo (rinfo );
613613 return REDISMODULE_OK ;
614614 }
615+ int dag_error = 0 ;
616+ size_t n_dagOps = array_len (rinfo -> dagOps );
615617
616618 if (!rinfo -> single_op_dag ) {
617619 RedisModule_ReplyWithArray (ctx , REDISMODULE_POSTPONED_ARRAY_LEN );
@@ -697,17 +699,10 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
697699 }
698700
699701 if (dag_error ) {
700- if (rinfo -> single_op_dag == 0 ) {
701- RedisModule_ReplySetArrayLength (ctx , rinfo -> dagReplyLength );
702- }
703- RAI_FreeRunInfo (rinfo );
704- return REDISMODULE_ERR ;
702+ goto cleanup ;
705703 }
706-
707704 if (!rinfo -> single_op_dag ) {
708- // Save the required tensors in redis key space.
709- _PersistTensors (ctx , rinfo );
710- RedisModule_ReplySetArrayLength (ctx , rinfo -> dagReplyLength );
705+ _DAG_PersistTensors (ctx , rinfo );
711706 } else {
712707 if (rinfo -> dagOps [0 ]-> commandType == REDISAI_DAG_CMD_MODELRUN ) {
713708 _ModelSingleOp_PersistTensors (ctx , rinfo -> dagOps [0 ]);
@@ -717,6 +712,10 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
717712 }
718713 }
719714
715+ cleanup :
716+ if (!rinfo -> single_op_dag ) {
717+ RedisModule_ReplySetArrayLength (ctx , rinfo -> dagReplyLength );
718+ }
720719 RAI_FreeRunInfo (rinfo );
721720 return REDISMODULE_OK ;
722721}
0 commit comments