11#include "dag_builder.h"
22#include "run_info.h"
33#include "string_utils.h"
4+ #include "modelRun_ctx.h"
45
5- int _LoadTensorFromKeyspace (RedisModuleCtx * ctx , RedisModuleString * keyName , RedisModuleKey * * key ,
6- RAI_Tensor * * tensor , RAI_Error * err ) {
6+ static int _LoadTensorFromKeyspace (RedisModuleCtx * ctx , RedisModuleString * keyName ,
7+ RedisModuleKey * * key , RAI_Tensor * * tensor , RAI_Error * err ) {
78
9+ int res = REDISMODULE_ERR ;
10+ // RedisModule_ThreadSafeContextLock(ctx);
811 * key = RedisModule_OpenKey (ctx , keyName , REDISMODULE_READ );
912 if (RedisModule_KeyType (* key ) == REDISMODULE_KEYTYPE_EMPTY ) {
10- RedisModule_CloseKey (* key );
1113 RAI_SetError (err , RAI_EDAGBUILDER , "ERR tensor key is empty" );
12- return REDISMODULE_ERR ;
14+ goto end ;
1315 }
1416 if (RedisModule_ModuleTypeGetType (* key ) != RedisAI_TensorType ) {
15- RedisModule_CloseKey (* key );
1617 RAI_SetError (err , RAI_EDAGBUILDER , REDISMODULE_ERRORMSG_WRONGTYPE );
17- return REDISMODULE_ERR ;
18+ goto end ;
1819 }
1920 * tensor = RedisModule_ModuleTypeGetValue (* key );
21+ res = REDISMODULE_OK ;
22+
23+ end :
2024 RedisModule_CloseKey (* key );
25+ // RedisModule_ThreadSafeContextUnlock(ctx);
26+ return res ;
27+ }
28+
29+ static int _RAI_DagLoadTensor (RAI_DAGRunCtx * run_info , RedisModuleString * key_name ,
30+ RAI_Error * err ) {
31+
32+ RedisAI_RunInfo * rinfo = (RedisAI_RunInfo * )run_info ;
33+ RedisModuleCtx * ctx = RedisModule_GetThreadSafeContext (NULL );
34+ RAI_Tensor * t ;
35+ RedisModuleKey * key ;
36+ if (_LoadTensorFromKeyspace (ctx , key_name , & key , & t , err ) == REDISMODULE_ERR ) {
37+ RedisModule_FreeString (NULL , key_name );
38+ RedisModule_FreeThreadSafeContext (ctx );
39+ return REDISMODULE_ERR ;
40+ }
41+ // Add the tensor under its "mangled" key name to the DAG local context dict.
42+ char buf [16 ];
43+ sprintf (buf , "%04d" , 1 );
44+ RedisModule_StringAppendBuffer (NULL , key_name , buf , strlen (buf ));
45+ AI_dictAdd (rinfo -> dagTensorsContext , (void * )key_name , (void * )RAI_TensorGetShallowCopy (t ));
46+ RedisModule_FreeString (NULL , key_name );
47+ RedisModule_FreeThreadSafeContext (ctx );
2148 return REDISMODULE_OK ;
2249}
2350
24- RAI_DAGRunCtx * RAI_DagRunCtxCreate (void ) {
51+ RAI_DAGRunCtx * RAI_DAGRunCtxCreate (void ) {
2552 RedisAI_RunInfo * rinfo ;
2653 RAI_InitRunInfo (& rinfo );
2754 return (RAI_DAGRunCtx * )rinfo ;
2855}
2956
30- int RAI_DagAddModelRun_ (RAI_DAGRunCtx * run_info , RAI_ModelRunCtx * mctx , RedisModuleString * * inputs ,
31- RedisModuleString * * outputs , RAI_Error * err ) {
32- if (array_len (mctx -> inputs ) != 0 || array_len (mctx -> outputs ) != 0 ) {
33- RAI_SetError (
34- err , RAI_EDAGBUILDER ,
35- "Model run context cannot contain inputs or outputs when it is a part of a DAG" );
36- return REDISMODULE_ERR ;
37- }
38- RAI_Model * model = mctx -> model ;
39- if (model -> ninputs != array_len (inputs )) {
40- RAI_SetError (err , RAI_EDAGBUILDER ,
41- "Number of keys given as INPUTS does not match model definition" );
42- return REDISMODULE_ERR ;
43- }
44- if (model -> noutputs != array_len (outputs )) {
45- RAI_SetError (err , RAI_EDAGBUILDER ,
46- "Number of keys given as OUTPUTS does not match model definition" );
47- return REDISMODULE_ERR ;
48- }
49-
57+ RAI_DAGRunOp * RAI_DAGCreateModelRunOp (RAI_DAGRunCtx * run_info , RAI_Model * model ) {
5058 RedisAI_RunInfo * rinfo = (RedisAI_RunInfo * )run_info ;
59+ RAI_ModelRunCtx * mctx = RAI_ModelRunCtxCreate (model );
5160 RAI_DagOp * op ;
5261 RAI_InitDagOp (& op );
53- rinfo -> dagOps = array_append (rinfo -> dagOps , op );
5462
5563 op -> commandType = REDISAI_DAG_CMD_MODELRUN ;
5664 op -> mctx = mctx ;
5765 op -> devicestr = model -> devicestr ;
58- op -> inkeys = inputs ;
59- op -> outkeys = outputs ;
6066 op -> runkey = RAI_HoldString (NULL , (RedisModuleString * )model -> infokey );
61- return REDISMODULE_OK ;
67+ return ( RAI_DAGRunOp * ) op ;
6268}
6369
64- int RAI_DagAddModelRun (RAI_DAGRunCtx * run_info , RAI_ModelRunCtx * mctx , const char * * inputs ,
65- size_t ninputs , const char * * outputs , size_t noutputs , RAI_Error * err ) {
70+ int RAI_DAGRunOpAddInput (RAI_DAGRunOp * DAGOp , const char * input ) {
71+ RAI_DagOp * op = (RAI_DagOp * )DAGOp ;
72+ RedisModuleString * inkey = RedisModule_CreateString (NULL , input , strlen (input ));
73+ op -> inkeys = array_append (op -> inkeys , inkey );
74+ return REDISMODULE_OK ;
75+ }
6676
67- RedisModuleString * * inkeys = array_new (RedisModuleString * , 1 );
68- for (size_t i = 0 ; i < ninputs ; i ++ ) {
69- RedisModuleString * inkey = RedisModule_CreateString (NULL , inputs [i ], strlen (inputs [i ]));
70- inkeys = array_append (inkeys , inkey );
71- }
72- RedisModuleString * * outkeys = array_new (RedisModuleString * , 1 );
73- for (size_t i = 0 ; i < noutputs ; i ++ ) {
74- RedisModuleString * outkey = RedisModule_CreateString (NULL , outputs [i ], strlen (outputs [i ]));
75- outkeys = array_append (outkeys , outkey );
76- }
77- return RAI_DagAddModelRun_ (run_info , mctx , inkeys , outkeys , err );
77+ int RAI_DAGRunOpAddOutput (RAI_DAGRunOp * DAGOp , const char * output ) {
78+ RAI_DagOp * op = (RAI_DagOp * )DAGOp ;
79+ RedisModuleString * outkey = RedisModule_CreateString (NULL , output , strlen (output ));
80+ op -> outkeys = array_append (op -> outkeys , outkey );
81+ return REDISMODULE_OK ;
7882}
7983
80- int RedisAI_DagAddLoadPhase_ (RAI_DAGRunCtx * run_info , RedisModuleString * * keys_to_load ,
81- RAI_Error * err ) {
84+ int RAI_DAGAddRunOp (RAI_DAGRunCtx * run_info , RAI_DAGRunOp * DAGop , RAI_Error * err ) {
8285
83- int status = REDISMODULE_ERR ;
86+ RAI_DagOp * op = ( RAI_DagOp * ) DAGop ;
8487 RedisAI_RunInfo * rinfo = (RedisAI_RunInfo * )run_info ;
85- RedisModuleCtx * ctx = RedisModule_GetThreadSafeContext (NULL );
86- RedisModule_ThreadSafeContextLock (ctx );
87- size_t n_keys = array_len (keys_to_load );
88-
89- for (size_t i = 0 ; i < n_keys ; i ++ ) {
90- RAI_Tensor * t ;
91- RedisModuleKey * key ;
92- if (_LoadTensorFromKeyspace (ctx , keys_to_load [i ], & key , & t , err ) == REDISMODULE_ERR ) {
93- goto cleanup ;
88+ if (op -> mctx ) {
89+ RAI_Model * model = op -> mctx -> model ;
90+ if (model -> ninputs != array_len (op -> inkeys )) {
91+ RAI_SetError (err , RAI_EDAGBUILDER ,
92+ "Number of keys given as INPUTS does not match model definition" );
93+ return REDISMODULE_ERR ;
94+ }
95+ if (model -> noutputs != array_len (op -> outkeys )) {
96+ RAI_SetError (err , RAI_EDAGBUILDER ,
97+ "Number of keys given as OUTPUTS does not match model definition" );
98+ return REDISMODULE_ERR ;
9499 }
95- // Add the tensor under its "mangled" key name to the DAG local context dict.
96- char buf [16 ];
97- sprintf (buf , "%04d" , 1 );
98- RedisModule_StringAppendBuffer (NULL , keys_to_load [i ], buf , strlen (buf ));
99- AI_dictAdd (rinfo -> dagTensorsContext , (void * )keys_to_load [i ],
100- (void * )RAI_TensorGetShallowCopy (t ));
101100 }
102- status = REDISMODULE_OK ;
101+ rinfo -> dagOps = array_append ( rinfo -> dagOps , op ) ;
103102
104- cleanup :
105- RedisModule_ThreadSafeContextUnlock (ctx );
106- for (size_t i = 0 ; i < n_keys ; i ++ ) {
107- RedisModule_FreeString (NULL , keys_to_load [i ]);
108- }
109- array_free (keys_to_load );
110- return status ;
103+ return REDISMODULE_OK ;
111104}
112105
113- int RedisAI_DagAddLoadPhase (RAI_DAGRunCtx * run_info , const char * * t_names , uint n , RAI_Error * err ) {
114- if (n == 0 ) {
115- RAI_SetError (err , RAI_EDAGBUILDER , "Number of keys to LOAD must be positive" );
116- return REDISMODULE_ERR ;
117- }
118- RedisModuleString * * keys_to_load = array_new (RedisModuleString * , 1 );
119- for (size_t i = 0 ; i < n ; i ++ ) {
120- RedisModuleString * key = RedisModule_CreateString (NULL , t_names [i ], strlen (t_names [i ]));
121- keys_to_load = array_append (keys_to_load , key );
122- }
123- return RedisAI_DagAddLoadPhase_ (run_info , keys_to_load , err );
106+ int RAI_DAGLoadTensor (RAI_DAGRunCtx * run_info , const char * t_name , RAI_Error * err ) {
107+
108+ RedisModuleString * key_name = RedisModule_CreateString (NULL , t_name , strlen (t_name ));
109+ return _RAI_DagLoadTensor (run_info , key_name , err );
124110}
125111
126- int RAI_DagAddTensorGet (RAI_DAGRunCtx * run_info , const char * t_name , RAI_Error * err ) {
112+ int RAI_DAGLoadTensorRS (RAI_DAGRunCtx * run_info , RedisModuleString * t_name , RAI_Error * err ) {
113+
114+ RedisModuleString * key_name = RedisModule_CreateStringFromString (NULL , t_name );
115+ return _RAI_DagLoadTensor (run_info , key_name , err );
116+ }
117+
118+ int RAI_DAGAddTensorGet (RAI_DAGRunCtx * run_info , const char * t_name , RAI_Error * err ) {
127119
128120 RedisAI_RunInfo * rinfo = (RedisAI_RunInfo * )run_info ;
129121 RAI_DagOp * op ;
@@ -134,4 +126,14 @@ int RAI_DagAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *
134126 RedisModuleString * name = RedisModule_CreateString (NULL , t_name , strlen (t_name ));
135127 op -> inkeys = array_append (op -> inkeys , name );
136128 return REDISMODULE_OK ;
137- }
129+ }
130+
131+ void RAI_DAGRunOpFree (RAI_DAGRunOp * dagOp ) {
132+ RAI_DagOp * op = (RAI_DagOp * )dagOp ;
133+ RAI_FreeDagOp (op );
134+ }
135+
136+ void RAI_DAGFree (RAI_DAGRunCtx * run_info ) {
137+ RedisAI_RunInfo * rinfo = (RedisAI_RunInfo * )run_info ;
138+ RAI_FreeRunInfo (rinfo );
139+ }
0 commit comments