Skip to content

Commit fea057d

Browse files
committed
Add LLAPI that allows adding ops to a DAG from string.
1 parent 9d4b1db commit fea057d

File tree

16 files changed

+367
-279
lines changed

16 files changed

+367
-279
lines changed

src/DAG/dag.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor,
130130
}
131131
// Only if we got until here, tensor is saved in keyspace.
132132
RedisAI_ReplicateTensorSet(ctx, demangled_key_name, tensor);
133+
RedisModule_CloseKey(key);
133134
ret = REDISMODULE_OK;
134135

135136
clean_up:

src/DAG/dag_builder.c

Lines changed: 62 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,21 @@
11
#include "dag_builder.h"
22
#include "run_info.h"
3+
#include "dag_parser.h"
34
#include "string_utils.h"
45
#include "modelRun_ctx.h"
56

6-
static int _LoadTensorFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName,
7-
RedisModuleKey **key, RAI_Tensor **tensor, RAI_Error *err) {
8-
9-
int res = REDISMODULE_ERR;
10-
*key = RedisModule_OpenKey(ctx, keyName, REDISMODULE_READ);
11-
if (RedisModule_KeyType(*key) == REDISMODULE_KEYTYPE_EMPTY) {
12-
RAI_SetError(err, RAI_EDAGBUILDER, "ERR tensor key is empty");
13-
goto end;
14-
}
15-
if (RedisModule_ModuleTypeGetType(*key) != RedisAI_TensorType) {
16-
RAI_SetError(err, RAI_EDAGBUILDER, REDISMODULE_ERRORMSG_WRONGTYPE);
17-
goto end;
18-
}
19-
*tensor = RedisModule_ModuleTypeGetValue(*key);
20-
res = REDISMODULE_OK;
21-
22-
end:
23-
RedisModule_CloseKey(*key);
24-
return res;
25-
}
26-
27-
static int _RAI_DagLoadTensor(RAI_DAGRunCtx *run_info, RedisModuleString *key_name,
28-
RAI_Error *err) {
7+
int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor *tensor) {
298

309
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
31-
RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL);
32-
RAI_Tensor *t;
33-
RedisModuleKey *key;
34-
if (_LoadTensorFromKeyspace(ctx, key_name, &key, &t, err) == REDISMODULE_ERR) {
35-
RedisModule_FreeString(NULL, key_name);
36-
RedisModule_FreeThreadSafeContext(ctx);
37-
return REDISMODULE_ERR;
38-
}
10+
RedisModuleString *key_name = RedisModule_CreateString(NULL, t_name, strlen(t_name));
3911
// Add the tensor under its "mangled" key name to the DAG local context dict.
4012
char buf[16];
4113
sprintf(buf, "%04d", 1);
4214
RedisModule_StringAppendBuffer(NULL, key_name, buf, strlen(buf));
43-
AI_dictAdd(rinfo->dagTensorsContext, (void *)key_name, (void *)RAI_TensorGetShallowCopy(t));
15+
AI_dictAdd(rinfo->dagTensorsContext, (void *)key_name,
16+
(void *)RAI_TensorGetShallowCopy(tensor));
4417
RedisModule_FreeString(NULL, key_name);
45-
RedisModule_FreeThreadSafeContext(ctx);
18+
4619
return REDISMODULE_OK;
4720
}
4821

@@ -112,18 +85,6 @@ int RAI_DAGAddRunOp(RAI_DAGRunCtx *run_info, RAI_DAGRunOp *DAGop, RAI_Error *err
11285
return REDISMODULE_OK;
11386
}
11487

115-
int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err) {
116-
117-
RedisModuleString *key_name = RedisModule_CreateString(NULL, t_name, strlen(t_name));
118-
return _RAI_DagLoadTensor(run_info, key_name, err);
119-
}
120-
121-
int RAI_DAGLoadTensorRS(RAI_DAGRunCtx *run_info, RedisModuleString *t_name, RAI_Error *err) {
122-
123-
RedisModuleString *key_name = RedisModule_CreateStringFromString(NULL, t_name);
124-
return _RAI_DagLoadTensor(run_info, key_name, err);
125-
}
126-
12788
int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err) {
12889

12990
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
@@ -151,6 +112,62 @@ int RAI_DAGAddTensorSet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor
151112
return REDISMODULE_OK;
152113
}
153114

