@@ -91,7 +91,7 @@ static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *curre
9191
9292static void Dag_StoreOutputsFromModelRunCtx (RedisAI_RunInfo * rinfo , RAI_DagOp * currentOp ) {
9393
94- RAI_ContextReadLock (rinfo );
94+ RAI_ContextWriteLock (rinfo );
9595 const size_t noutputs = RAI_ModelRunCtxNumOutputs (currentOp -> mctx );
9696 for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
9797 RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (currentOp -> mctx , outputNumber );
@@ -284,6 +284,9 @@ void RedisAI_BatchedDagRunSession_ModelRun_Step(RedisAI_RunInfo **batched_rinfo,
284284 if (rinfo -> single_op_dag == 0 )
285285 Dag_StoreOutputsFromModelRunCtx (rinfo , currentOp );
286286 }
287+ // Clear the result in case of an error.
288+ if (result == REDISMODULE_ERR )
289+ RAI_ClearError (& err );
287290}
288291
289292/**
@@ -452,16 +455,20 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
452455 return 1 ;
453456}
454457
455- int RedisAI_DagDeviceComplete (RedisAI_RunInfo * rinfo ) {
458+ bool RedisAI_DagDeviceComplete (RedisAI_RunInfo * rinfo ) {
456459 return rinfo -> dagDeviceCompleteOpCount == rinfo -> dagDeviceOpCount ;
457460}
458461
459- int RedisAI_DagComplete (RedisAI_RunInfo * rinfo ) {
462+ bool RedisAI_DagComplete (RedisAI_RunInfo * rinfo ) {
460463 int completeOpCount = __atomic_load_n (rinfo -> dagCompleteOpCount , __ATOMIC_RELAXED );
461464
462465 return completeOpCount == rinfo -> dagOpCount ;
463466}
464467
468+ bool RedisAI_DagError (RedisAI_RunInfo * rinfo ) {
469+ return __atomic_load_n (rinfo -> dagError , __ATOMIC_RELAXED ) != 0 ;
470+ }
471+
465472RAI_DagOp * RedisAI_DagCurrentOp (RedisAI_RunInfo * rinfo ) {
466473 if (rinfo -> dagDeviceCompleteOpCount == rinfo -> dagDeviceOpCount ) {
467474 return NULL ;
@@ -470,21 +477,21 @@ RAI_DagOp *RedisAI_DagCurrentOp(RedisAI_RunInfo *rinfo) {
470477 return rinfo -> dagDeviceOps [rinfo -> dagDeviceCompleteOpCount ];
471478}
472479
473- void RedisAI_DagCurrentOpInfo (RedisAI_RunInfo * rinfo , int * currentOpReady ,
474- int * currentOpBatchable ) {
480+ void RedisAI_DagCurrentOpInfo (RedisAI_RunInfo * rinfo , bool * currentOpReady ,
481+ bool * currentOpBatchable ) {
475482 RAI_DagOp * currentOp_ = RedisAI_DagCurrentOp (rinfo );
476483
477- * currentOpReady = 0 ;
478- * currentOpBatchable = 0 ;
484+ * currentOpReady = false ;
485+ * currentOpBatchable = false ;
479486
480487 if (currentOp_ == NULL ) {
481488 return ;
482489 }
483490
484491 if (currentOp_ -> mctx && currentOp_ -> mctx -> model -> opts .batchsize > 0 ) {
485- * currentOpBatchable = 1 ;
492+ * currentOpBatchable = true ;
486493 }
487- * currentOpReady = 1 ;
494+ * currentOpReady = true ;
488495 // If this is a single op dag, the op is definitely ready.
489496 if (rinfo -> single_op_dag == 1 )
490497 return ;
@@ -495,7 +502,7 @@ void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady,
495502 for (int i = 0 ; i < n_inkeys ; i ++ ) {
496503 if (AI_dictFind (rinfo -> dagTensorsContext , currentOp_ -> inkeys [i ]) == NULL ) {
497504 RAI_ContextUnlock (rinfo );
498- * currentOpReady = 0 ;
505+ * currentOpReady = false ;
499506 return ;
500507 }
501508 }
@@ -604,13 +611,11 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
604611
605612 if (* rinfo -> timedOut ) {
606613 RedisModule_ReplyWithSimpleString (ctx , "TIMEDOUT" );
607- RAI_FreeRunInfo (rinfo );
608614 return REDISMODULE_OK ;
609615 }
610616
611617 if (RAI_GetErrorCode (rinfo -> err ) == RAI_EDAGRUN ) {
612618 RedisModule_ReplyWithError (ctx , RAI_GetErrorOneLine (rinfo -> err ));
613- RAI_FreeRunInfo (rinfo );
614619 return REDISMODULE_OK ;
615620 }
616621 int dag_error = 0 ;
@@ -717,7 +722,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
717722 if (!rinfo -> single_op_dag ) {
718723 RedisModule_ReplySetArrayLength (ctx , rinfo -> dagReplyLength );
719724 }
720- RAI_FreeRunInfo (rinfo );
721725 return REDISMODULE_OK ;
722726}
723727
@@ -745,15 +749,12 @@ int RedisAI_DagRun_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx, RedisMo
745749 return REDISMODULE_OK ;
746750}
747751
748- void RunInfo_FreeData (RedisModuleCtx * ctx , void * rinfo ) {}
749-
750- void RedisAI_Disconnected (RedisModuleCtx * ctx , RedisModuleBlockedClient * bc ) {
751- RedisModule_Log (ctx , "warning" , "Blocked client %p disconnected!" , (void * )bc );
752- }
752+ void RunInfo_FreeData (RedisModuleCtx * ctx , void * rinfo ) { RAI_FreeRunInfo (rinfo ); }
753753
754754void DAG_ReplyAndUnblock (RedisAI_OnFinishCtx * ctx , void * private_data ) {
755755
756756 RedisAI_RunInfo * rinfo = (RedisAI_RunInfo * )ctx ;
757- if (rinfo -> client )
757+ if (rinfo -> client ) {
758758 RedisModule_UnblockClient (rinfo -> client , rinfo );
759+ }
759760}
0 commit comments