33#include "background_workers.h"
44#include "util/string_utils.h"
55
6- void _DAG_SetTensorsInLocalContext (RedisAI_RunInfo * rinfo ) {
7- for (size_t i = 0 ; i < rinfo -> dagOpCount ; i ++ ) {
8- RAI_DagOp * op = rinfo -> dagOps [i ];
9- if (op -> commandType == REDISAI_DAG_CMD_TENSORSET ) {
10- // Insert the tensor with its mangled (unique) name.
11- void * t = (void * )RAI_TensorGetShallowCopy (op -> outTensor );
12- AI_dictReplace (rinfo -> dagTensorsContext , (void * )op -> outkeys [0 ], t );
13- }
14- }
15- }
16-
17- int MangleTensorsNames (RedisAI_RunInfo * rinfo ) {
18-
19- int res = REDISMODULE_ERR ;
20- AI_dict * mangled_tensors = AI_dictCreate (& AI_dictTypeHeapRStrings , NULL );
6+ int ValidatePersistKeys (RedisAI_RunInfo * rinfo , AI_dict * tensorsNamesToInd ,
7+ AI_dict * persistTensorsNames ) {
218
229 {
23- AI_dictIterator * iter = AI_dictGetSafeIterator (rinfo -> dagTensorsContext );
24- AI_dictEntry * entry = AI_dictNext (iter );
25- while (entry ) {
26- RedisModuleString * key = (RedisModuleString * )AI_dictGetKey (entry );
27- size_t key_len ;
28- const char * key_str = RedisModule_StringPtrLen (key , & key_len );
29- RedisModuleString * demangled_key = RedisModule_CreateString (NULL , key_str , key_len - 4 );
30- int * instance = RedisModule_Alloc (sizeof (int ));
31- * instance = 1 ;
32- AI_dictAdd (mangled_tensors , (void * )demangled_key , (void * )instance );
33- RedisModule_FreeString (NULL , demangled_key );
34- entry = AI_dictNext (iter );
10+ AI_dictIterator * iter = AI_dictGetSafeIterator (persistTensorsNames );
11+ AI_dictEntry * persist_entry ;
12+ while ((persist_entry = AI_dictNext (iter ))) {
13+ RedisModuleString * persist_key = (RedisModuleString * )AI_dictGetKey (persist_entry );
14+ AI_dictEntry * entry = AI_dictFind (tensorsNamesToInd , persist_key );
15+ if (!entry ) {
16+ RAI_SetError (rinfo -> err , RAI_EDAGRUN , "ERR PERSIST key cannot be found in DAG" );
17+ return REDISMODULE_ERR ;
18+ }
3519 }
3620 AI_dictReleaseIterator (iter );
3721 }
22+ return REDISMODULE_OK ;
23+ }
24+
25+ int MapTensorsKeysToIndices (RedisAI_RunInfo * rinfo , AI_dict * tensorsNamesToInd ) {
3826
3927 for (long long i = 0 ; i < array_len (rinfo -> dagOps ); i ++ ) {
4028 RAI_DagOp * currentOp = rinfo -> dagOps [i ];
4129
42- RedisModuleString * * mangled_inkeys =
43- array_new (RedisModuleString * , array_len (currentOp -> inkeys ));
4430 for (long long j = 0 ; j < array_len (currentOp -> inkeys ); j ++ ) {
4531 RedisModuleString * key = currentOp -> inkeys [j ];
46- AI_dictEntry * entry = AI_dictFind (mangled_tensors , key );
32+ AI_dictEntry * entry = AI_dictFind (tensorsNamesToInd , key );
4733 if (!entry ) {
48- array_free (mangled_inkeys );
4934 RAI_SetError (rinfo -> err , RAI_EDAGRUN , "ERR INPUT key cannot be found in DAG" );
50- goto cleanup ;
35+ return REDISMODULE_ERR ;
5136 }
52- int * instance = AI_dictGetVal (entry );
53- char buf [16 ];
54- sprintf (buf , "%04d" , * instance );
55- RedisModuleString * mangled_key = RedisModule_CreateStringFromString (NULL , key );
56- RedisModule_StringAppendBuffer (NULL , mangled_key , buf , strlen (buf ));
57- mangled_inkeys = array_append (mangled_inkeys , mangled_key );
37+ int * ind = AI_dictGetVal (entry );
38+ currentOp -> inkeys_indices = array_append (currentOp -> inkeys_indices , * ind );
5839 }
5940
60- RedisModuleString * * mangled_outkeys =
61- array_new (RedisModuleString * , array_len (currentOp -> outkeys ));
6241 for (long long j = 0 ; j < array_len (currentOp -> outkeys ); j ++ ) {
6342 RedisModuleString * key = currentOp -> outkeys [j ];
64- AI_dictEntry * entry = AI_dictFind (mangled_tensors , key );
65- int * instance = NULL ;
66- if (entry ) {
67- instance = AI_dictGetVal (entry );
68- * instance += 1 ;
43+ int * ind = RedisModule_Alloc (sizeof (int ));
44+ * ind = array_len (rinfo -> dagSharedTensors );
45+
46+ // Add a new empty place holder in the array for an output tensor.
47+ // If this is MODELSET op, the tensor is already realized.
48+ if (currentOp -> commandType == REDISAI_DAG_CMD_TENSORSET ) {
49+ RAI_Tensor * t = RAI_TensorGetShallowCopy (currentOp -> outTensor );
50+ rinfo -> dagSharedTensors = array_append (rinfo -> dagSharedTensors , t );
6951 } else {
70- instance = RedisModule_Alloc (sizeof (int ));
71- * instance = 1 ;
72- AI_dictAdd (mangled_tensors , (void * )key , (void * )instance );
73- }
74- char buf [16 ];
75- sprintf (buf , "%04d" , * instance );
76- RedisModuleString * mangled_key = RedisModule_CreateStringFromString (NULL , key );
77- RedisModule_StringAppendBuffer (NULL , mangled_key , buf , strlen (buf ));
78- mangled_outkeys = array_append (mangled_outkeys , mangled_key );
79- }
80-
81- if (currentOp -> inkeys ) {
82- for (size_t j = 0 ; j < array_len (currentOp -> inkeys ); j ++ ) {
83- RedisModule_FreeString (NULL , currentOp -> inkeys [j ]);
84- }
85- array_free (currentOp -> inkeys );
86- }
87-
88- if (currentOp -> outkeys ) {
89- for (size_t j = 0 ; j < array_len (currentOp -> outkeys ); j ++ ) {
90- RedisModule_FreeString (NULL , currentOp -> outkeys [j ]);
52+ rinfo -> dagSharedTensors = array_append (rinfo -> dagSharedTensors , NULL );
9153 }
92- array_free (currentOp -> outkeys );
54+ currentOp -> outkeys_indices = array_append (currentOp -> outkeys_indices , * ind );
55+ AI_dictReplace (tensorsNamesToInd , (void * )key , (void * )ind );
9356 }
94-
95- currentOp -> inkeys = mangled_inkeys ;
96- currentOp -> outkeys = mangled_outkeys ;
97- }
98-
99- AI_dict * mangled_persisted = AI_dictCreate (& AI_dictTypeHeapRStrings , NULL );
100- {
101- AI_dictIterator * iter = AI_dictGetSafeIterator (rinfo -> dagTensorsPersistedContext );
102- AI_dictEntry * entry = AI_dictNext (iter );
103- while (entry ) {
104- RedisModuleString * key = (RedisModuleString * )AI_dictGetKey (entry );
105- AI_dictEntry * mangled_entry = AI_dictFind (mangled_tensors , key );
106- if (!mangled_entry ) {
107- AI_dictRelease (mangled_persisted );
108- AI_dictReleaseIterator (iter );
109- RAI_SetError (rinfo -> err , RAI_EDAGRUN , "ERR PERSIST key cannot be found in DAG" );
110- goto cleanup ;
111- }
112- if (AI_dictFind (mangled_persisted , key ) != NULL ) {
113- AI_dictRelease (mangled_persisted );
114- AI_dictReleaseIterator (iter );
115- RAI_SetError (rinfo -> err , RAI_EDAGRUN , "ERR PERSIST keys must be unique" );
116- goto cleanup ;
117- }
118- int * instance = AI_dictGetVal (mangled_entry );
119- char buf [16 ];
120- sprintf (buf , "%04d" , * instance );
121- RedisModuleString * mangled_key = RedisModule_CreateStringFromString (NULL , key );
122- RedisModule_StringAppendBuffer (NULL , mangled_key , buf , strlen (buf ));
123- AI_dictAdd (mangled_persisted , (void * )mangled_key , (void * )1 );
124- RedisModule_FreeString (NULL , mangled_key );
125- entry = AI_dictNext (iter );
126- }
127- AI_dictReleaseIterator (iter );
12857 }
129-
130- AI_dictRelease (rinfo -> dagTensorsPersistedContext );
131- rinfo -> dagTensorsPersistedContext = mangled_persisted ;
132-
133- for (long long i = 0 ; i < array_len (rinfo -> dagOps ); i ++ ) {
134- if (rinfo -> dagOps [i ]-> devicestr == NULL ) {
135- rinfo -> dagOps [i ]-> devicestr = "CPU" ;
136- }
137- }
138- // Tensors from TENSORSET ops are ready to be put in DAG local context under their mangled
139- // names.
140- _DAG_SetTensorsInLocalContext (rinfo );
141- res = REDISMODULE_OK ;
142-
143- cleanup : {
144- AI_dictIterator * iter = AI_dictGetSafeIterator (mangled_tensors );
145- AI_dictEntry * entry = AI_dictNext (iter );
146- while (entry ) {
147- int * val = (int * )AI_dictGetVal (entry );
148- RedisModule_Free (val );
149- entry = AI_dictNext (iter );
150- }
151- AI_dictReleaseIterator (iter );
152- }
153- AI_dictRelease (mangled_tensors );
154- return res ;
58+ return REDISMODULE_OK ;
15559}
15660
15761// Add Shallow copies of the DAG run info to the devices' queues.
@@ -242,7 +146,7 @@ int RAI_DAGRun(RAI_DAGRunCtx *run_info, RAI_OnFinishCB DAGAsyncFinish, void *pri
242146 }
243147 // Make the inkeys and outkeys of the DAG ops unique, to ensure that the operations
244148 // will be execute in the right order.
245- if (MangleTensorsNames (rinfo ) != REDISMODULE_OK ) {
149+ if (MapTensorsKeysToIndices (rinfo , rinfo -> tensorsNamesToIndices ) != REDISMODULE_OK ) {
246150 RAI_SetError (err , rinfo -> err -> code , rinfo -> err -> detail );
247151 return REDISMODULE_ERR ;
248152 }
@@ -269,16 +173,13 @@ size_t RAI_DAGNumOutputs(RAI_OnFinishCtx *finish_ctx) {
269173const RAI_Tensor * RAI_DAGOutputTensor (RAI_OnFinishCtx * finish_ctx , size_t index ) {
270174 size_t tensor_get_op_ind = -1 ;
271175 RedisAI_RunInfo * rinfo = (RedisAI_RunInfo * )finish_ctx ;
176+
272177 for (size_t i = 0 ; i < rinfo -> dagOpCount ; i ++ ) {
273178 RAI_DagOp * op = rinfo -> dagOps [i ];
274179 if (op -> commandType == REDISAI_DAG_CMD_TENSORGET ) {
275180 tensor_get_op_ind ++ ;
276181 if (tensor_get_op_ind == index ) {
277- RAI_Tensor * t ;
278- int res = RAI_getTensorFromLocalContext (rinfo -> dagTensorsContext , op -> inkeys [0 ], & t ,
279- op -> err );
280- RedisModule_Assert (res == REDISMODULE_OK );
281- return t ;
182+ return Dag_GetInternalTensor (rinfo , op -> inkeys_indices [0 ]);
282183 }
283184 }
284185 }
0 commit comments