115+
int RAI_DAGAddOpsFromString(RAI_DAGRunCtx *run_info, const char *dag, RAI_Error *err) {
116+
117+
int res = REDISMODULE_ERR;
118+
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
119+
int argc = 0;
120+
char dag_string[strlen(dag) + 1];
121+
strcpy(dag_string, dag);
122+
123+
char *token = strtok(dag_string, " ");
124+
if (strcmp(token, "|>") != 0) {
125+
RAI_SetError(err, RAI_EDAGBUILDER, "DAG op should start with: '|>' ");
126+
return res;
127+
}
128+
RedisModuleString **argv = array_new(RedisModuleString *, 2);
129+
while (token != NULL) {
130+
RedisModuleString *RS_token = RedisModule_CreateString(NULL, token, strlen(token));
131+
argv = array_append(argv, RS_token);
132+
argc++;
133+
token = strtok(NULL, " ");
134+
}
135+
136+
size_t num_ops_before = array_len(rinfo->dagOps);
137+
size_t new_ops = 0;
138+
RAI_DagOp *op;
139+
for (size_t i = 0; i < argc; i++) {
140+
const char *arg_string = RedisModule_StringPtrLen(argv[i], NULL);
141+
if (strcmp(arg_string, "|>") == 0 && i < argc - 1) {
142+
RAI_InitDagOp(&op);
143+
rinfo->dagOps = array_append(rinfo->dagOps, op);
144+
new_ops++;
145+
op->argv = &argv[i + 1];
146+
} else {
147+
op->argc++;
148+
}
149+
}
150+
151+
if (ParseDAGOps(rinfo, num_ops_before, new_ops) != REDISMODULE_OK) {
152+
// Remove all ops that where added before the error and go back to the initial state.
153+
RAI_SetError(err, RAI_GetErrorCode(rinfo->err), RAI_GetError(rinfo->err));
154+
for (size_t i = num_ops_before; i < array_len(rinfo->dagOps); i++) {
155+
RAI_FreeDagOp(rinfo->dagOps[i]);
156+
}
157+
rinfo->dagOps = array_trimm_len(rinfo->dagOps, num_ops_before);
158+
goto cleanup;
159+
}
160+
rinfo->dagOpCount = array_len(rinfo->dagOps);
161+
res = REDISMODULE_OK;
162+
163+
cleanup:
164+
for (size_t i = 0; i < argc; i++) {
165+
RedisModule_FreeString(NULL, argv[i]);
166+
}
167+
array_free(argv);
168+
return res;
169+
}
170+
154171
size_t RAI_DAGNumOps(RAI_DAGRunCtx *run_info) {
155172
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
156173
return array_len(rinfo->dagOps);

src/DAG/dag_builder.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,12 @@ int RAI_DAGRunOpAddOutput(RAI_DAGRunOp *DAGOp, const char *output);
4545
int RAI_DAGAddRunOp(RAI_DAGRunCtx *run_info, RAI_DAGRunOp *DAGop, RAI_Error *err);
4646

4747
/**
48-
* @brief Load a tensor from keyspace to the DAG local context.
48+
* @brief Load a given tensor to the DAG local context.
4949
* @param runInfo The DAG to load the tensor into.
5050
* @param tname The tensor key.
51-
* @param err Error is returned in case that the key does not exist, or not holding a tensor type.
51+
* @param tensor The tensor to load to the DAG (we load a shallow copy).
5252
*/
53-
int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err);
54-
55-
/**
56-
* @brief Load a tensor from keyspace to the DAG local context.
57-
* @param runInfo The DAG to load the tensor into.
58-
* @param tname The tensor key (can hold any binary string).
59-
* @param err Error is returned in case that the key does not exist, or not holding a tensor type.
60-
*/
61-
int RAI_DAGLoadTensorRS(RAI_DAGRunCtx *run_info, RedisModuleString *t_name, RAI_Error *err);
53+
int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor *tensor);
6254

6355
/**
6456
* @brief Append a TENSORSET op to a DAG (can use to load an intermediate tensors)
@@ -74,6 +66,16 @@ int RAI_DAGAddTensorSet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor
7466
*/
7567
int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err);
7668

69+
/**
70+
* @brief Add ops to a DAG from string (according to the command syntax). In case of a valid
71+
* string, the ops are added to the DAG run info, and otherwise all the ops are discarded.
72+
* @param runInfo The DAG to insert the ops into.
73+
* @param dag The string representing the DAG ops to add.
74+
* @param err Error is returned in case of a MODELRUN op if the number of inputs and outputs
75+
* given to the op does not match to the number of inputs and outputs in the model definition.
76+
*/
77+
int RAI_DAGAddOpsFromString(RAI_DAGRunCtx *run_info, const char *dag, RAI_Error *err);
78+
7779
/**
7880
* @brief Returns the number of ops in a DAG.
7981
*/

0 commit comments

Comments
 (0)