Skip to content

Commit e56ede3

Browse files
committed
Remove Persist. Add LLAPI for getting the status of each DAG op (RAI_Error) + tests via test module.
1 parent b9d4608 commit e56ede3

File tree

8 files changed

+265
-202
lines changed

8 files changed

+265
-202
lines changed

src/DAG/dag.c

Lines changed: 119 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,112 @@ static void Dag_StoreOutputsFromModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *c
101101
RAI_ContextUnlock(rinfo);
102102
}
103103

104+
static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor,
105+
RedisModuleString *persist_key_name, bool mangled_name) {
106+
107+
int ret = REDISMODULE_ERR;
108+
RedisModuleKey *key;
109+
size_t persist_key_len;
110+
const char *persist_key_str = RedisModule_StringPtrLen(persist_key_name, &persist_key_len);
111+
112+
RedisModuleString *demangled_key_name;
113+
if (mangled_name) {
114+
demangled_key_name = RedisModule_CreateString(NULL, persist_key_str, persist_key_len - 4);
115+
} else {
116+
demangled_key_name = RedisModule_CreateString(NULL, persist_key_str, persist_key_len);
117+
}
118+
119+
const int status =
120+
RAI_OpenKey_Tensor(ctx, demangled_key_name, &key, REDISMODULE_READ | REDISMODULE_WRITE);
121+
if (status == REDISMODULE_ERR) {
122+
RedisModule_ReplyWithError(ctx, "ERR could not save tensor");
123+
goto clean_up;
124+
}
125+
if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, RAI_TensorGetShallowCopy(tensor)) !=
126+
REDISMODULE_OK) {
127+
RedisModule_ReplyWithError(ctx, "ERR could not save tensor");
128+
RedisModule_CloseKey(key);
129+
goto clean_up;
130+
}
131+
// Only if we got until here, tensor is saved in keyspace.
132+
RedisAI_ReplicateTensorSet(ctx, demangled_key_name, tensor);
133+
ret = REDISMODULE_OK;
134+
135+
clean_up:
136+
RedisModule_FreeString(NULL, demangled_key_name);
137+
return ret;
138+
}
139+
140+
static void _DAG_PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
141+
142+
AI_dictIterator *persist_iter = AI_dictGetSafeIterator(rinfo->dagTensorsPersistedContext);
143+
AI_dictEntry *persist_entry = AI_dictNext(persist_iter);
144+
145+
while (persist_entry) {
146+
RedisModuleString *persist_key_name = AI_dictGetKey(persist_entry);
147+
AI_dictEntry *tensor_entry = AI_dictFind(rinfo->dagTensorsContext, persist_key_name);
148+
RedisModule_Assert(tensor_entry);
149+
RAI_Tensor *tensor = AI_dictGetVal(tensor_entry);
150+
if (tensor == NULL) {
151+
persist_entry = AI_dictNext(persist_iter);
152+
continue;
153+
}
154+
if (_StoreTensorInKeySpace(ctx, tensor, persist_key_name, true) == REDISMODULE_ERR) {
155+
*rinfo->dagError = 1;
156+
RedisModule_Log(ctx, "warning",
157+
"Could not persist tensor under the key (%s) after executing DAGRUN "
158+
"command, persist stopped",
159+
RedisModule_StringPtrLen(persist_key_name, NULL));
160+
AI_dictReleaseIterator(persist_iter);
161+
return;
162+
}
163+
persist_entry = AI_dictNext(persist_iter);
164+
}
165+
AI_dictReleaseIterator(persist_iter);
166+
}
167+
168+
static void _ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
169+
170+
const size_t noutputs = RAI_ModelRunCtxNumOutputs(op->mctx);
171+
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
172+
RedisModuleString *persist_key_name = op->outkeys[outputNumber];
173+
RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(op->mctx, outputNumber);
174+
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
175+
if (!tensor)
176+
continue;
177+
178+
if (_StoreTensorInKeySpace(ctx, tensor, persist_key_name, false) == REDISMODULE_ERR) {
179+
RedisModule_Log(ctx, "warning",
180+
"Could not persist tensor under the key (%s) after executing DAGRUN "
181+
"command, persist stopped",
182+
RedisModule_StringPtrLen(persist_key_name, NULL));
183+
op->result = REDISMODULE_ERR;
184+
return;
185+
}
186+
}
187+
}
188+
189+
static void _ScriptSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
190+
191+
const size_t noutputs = RAI_ScriptRunCtxNumOutputs(op->sctx);
192+
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
193+
RedisModuleString *persist_key_name = op->outkeys[outputNumber];
194+
RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(op->sctx, outputNumber);
195+
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
196+
if (!tensor)
197+
continue;
198+
199+
if (_StoreTensorInKeySpace(ctx, tensor, persist_key_name, false) == REDISMODULE_ERR) {
200+
RedisModule_Log(ctx, "warning",
201+
"Could not persist tensor under the key (%s) after executing DAGRUN "
202+
"command, persist stopped",
203+
RedisModule_StringPtrLen(persist_key_name, NULL));
204+
op->result = REDISMODULE_ERR;
205+
return;
206+
}
207+
}
208+
}
209+
104210
/**
105211
* Execution of a MODELRUN DAG step.
106212
* If an error occurs, it is recorded in the DagOp struct.
@@ -490,128 +596,24 @@ void RedisAI_BatchedDagRunSessionStep(RedisAI_RunInfo **batched_rinfo, const cha
490596
return;
491597
}
492598

493-
static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor,
494-
RedisModuleString *persist_key_name, bool mangled_name) {
495-
496-
int ret = REDISMODULE_ERR;
497-
RedisModuleKey *key;
498-
size_t persist_key_len;
499-
const char *persist_key_str = RedisModule_StringPtrLen(persist_key_name, &persist_key_len);
500-
501-
RedisModuleString *demangled_key_name;
502-
if (mangled_name) {
503-
demangled_key_name = RedisModule_CreateString(NULL, persist_key_str, persist_key_len - 4);
504-
} else {
505-
demangled_key_name = RedisModule_CreateString(NULL, persist_key_str, persist_key_len);
506-
}
507-
508-
const int status =
509-
RAI_OpenKey_Tensor(ctx, demangled_key_name, &key, REDISMODULE_READ | REDISMODULE_WRITE);
510-
if (status == REDISMODULE_ERR) {
511-
RedisModule_ReplyWithError(ctx, "ERR could not save tensor");
512-
goto clean_up;
513-
} else {
514-
if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType,
515-
RAI_TensorGetShallowCopy(tensor)) != REDISMODULE_OK) {
516-
RedisModule_ReplyWithError(ctx, "ERR could not save tensor");
517-
goto clean_up;
518-
}
519-
}
520-
ret = REDISMODULE_OK;
521-
522-
clean_up:
523-
RedisModule_CloseKey(key);
524-
RedisAI_ReplicateTensorSet(ctx, demangled_key_name, tensor);
525-
RedisModule_FreeString(NULL, demangled_key_name);
526-
return ret;
527-
}
528-
529-
static void _PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
530-
531-
AI_dictIterator *persist_iter = AI_dictGetSafeIterator(rinfo->dagTensorsPersistedContext);
532-
AI_dictEntry *persist_entry = AI_dictNext(persist_iter);
533-
534-
while (persist_entry) {
535-
RedisModuleString *persist_key_name = AI_dictGetKey(persist_entry);
536-
AI_dictEntry *tensor_entry = AI_dictFind(rinfo->dagTensorsContext, persist_key_name);
537-
if (tensor_entry) {
538-
RAI_Tensor *tensor = AI_dictGetVal(tensor_entry);
539-
if (tensor == NULL) {
540-
persist_entry = AI_dictNext(persist_iter);
541-
continue;
542-
}
543-
if (_StoreTensorInKeySpace(ctx, tensor, persist_key_name, true) == REDISMODULE_ERR)
544-
rinfo->dagReplyLength++;
545-
546-
} else {
547-
RedisModule_ReplyWithError(ctx,
548-
"ERR specified persistent key that was not used in DAG");
549-
rinfo->dagReplyLength++;
550-
RedisModule_Log(ctx, "warning",
551-
"on DAGRUN's PERSIST specified persistent key (%s) that "
552-
"was not used on DAG. Logging all local context keys",
553-
RedisModule_StringPtrLen(persist_key_name, NULL));
554-
AI_dictIterator *local_iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext);
555-
AI_dictEntry *local_entry = AI_dictNext(local_iter);
556-
557-
while (local_entry) {
558-
RedisModuleString *localcontext_key_name = AI_dictGetKey(local_entry);
559-
RedisModule_Log(ctx, "warning", "DAG's local context key (%s)",
560-
RedisModule_StringPtrLen(localcontext_key_name, NULL));
561-
local_entry = AI_dictNext(local_iter);
562-
}
563-
AI_dictReleaseIterator(local_iter);
564-
565-
for (size_t opN = 0; opN < array_len(rinfo->dagOps); opN++) {
566-
RedisModule_Log(ctx, "warning", "DAG's op n# %zu - cmdType %d ( argc %d )", opN,
567-
rinfo->dagOps[opN]->commandType, rinfo->dagOps[opN]->argc);
568-
}
569-
}
570-
persist_entry = AI_dictNext(persist_iter);
571-
}
572-
AI_dictReleaseIterator(persist_iter);
573-
}
574-
575-
static void _ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
576-
const size_t noutputs = RAI_ModelRunCtxNumOutputs(op->mctx);
577-
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
578-
RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(op->mctx, outputNumber);
579-
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
580-
if (tensor)
581-
_StoreTensorInKeySpace(ctx, tensor, op->outkeys[outputNumber], false);
582-
}
583-
}
584-
585-
static void _ScriptSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
586-
const size_t noutputs = RAI_ScriptRunCtxNumOutputs(op->sctx);
587-
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
588-
RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(op->sctx, outputNumber);
589-
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
590-
if (tensor)
591-
_StoreTensorInKeySpace(ctx, tensor, op->outkeys[outputNumber], false);
592-
}
593-
}
594-
595599
int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
596600
REDISMODULE_NOT_USED(argv);
597601
REDISMODULE_NOT_USED(argc);
598602
RedisAI_RunInfo *rinfo = RedisModule_GetBlockedClientPrivateData(ctx);
599603

600-
if (RAI_GetErrorCode(rinfo->err) == RAI_EDAGRUN) {
601-
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(rinfo->err));
604+
if (*rinfo->timedOut) {
605+
RedisModule_ReplyWithSimpleString(ctx, "TIMEDOUT");
602606
RAI_FreeRunInfo(rinfo);
603-
return REDISMODULE_ERR;
607+
return REDISMODULE_OK;
604608
}
605-
int dag_error = 0;
606-
char *detail_oneline;
607609

608-
size_t n_dagOps = array_len(rinfo->dagOps);
609-
610-
if (*rinfo->timedOut) {
611-
RedisModule_ReplyWithSimpleString(ctx, "TIMEDOUT");
610+
if (RAI_GetErrorCode(rinfo->err) == RAI_EDAGRUN) {
611+
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(rinfo->err));
612612
RAI_FreeRunInfo(rinfo);
613613
return REDISMODULE_OK;
614614
}
615+
int dag_error = 0;
616+
size_t n_dagOps = array_len(rinfo->dagOps);
615617

616618
if (!rinfo->single_op_dag) {
617619
RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_ARRAY_LEN);
@@ -697,17 +699,10 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
697699
}
698700

699701
if (dag_error) {
700-
if (rinfo->single_op_dag == 0) {
701-
RedisModule_ReplySetArrayLength(ctx, rinfo->dagReplyLength);
702-
}
703-
RAI_FreeRunInfo(rinfo);
704-
return REDISMODULE_ERR;
702+
goto cleanup;
705703
}
706-
707704
if (!rinfo->single_op_dag) {
708-
// Save the required tensors in redis key space.
709-
_PersistTensors(ctx, rinfo);
710-
RedisModule_ReplySetArrayLength(ctx, rinfo->dagReplyLength);
705+
_DAG_PersistTensors(ctx, rinfo);
711706
} else {
712707
if (rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_MODELRUN) {
713708
_ModelSingleOp_PersistTensors(ctx, rinfo->dagOps[0]);
@@ -717,6 +712,10 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
717712
}
718713
}
719714

715+
cleanup:
716+
if (!rinfo->single_op_dag) {
717+
RedisModule_ReplySetArrayLength(ctx, rinfo->dagReplyLength);
718+
}
720719
RAI_FreeRunInfo(rinfo);
721720
return REDISMODULE_OK;
722721
}

src/DAG/dag_builder.c

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ static int _LoadTensorFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyNa
77
RedisModuleKey **key, RAI_Tensor **tensor, RAI_Error *err) {
88

99
int res = REDISMODULE_ERR;
10-
// RedisModule_ThreadSafeContextLock(ctx);
1110
*key = RedisModule_OpenKey(ctx, keyName, REDISMODULE_READ);
1211
if (RedisModule_KeyType(*key) == REDISMODULE_KEYTYPE_EMPTY) {
1312
RAI_SetError(err, RAI_EDAGBUILDER, "ERR tensor key is empty");
@@ -22,7 +21,6 @@ static int _LoadTensorFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyNa
2221

2322
end:
2423
RedisModule_CloseKey(*key);
25-
// RedisModule_ThreadSafeContextUnlock(ctx);
2624
return res;
2725
}
2826

@@ -126,30 +124,6 @@ int RAI_DAGLoadTensorRS(RAI_DAGRunCtx *run_info, RedisModuleString *t_name, RAI_
126124
return _RAI_DagLoadTensor(run_info, key_name, err);
127125
}
128126

129-
// todo: Persist tensors should not be part of dag reply, but before...
130-
int RAI_DAGAddPersistTensorRS(RAI_DAGRunCtx *run_info, RedisModuleString *t_name, RAI_Error *err) {
131-
132-
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
133-
if (AI_dictAdd(rinfo->dagTensorsPersistedContext, (void *)t_name, (void *)1) != DICT_OK) {
134-
RAI_SetError(err, RAI_EDAGBUILDER, "Tensor key to persist has already given");
135-
return REDISMODULE_ERR;
136-
}
137-
return REDISMODULE_OK;
138-
}
139-
140-
int RAI_DAGAddPersistTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err) {
141-
142-
RedisModuleString *key_name = RedisModule_CreateString(NULL, t_name, strlen(t_name));
143-
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
144-
if (AI_dictAdd(rinfo->dagTensorsPersistedContext, (void *)key_name, (void *)1) != DICT_OK) {
145-
RAI_SetError(err, RAI_EDAGBUILDER, "Tensor key to persist has already given");
146-
RedisModule_FreeString(NULL, key_name);
147-
return REDISMODULE_ERR;
148-
}
149-
RedisModule_FreeString(NULL, key_name);
150-
return REDISMODULE_OK;
151-
}
152-
153127
int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err) {
154128

155129
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
@@ -177,6 +151,11 @@ int RAI_DAGAddTensorSet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor
177151
return REDISMODULE_OK;
178152
}
179153

154+
size_t RAI_DAGNumOps(RAI_DAGRunCtx *run_info) {
155+
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
156+
return array_len(rinfo->dagOps);
157+
}
158+
180159
void RAI_DAGRunOpFree(RAI_DAGRunOp *dagOp) {
181160
RAI_DagOp *op = (RAI_DagOp *)dagOp;
182161
RAI_FreeDagOp(op);

src/DAG/dag_builder.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,12 @@ int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *er
1818

1919
int RAI_DAGLoadTensorRS(RAI_DAGRunCtx *run_info, RedisModuleString *t_name, RAI_Error *err);
2020

21-
int RAI_DAGAddPersistTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err);
22-
23-
int RAI_DAGAddPersistTensorRS(RAI_DAGRunCtx *run_info, RedisModuleString *t_name, RAI_Error *err);
24-
2521
int RAI_DAGAddTensorSet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor *tensor);
2622

2723
int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err);
2824

25+
size_t RAI_DAGNumOps(RAI_DAGRunCtx *run_info);
26+
2927
void RAI_DAGFree(RAI_DAGRunCtx *run_info);
3028

3129
void RAI_DAGRunOpFree(RAI_DAGRunOp *dagOp);

src/DAG/dag_execute.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ int MangleTensorsNames(RedisAI_RunInfo *rinfo) {
109109
RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR PERSIST key cannot be found in DAG");
110110
goto cleanup;
111111
}
112+
if (AI_dictFind(mangled_persisted, key) != NULL) {
113+
AI_dictRelease(mangled_persisted);
114+
AI_dictReleaseIterator(iter);
115+
RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR PERSIST keys must be unique");
116+
goto cleanup;
117+
}
112118
int *instance = AI_dictGetVal(mangled_entry);
113119
char buf[16];
114120
sprintf(buf, "%04d", *instance);
@@ -278,3 +284,17 @@ RAI_Tensor *RAI_DAGOutputTensor(RAI_OnFinishCtx *finish_ctx, size_t index) {
278284
}
279285
return NULL;
280286
}
287+
288+
int RAI_DAGRunError(RAI_OnFinishCtx *finish_ctx) {
289+
return *((RedisAI_RunInfo *)finish_ctx)->dagError;
290+
}
291+
292+
RAI_Error *RAI_DAGCopyOpStatus(RAI_OnFinishCtx *finish_ctx, size_t index) {
293+
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)finish_ctx;
294+
RedisModule_Assert(index < rinfo->dagOpCount);
295+
RAI_Error *err;
296+
RAI_InitError(&err);
297+
RAI_SetError(err, RAI_GetErrorCode(rinfo->dagOps[index]->err),
298+
RAI_GetError(rinfo->dagOps[index]->err));
299+
return err;
300+
}

src/DAG/dag_execute.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,7 @@ int RAI_DAGRun(RAI_DAGRunCtx *run_info, RAI_OnFinishCB DAGAsyncFinish, void *pri
2222
size_t RAI_DAGNumOutputs(RAI_OnFinishCtx *finish_ctx);
2323

2424
RAI_Tensor *RAI_DAGOutputTensor(RAI_OnFinishCtx *finish_ctx, size_t index);
25+
26+
int RAI_DAGRunError(RAI_OnFinishCtx *finish_ctx);
27+
28+
RAI_Error *RAI_DAGCopyOpStatus(RAI_OnFinishCtx *finish_ctx, size_t index);

0 commit comments

Comments
 (0)