Skip to content

Commit 7c696d3

Browse files
committed
LOAD, MODELRUN, TENSORSET in DAG - test LLAPI (with/without errors) via test module (passed valgrind)
1 parent 5e42d0b commit 7c696d3

File tree

17 files changed

+478
-280
lines changed

17 files changed

+478
-280
lines changed

src/DAG/dag.c

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *curre
5757
for (uint i = 0; i < n_inkeys; i++) {
5858
RAI_Tensor *inputTensor;
5959
const int get_result = RAI_getTensorFromLocalContext(
60-
NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err);
60+
rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err);
6161
if (get_result == REDISMODULE_ERR) {
6262
// We check for this outside the function
6363
// this check cannot be covered by tests
@@ -198,7 +198,7 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
198198
for (uint i = 0; i < n_inkeys; i++) {
199199
RAI_Tensor *inputTensor;
200200
const int get_result = RAI_getTensorFromLocalContext(
201-
NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err);
201+
rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err);
202202
if (get_result == REDISMODULE_ERR) {
203203
// We check for this outside the function
204204
// this check cannot be covered by tests
@@ -255,8 +255,7 @@ size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) {
255255
if (rinfo->single_op_dag) {
256256
input = op->mctx->inputs[i].tensor;
257257
} else {
258-
RAI_getTensorFromLocalContext(NULL, rinfo->dagTensorsContext, op->inkeys[i], &input,
259-
op->err);
258+
RAI_getTensorFromLocalContext(rinfo->dagTensorsContext, op->inkeys[i], &input, op->err);
260259
}
261260
// We are expecting input != NULL, because we only reach this function if all inputs
262261
// are available in context for the current dagOp. We could be more defensive
@@ -304,14 +303,14 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
304303
if (rinfo1->single_op_dag == 1) {
305304
input1 = op1->mctx->inputs[i].tensor;
306305
} else {
307-
RAI_getTensorFromLocalContext(NULL, rinfo1->dagTensorsContext, op1->inkeys[i], &input1,
306+
RAI_getTensorFromLocalContext(rinfo1->dagTensorsContext, op1->inkeys[i], &input1,
308307
op1->err);
309308
}
310309
RAI_Tensor *input2;
311310
if (rinfo2->single_op_dag == 1) {
312311
input2 = op2->mctx->inputs[i].tensor;
313312
} else {
314-
RAI_getTensorFromLocalContext(NULL, rinfo2->dagTensorsContext, op2->inkeys[i], &input2,
313+
RAI_getTensorFromLocalContext(rinfo2->dagTensorsContext, op2->inkeys[i], &input2,
315314
op2->err);
316315
}
317316
if (input1 == NULL || input2 == NULL) {
@@ -637,8 +636,8 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
637636
case REDISAI_DAG_CMD_TENSORGET: {
638637
rinfo->dagReplyLength++;
639638
RAI_Tensor *t;
640-
int res = RAI_getTensorFromLocalContext(NULL, rinfo->dagTensorsContext,
641-
currentOp->inkeys[0], &t, currentOp->err);
639+
int res = RAI_getTensorFromLocalContext(rinfo->dagTensorsContext, currentOp->inkeys[0],
640+
&t, currentOp->err);
642641
if (res != REDISMODULE_OK) {
643642
RedisModule_ReplyWithError(ctx, currentOp->err->detail_oneline);
644643
dag_error = 1;

src/DAG/dag_builder.c

Lines changed: 85 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,129 +1,121 @@
11
#include "dag_builder.h"
22
#include "run_info.h"
33
#include "string_utils.h"
4+
#include "modelRun_ctx.h"
45

5-
int _LoadTensorFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisModuleKey **key,
6-
RAI_Tensor **tensor, RAI_Error *err) {
6+
static int _LoadTensorFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName,
7+
RedisModuleKey **key, RAI_Tensor **tensor, RAI_Error *err) {
78

9+
int res = REDISMODULE_ERR;
10+
// RedisModule_ThreadSafeContextLock(ctx);
811
*key = RedisModule_OpenKey(ctx, keyName, REDISMODULE_READ);
912
if (RedisModule_KeyType(*key) == REDISMODULE_KEYTYPE_EMPTY) {
10-
RedisModule_CloseKey(*key);
1113
RAI_SetError(err, RAI_EDAGBUILDER, "ERR tensor key is empty");
12-
return REDISMODULE_ERR;
14+
goto end;
1315
}
1416
if (RedisModule_ModuleTypeGetType(*key) != RedisAI_TensorType) {
15-
RedisModule_CloseKey(*key);
1617
RAI_SetError(err, RAI_EDAGBUILDER, REDISMODULE_ERRORMSG_WRONGTYPE);
17-
return REDISMODULE_ERR;
18+
goto end;
1819
}
1920
*tensor = RedisModule_ModuleTypeGetValue(*key);
21+
res = REDISMODULE_OK;
22+
23+
end:
2024
RedisModule_CloseKey(*key);
25+
// RedisModule_ThreadSafeContextUnlock(ctx);
26+
return res;
27+
}
28+
29+
static int _RAI_DagLoadTensor(RAI_DAGRunCtx *run_info, RedisModuleString *key_name,
30+
RAI_Error *err) {
31+
32+
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
33+
RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL);
34+
RAI_Tensor *t;
35+
RedisModuleKey *key;
36+
if (_LoadTensorFromKeyspace(ctx, key_name, &key, &t, err) == REDISMODULE_ERR) {
37+
RedisModule_FreeString(NULL, key_name);
38+
RedisModule_FreeThreadSafeContext(ctx);
39+
return REDISMODULE_ERR;
40+
}
41+
// Add the tensor under its "mangled" key name to the DAG local context dict.
42+
char buf[16];
43+
sprintf(buf, "%04d", 1);
44+
RedisModule_StringAppendBuffer(NULL, key_name, buf, strlen(buf));
45+
AI_dictAdd(rinfo->dagTensorsContext, (void *)key_name, (void *)RAI_TensorGetShallowCopy(t));
46+
RedisModule_FreeString(NULL, key_name);
47+
RedisModule_FreeThreadSafeContext(ctx);
2148
return REDISMODULE_OK;
2249
}
2350

24-
RAI_DAGRunCtx *RAI_DagRunCtxCreate(void) {
51+
RAI_DAGRunCtx *RAI_DAGRunCtxCreate(void) {
2552
RedisAI_RunInfo *rinfo;
2653
RAI_InitRunInfo(&rinfo);
2754
return (RAI_DAGRunCtx *)rinfo;
2855
}
2956

30-
int RAI_DagAddModelRun_(RAI_DAGRunCtx *run_info, RAI_ModelRunCtx *mctx, RedisModuleString **inputs,
31-
RedisModuleString **outputs, RAI_Error *err) {
32-
if (array_len(mctx->inputs) != 0 || array_len(mctx->outputs) != 0) {
33-
RAI_SetError(
34-
err, RAI_EDAGBUILDER,
35-
"Model run context cannot contain inputs or outputs when it is a part of a DAG");
36-
return REDISMODULE_ERR;
37-
}
38-
RAI_Model *model = mctx->model;
39-
if (model->ninputs != array_len(inputs)) {
40-
RAI_SetError(err, RAI_EDAGBUILDER,
41-
"Number of keys given as INPUTS does not match model definition");
42-
return REDISMODULE_ERR;
43-
}
44-
if (model->noutputs != array_len(outputs)) {
45-
RAI_SetError(err, RAI_EDAGBUILDER,
46-
"Number of keys given as OUTPUTS does not match model definition");
47-
return REDISMODULE_ERR;
48-
}
49-
57+
RAI_DAGRunOp *RAI_DAGCreateModelRunOp(RAI_DAGRunCtx *run_info, RAI_Model *model) {
5058
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
59+
RAI_ModelRunCtx *mctx = RAI_ModelRunCtxCreate(model);
5160
RAI_DagOp *op;
5261
RAI_InitDagOp(&op);
53-
rinfo->dagOps = array_append(rinfo->dagOps, op);
5462

5563
op->commandType = REDISAI_DAG_CMD_MODELRUN;
5664
op->mctx = mctx;
5765
op->devicestr = model->devicestr;
58-
op->inkeys = inputs;
59-
op->outkeys = outputs;
6066
op->runkey = RAI_HoldString(NULL, (RedisModuleString *)model->infokey);
61-
return REDISMODULE_OK;
67+
return (RAI_DAGRunOp *)op;
6268
}
6369

64-
int RAI_DagAddModelRun(RAI_DAGRunCtx *run_info, RAI_ModelRunCtx *mctx, const char **inputs,
65-
size_t ninputs, const char **outputs, size_t noutputs, RAI_Error *err) {
70+
int RAI_DAGRunOpAddInput(RAI_DAGRunOp *DAGOp, const char *input) {
71+
RAI_DagOp *op = (RAI_DagOp *)DAGOp;
72+
RedisModuleString *inkey = RedisModule_CreateString(NULL, input, strlen(input));
73+
op->inkeys = array_append(op->inkeys, inkey);
74+
return REDISMODULE_OK;
75+
}
6676

67-
RedisModuleString **inkeys = array_new(RedisModuleString *, 1);
68-
for (size_t i = 0; i < ninputs; i++) {
69-
RedisModuleString *inkey = RedisModule_CreateString(NULL, inputs[i], strlen(inputs[i]));
70-
inkeys = array_append(inkeys, inkey);
71-
}
72-
RedisModuleString **outkeys = array_new(RedisModuleString *, 1);
73-
for (size_t i = 0; i < noutputs; i++) {
74-
RedisModuleString *outkey = RedisModule_CreateString(NULL, outputs[i], strlen(outputs[i]));
75-
outkeys = array_append(outkeys, outkey);
76-
}
77-
return RAI_DagAddModelRun_(run_info, mctx, inkeys, outkeys, err);
77+
int RAI_DAGRunOpAddOutput(RAI_DAGRunOp *DAGOp, const char *output) {
78+
RAI_DagOp *op = (RAI_DagOp *)DAGOp;
79+
RedisModuleString *outkey = RedisModule_CreateString(NULL, output, strlen(output));
80+
op->outkeys = array_append(op->outkeys, outkey);
81+
return REDISMODULE_OK;
7882
}
7983

80-
int RedisAI_DagAddLoadPhase_(RAI_DAGRunCtx *run_info, RedisModuleString **keys_to_load,
81-
RAI_Error *err) {
84+
int RAI_DAGAddRunOp(RAI_DAGRunCtx *run_info, RAI_DAGRunOp *DAGop, RAI_Error *err) {
8285

83-
int status = REDISMODULE_ERR;
86+
RAI_DagOp *op = (RAI_DagOp *)DAGop;
8487
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
85-
RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL);
86-
RedisModule_ThreadSafeContextLock(ctx);
87-
size_t n_keys = array_len(keys_to_load);
88-
89-
for (size_t i = 0; i < n_keys; i++) {
90-
RAI_Tensor *t;
91-
RedisModuleKey *key;
92-
if (_LoadTensorFromKeyspace(ctx, keys_to_load[i], &key, &t, err) == REDISMODULE_ERR) {
93-
goto cleanup;
88+
if (op->mctx) {
89+
RAI_Model *model = op->mctx->model;
90+
if (model->ninputs != array_len(op->inkeys)) {
91+
RAI_SetError(err, RAI_EDAGBUILDER,
92+
"Number of keys given as INPUTS does not match model definition");
93+
return REDISMODULE_ERR;
94+
}
95+
if (model->noutputs != array_len(op->outkeys)) {
96+
RAI_SetError(err, RAI_EDAGBUILDER,
97+
"Number of keys given as OUTPUTS does not match model definition");
98+
return REDISMODULE_ERR;
9499
}
95-
// Add the tensor under its "mangled" key name to the DAG local context dict.
96-
char buf[16];
97-
sprintf(buf, "%04d", 1);
98-
RedisModule_StringAppendBuffer(NULL, keys_to_load[i], buf, strlen(buf));
99-
AI_dictAdd(rinfo->dagTensorsContext, (void *)keys_to_load[i],
100-
(void *)RAI_TensorGetShallowCopy(t));
101100
}
102-
status = REDISMODULE_OK;
101+
rinfo->dagOps = array_append(rinfo->dagOps, op);
103102

104-
cleanup:
105-
RedisModule_ThreadSafeContextUnlock(ctx);
106-
for (size_t i = 0; i < n_keys; i++) {
107-
RedisModule_FreeString(NULL, keys_to_load[i]);
108-
}
109-
array_free(keys_to_load);
110-
return status;
103+
return REDISMODULE_OK;
111104
}
112105

113-
int RedisAI_DagAddLoadPhase(RAI_DAGRunCtx *run_info, const char **t_names, uint n, RAI_Error *err) {
114-
if (n == 0) {
115-
RAI_SetError(err, RAI_EDAGBUILDER, "Number of keys to LOAD must be positive");
116-
return REDISMODULE_ERR;
117-
}
118-
RedisModuleString **keys_to_load = array_new(RedisModuleString *, 1);
119-
for (size_t i = 0; i < n; i++) {
120-
RedisModuleString *key = RedisModule_CreateString(NULL, t_names[i], strlen(t_names[i]));
121-
keys_to_load = array_append(keys_to_load, key);
122-
}
123-
return RedisAI_DagAddLoadPhase_(run_info, keys_to_load, err);
106+
int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err) {
107+
108+
RedisModuleString *key_name = RedisModule_CreateString(NULL, t_name, strlen(t_name));
109+
return _RAI_DagLoadTensor(run_info, key_name, err);
124110
}
125111

126-
int RAI_DagAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err) {
112+
int RAI_DAGLoadTensorRS(RAI_DAGRunCtx *run_info, RedisModuleString *t_name, RAI_Error *err) {
113+
114+
RedisModuleString *key_name = RedisModule_CreateStringFromString(NULL, t_name);
115+
return _RAI_DagLoadTensor(run_info, key_name, err);
116+
}
117+
118+
int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err) {
127119

128120
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
129121
RAI_DagOp *op;
@@ -134,4 +126,14 @@ int RAI_DagAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *
134126
RedisModuleString *name = RedisModule_CreateString(NULL, t_name, strlen(t_name));
135127
op->inkeys = array_append(op->inkeys, name);
136128
return REDISMODULE_OK;
137-
}
129+
}
130+
131+
void RAI_DAGRunOpFree(RAI_DAGRunOp *dagOp) {
132+
RAI_DagOp *op = (RAI_DagOp *)dagOp;
133+
RAI_FreeDagOp(op);
134+
}
135+
136+
void RAI_DAGFree(RAI_DAGRunCtx *run_info) {
137+
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
138+
RAI_FreeRunInfo(rinfo);
139+
}

src/DAG/dag_builder.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,22 @@
22

33
#include "redisai.h"
44

5-
RAI_DAGRunCtx *RedisAI_DagRunCtxCreate(void);
5+
RAI_DAGRunCtx *RAI_DAGRunCtxCreate(void);
66

7-
int RAI_DagAddModelRun_(RAI_DAGRunCtx *run_info, RAI_ModelRunCtx *mctx, RedisModuleString **inputs,
8-
RedisModuleString **outputs, RAI_Error *err);
7+
RAI_DAGRunOp *RAI_DAGCreateModelRunOp(RAI_DAGRunCtx *run_info, RAI_Model *model);
98

10-
int RAI_DagAddModelRun(RAI_DAGRunCtx *run_info, RAI_ModelRunCtx *mctx, const char **inputs,
11-
size_t ninputs, const char **outputs, size_t noutputs, RAI_Error *err);
9+
int RAI_DAGRunOpAddInput(RAI_DAGRunOp *DAGOp, const char *input);
1210

13-
int RAI_DagAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err);
11+
int RAI_DAGRunOpAddOutput(RAI_DAGRunOp *DAGOp, const char *output);
12+
13+
int RAI_DAGAddRunOp(RAI_DAGRunCtx *run_info, RAI_DAGRunOp *DAGop, RAI_Error *err);
14+
15+
int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err);
16+
17+
int RAI_DAGLoadTensorRS(RAI_DAGRunCtx *run_info, RedisModuleString *t_name, RAI_Error *err);
18+
19+
int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err);
20+
21+
void RAI_DAGFree(RAI_DAGRunCtx *run_info);
22+
23+
void RAI_DAGRunOpFree(RAI_DAGRunOp *dagOp);

src/DAG/dag_execute.c

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "background_workers.h"
44
#include "util/string_utils.h"
55

6-
int MangleTensorsNames(RedisAI_RunInfo *rinfo, RAI_Error *err) {
6+
int MangleTensorsNames(RedisAI_RunInfo *rinfo) {
77

88
int res = REDISMODULE_ERR;
99
AI_dict *mangled_tensors = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
@@ -35,7 +35,7 @@ int MangleTensorsNames(RedisAI_RunInfo *rinfo, RAI_Error *err) {
3535
AI_dictEntry *entry = AI_dictFind(mangled_tensors, key);
3636
if (!entry) {
3737
array_free(mangled_inkeys);
38-
RAI_SetError(err, RAI_EDAGRUN, "ERR INPUT key cannot be found in DAG");
38+
RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR INPUT key cannot be found in DAG");
3939
goto cleanup;
4040
}
4141
int *instance = AI_dictGetVal(entry);
@@ -95,7 +95,7 @@ int MangleTensorsNames(RedisAI_RunInfo *rinfo, RAI_Error *err) {
9595
if (!mangled_entry) {
9696
AI_dictRelease(mangled_persisted);
9797
AI_dictReleaseIterator(iter);
98-
RAI_SetError(err, RAI_EDAGRUN, "ERR PERSIST key cannot be found in DAG");
98+
RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR PERSIST key cannot be found in DAG");
9999
goto cleanup;
100100
}
101101
int *instance = AI_dictGetVal(mangled_entry);
@@ -211,8 +211,8 @@ int DAG_InsertDAGToQueue(RedisAI_RunInfo *rinfo) {
211211
return REDISMODULE_OK;
212212
}
213213

214-
int RAI_DagRunAsync(RAI_DAGRunCtx *run_info, RAI_OnFinishCB DAGAsyncFinish, void *private_data,
215-
RAI_Error *err) {
214+
int RAI_DAGRun(RAI_DAGRunCtx *run_info, RAI_OnFinishCB DAGAsyncFinish, void *private_data,
215+
RAI_Error *err) {
216216

217217
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
218218
rinfo->dagOpCount = array_len(rinfo->dagOps);
@@ -222,15 +222,45 @@ int RAI_DagRunAsync(RAI_DAGRunCtx *run_info, RAI_OnFinishCB DAGAsyncFinish, void
222222
}
223223
// Make the inkeys and outkeys of the DAG ops unique, to ensure that the operations
224224
// will be execute in the right order.
225-
if (MangleTensorsNames(rinfo, err) != REDISMODULE_OK) {
225+
if (MangleTensorsNames(rinfo) != REDISMODULE_OK) {
226+
RAI_SetError(err, rinfo->err->code, rinfo->err->detail);
226227
return REDISMODULE_ERR;
227228
}
228229
rinfo->OnFinish = (RedisAI_OnFinishCB)DAGAsyncFinish;
229230
rinfo->private_data = private_data;
230231
if (DAG_InsertDAGToQueue(rinfo) != REDISMODULE_OK) {
231232
RAI_SetError(err, rinfo->err->code, rinfo->err->detail);
232-
RAI_ClearError(rinfo->err);
233233
return REDISMODULE_ERR;
234234
}
235235
return REDISMODULE_OK;
236-
}
236+
}
237+
238+
size_t RAI_DAGNumOutputs(RAI_OnFinishCtx *finish_ctx) {
239+
size_t n_outputs = 0;
240+
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)finish_ctx;
241+
for (size_t i = 0; i < rinfo->dagOpCount; i++) {
242+
if (rinfo->dagOps[i]->commandType == REDISAI_DAG_CMD_TENSORGET) {
243+
n_outputs++;
244+
}
245+
}
246+
return n_outputs;
247+
}
248+
249+
RAI_Tensor *RAI_DAGOutputTensor(RAI_OnFinishCtx *finish_ctx, size_t index) {
250+
size_t tensor_get_op_ind = -1;
251+
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)finish_ctx;
252+
for (size_t i = 0; i < rinfo->dagOpCount; i++) {
253+
RAI_DagOp *op = rinfo->dagOps[i];
254+
if (op->commandType == REDISAI_DAG_CMD_TENSORGET) {
255+
tensor_get_op_ind++;
256+
if (tensor_get_op_ind == index) {
257+
RAI_Tensor *t;
258+
int res = RAI_getTensorFromLocalContext(rinfo->dagTensorsContext, op->inkeys[0], &t,
259+
op->err);
260+
RedisModule_Assert(res == REDISMODULE_OK);
261+
return t;
262+
}
263+
}
264+
}
265+
return NULL;
266+
}

0 commit comments

Comments
 (0)