1111pthread_mutex_t global_lock = PTHREAD_MUTEX_INITIALIZER ;
1212pthread_cond_t global_cond = PTHREAD_COND_INITIALIZER ;
1313
14- static RAI_Model * _getModelFromKeySpace (RedisModuleCtx * ctx , const char * keyNameStr ) {
14+ static void * _getFromKeySpace (RedisModuleCtx * ctx , const char * keyNameStr ) {
1515
1616 RedisModuleString * keyRedisStr = RedisModule_CreateString (ctx , keyNameStr , strlen (keyNameStr ));
1717 RedisModuleKey * key = RedisModule_OpenKey (ctx , keyRedisStr , REDISMODULE_READ );
18- RAI_Model * model = RedisModule_ModuleTypeGetValue (key );
1918 RedisModule_FreeString (ctx , keyRedisStr );
2019 RedisModule_CloseKey (key );
21- return model ;
20+ return RedisModule_ModuleTypeGetValue ( key ) ;
2221}
2322
2423static void _DAGFinishFuncError (RAI_OnFinishCtx * onFinishCtx , void * private_data ) {
@@ -58,13 +57,29 @@ static int _testLoadError(RAI_DAGRunCtx *run_info) {
5857 return LLAPIMODULE_ERR ;
5958}
6059
60+ static int _testPersistError (RAI_DAGRunCtx * run_info ) {
61+
62+ RAI_Error * err ;
63+ RedisAI_InitError (& err );
64+ int status = RedisAI_DAGAddPersistTensor (run_info , "t1" , err );
65+ RedisModule_Assert (status == REDISMODULE_OK );
66+ status = RedisAI_DAGAddPersistTensor (run_info , "t1" , err );
67+ if (strcmp (RedisAI_GetError (err ), "Tensor key to persist has already given" ) == 0 ) {
68+ RedisModule_Assert (status == REDISMODULE_ERR );
69+ RedisAI_FreeError (err );
70+ return LLAPIMODULE_OK ;
71+ }
72+ RedisAI_FreeError (err );
73+ return LLAPIMODULE_ERR ;
74+ }
75+
6176static int _testModelRunOpError (RedisModuleCtx * ctx , RAI_DAGRunCtx * run_info ) {
6277
6378 RAI_Error * err ;
6479 RedisAI_InitError (& err );
6580 // The model m{1} should exist in key space.
66- RAI_Model * model = _getModelFromKeySpace (ctx , "m{1}" );
67- RAI_DAGRunOp * op = RedisAI_DAGCreateModelRunOp (run_info , model );
81+ RAI_Model * model = _getFromKeySpace (ctx , "m{1}" );
82+ RAI_DAGRunOp * op = RedisAI_DAGCreateModelRunOp (model );
6883 RedisAI_DAGRunOpAddInput (op , "first_input" );
6984
7085 // This model expect for 2 inputs not 1.
@@ -157,8 +172,8 @@ static int _testSimpleDAGRun(RedisModuleCtx *ctx, RAI_DAGRunCtx *run_info) {
157172 }
158173
159174 // The model m{1} should exist in key space.
160- RAI_Model * model = _getModelFromKeySpace (ctx , "m{1}" );
161- RAI_DAGRunOp * op = RedisAI_DAGCreateModelRunOp (run_info , model );
175+ RAI_Model * model = _getFromKeySpace (ctx , "m{1}" );
176+ RAI_DAGRunOp * op = RedisAI_DAGCreateModelRunOp (model );
162177 RedisAI_DAGRunOpAddInput (op , "a{1}" );
163178 RedisAI_DAGRunOpAddInput (op , "b{1}" );
164179 RedisAI_DAGRunOpAddOutput (op , "output" );
@@ -175,6 +190,7 @@ static int _testSimpleDAGRun(RedisModuleCtx *ctx, RAI_DAGRunCtx *run_info) {
175190 }
176191 // Wait until the onFinish callback returns.
177192 pthread_cond_wait (& global_cond , & global_lock );
193+ pthread_mutex_unlock (& global_lock );
178194
179195 // Verify that we received the expected tensor at the end of the run.
180196 RedisModule_Assert (array_len (outputs ) == 1 );
@@ -199,6 +215,62 @@ static int _testSimpleDAGRun(RedisModuleCtx *ctx, RAI_DAGRunCtx *run_info) {
199215 return res ;
200216}
201217
218+ static int _testSimpleDAGRun2 (RedisModuleCtx * ctx , RAI_DAGRunCtx * run_info ) {
219+
220+ RAI_Error * err ;
221+ RedisAI_InitError (& err );
222+ RAI_Tensor * * outputs = array_new (RAI_Tensor * , 1 );
223+ int res = LLAPIMODULE_ERR ;
224+
225+ RAI_Tensor * tensor = _getFromKeySpace (ctx , "a{1}" );
226+ RedisAI_DAGAddTensorSet (run_info , "input1" , tensor );
227+ tensor = _getFromKeySpace (ctx , "b{1}" );
228+ RedisAI_DAGAddTensorSet (run_info , "input2" , tensor );
229+
230+ // The script myscript{1} should exist in key space.
231+ RAI_Script * script = _getFromKeySpace (ctx , "myscript{1}" );
232+ RAI_DAGRunOp * op = RedisAI_DAGCreateScriptRunOp (script , "bar" );
233+ RedisAI_DAGRunOpAddInput (op , "input1" );
234+ RedisAI_DAGRunOpAddInput (op , "input2" );
235+ RedisAI_DAGRunOpAddOutput (op , "output" );
236+ int status = RedisAI_DAGAddRunOp (run_info , op , err );
237+ if (status != REDISMODULE_OK ) {
238+ goto cleanup ;
239+ }
240+
241+ RedisAI_DAGAddTensorGet (run_info , "output" , err );
242+ pthread_mutex_lock (& global_lock );
243+ if (RedisAI_DAGRun (run_info , _DAGFinishFunc , & outputs , err ) != REDISMODULE_OK ) {
244+ pthread_mutex_unlock (& global_lock );
245+ goto cleanup ;
246+ }
247+ // Wait until the onFinish callback returns.
248+ pthread_cond_wait (& global_cond , & global_lock );
249+ pthread_mutex_unlock (& global_lock );
250+
251+ // Verify that we received the expected tensor at the end of the run.
252+ RedisModule_Assert (array_len (outputs ) == 1 );
253+ RAI_Tensor * out_tensor = outputs [0 ];
254+ double expceted [4 ] = {4 , 6 , 4 , 6 };
255+ double val [4 ];
256+ for (long long i = 0 ; i < 4 ; i ++ ) {
257+ if (RedisAI_TensorGetValueAsDouble (out_tensor , i , & val [i ]) != 0 ) {
258+ goto cleanup ;
259+ }
260+ if (expceted [i ] != val [i ]) {
261+ goto cleanup ;
262+ }
263+ }
264+ RedisAI_TensorFree (out_tensor );
265+ res = LLAPIMODULE_OK ;
266+
267+ cleanup :
268+ RedisAI_FreeError (err );
269+ array_free (outputs );
270+ RedisAI_DAGFree (run_info );
271+ return res ;
272+ }
273+
202274int RAI_llapi_DAGRun (RedisModuleCtx * ctx , RedisModuleString * * argv , int argc ) {
203275 REDISMODULE_NOT_USED (argv );
204276
@@ -213,6 +285,11 @@ int RAI_llapi_DAGRun(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
213285 RedisAI_DAGFree (run_info );
214286 return RedisModule_ReplyWithSimpleString (ctx , "LOAD error test failed" );
215287 }
288+ // Test the case of a failure due to persist two tensors with the same name.
289+ if (_testPersistError (run_info ) != LLAPIMODULE_OK ) {
290+ RedisAI_DAGFree (run_info );
291+ return RedisModule_ReplyWithSimpleString (ctx , "PERSIST error test failed" );
292+ }
216293 // Test the case of a failure due to addition of a non compatible MODELRUN op.
217294 if (_testModelRunOpError (ctx , run_info ) != LLAPIMODULE_OK ) {
218295 RedisAI_DAGFree (run_info );
@@ -234,5 +311,10 @@ int RAI_llapi_DAGRun(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
234311 if (_testSimpleDAGRun (ctx , run_info ) != LLAPIMODULE_OK ) {
235312 return RedisModule_ReplyWithSimpleString (ctx , "Simple DAG run test failed" );
236313 }
314+ run_info = RedisAI_DAGRunCtxCreate ();
315+ // Test the case of building and running a DAG with TENSORSET, SCRIPTRUN and PERSIST ops.
316+ if (_testSimpleDAGRun2 (ctx , run_info ) != LLAPIMODULE_OK ) {
317+ return RedisModule_ReplyWithSimpleString (ctx , "Simple DAG run2 test failed" );
318+ }
237319 return RedisModule_ReplyWithSimpleString (ctx , "DAG run success" );
238320}
0 commit comments