|
1 | 1 | #include "dag_builder.h" |
2 | 2 | #include "run_info.h" |
| 3 | +#include "dag_parser.h" |
3 | 4 | #include "string_utils.h" |
4 | 5 | #include "modelRun_ctx.h" |
5 | 6 |
|
6 | | -static int _LoadTensorFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, |
7 | | - RedisModuleKey **key, RAI_Tensor **tensor, RAI_Error *err) { |
8 | | - |
9 | | - int res = REDISMODULE_ERR; |
10 | | - *key = RedisModule_OpenKey(ctx, keyName, REDISMODULE_READ); |
11 | | - if (RedisModule_KeyType(*key) == REDISMODULE_KEYTYPE_EMPTY) { |
12 | | - RAI_SetError(err, RAI_EDAGBUILDER, "ERR tensor key is empty"); |
13 | | - goto end; |
14 | | - } |
15 | | - if (RedisModule_ModuleTypeGetType(*key) != RedisAI_TensorType) { |
16 | | - RAI_SetError(err, RAI_EDAGBUILDER, REDISMODULE_ERRORMSG_WRONGTYPE); |
17 | | - goto end; |
18 | | - } |
19 | | - *tensor = RedisModule_ModuleTypeGetValue(*key); |
20 | | - res = REDISMODULE_OK; |
21 | | - |
22 | | -end: |
23 | | - RedisModule_CloseKey(*key); |
24 | | - return res; |
25 | | -} |
26 | | - |
27 | | -static int _RAI_DagLoadTensor(RAI_DAGRunCtx *run_info, RedisModuleString *key_name, |
28 | | - RAI_Error *err) { |
| 7 | +int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor *tensor) { |
29 | 8 |
|
30 | 9 | RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info; |
31 | | - RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL); |
32 | | - RAI_Tensor *t; |
33 | | - RedisModuleKey *key; |
34 | | - if (_LoadTensorFromKeyspace(ctx, key_name, &key, &t, err) == REDISMODULE_ERR) { |
35 | | - RedisModule_FreeString(NULL, key_name); |
36 | | - RedisModule_FreeThreadSafeContext(ctx); |
37 | | - return REDISMODULE_ERR; |
38 | | - } |
| 10 | + RedisModuleString *key_name = RedisModule_CreateString(NULL, t_name, strlen(t_name)); |
39 | 11 | // Add the tensor under its "mangled" key name to the DAG local context dict. |
40 | 12 | char buf[16]; |
41 | 13 | sprintf(buf, "%04d", 1); |
42 | 14 | RedisModule_StringAppendBuffer(NULL, key_name, buf, strlen(buf)); |
43 | | - AI_dictAdd(rinfo->dagTensorsContext, (void *)key_name, (void *)RAI_TensorGetShallowCopy(t)); |
| 15 | + AI_dictAdd(rinfo->dagTensorsContext, (void *)key_name, |
| 16 | + (void *)RAI_TensorGetShallowCopy(tensor)); |
44 | 17 | RedisModule_FreeString(NULL, key_name); |
45 | | - RedisModule_FreeThreadSafeContext(ctx); |
| 18 | + |
46 | 19 | return REDISMODULE_OK; |
47 | 20 | } |
48 | 21 |
|
@@ -112,18 +85,6 @@ int RAI_DAGAddRunOp(RAI_DAGRunCtx *run_info, RAI_DAGRunOp *DAGop, RAI_Error *err |
112 | 85 | return REDISMODULE_OK; |
113 | 86 | } |
114 | 87 |
|
115 | | -int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err) { |
116 | | - |
117 | | - RedisModuleString *key_name = RedisModule_CreateString(NULL, t_name, strlen(t_name)); |
118 | | - return _RAI_DagLoadTensor(run_info, key_name, err); |
119 | | -} |
120 | | - |
121 | | -int RAI_DAGLoadTensorRS(RAI_DAGRunCtx *run_info, RedisModuleString *t_name, RAI_Error *err) { |
122 | | - |
123 | | - RedisModuleString *key_name = RedisModule_CreateStringFromString(NULL, t_name); |
124 | | - return _RAI_DagLoadTensor(run_info, key_name, err); |
125 | | -} |
126 | | - |
127 | 88 | int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err) { |
128 | 89 |
|
129 | 90 | RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info; |
@@ -151,6 +112,62 @@ int RAI_DAGAddTensorSet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor |
151 | 112 | return REDISMODULE_OK; |
152 | 113 | } |
153 | 114 |
|
| 115 | +int RAI_DAGAddOpsFromString(RAI_DAGRunCtx *run_info, const char *dag, RAI_Error *err) { |
| 116 | + |
| 117 | + int res = REDISMODULE_ERR; |
| 118 | + RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info; |
| 119 | + int argc = 0; |
| 120 | + char dag_string[strlen(dag) + 1]; |
| 121 | + strcpy(dag_string, dag); |
| 122 | + |
| 123 | + char *token = strtok(dag_string, " "); |
| 124 | + if (strcmp(token, "|>") != 0) { |
| 125 | + RAI_SetError(err, RAI_EDAGBUILDER, "DAG op should start with: '|>' "); |
| 126 | + return res; |
| 127 | + } |
| 128 | + RedisModuleString **argv = array_new(RedisModuleString *, 2); |
| 129 | + while (token != NULL) { |
| 130 | + RedisModuleString *RS_token = RedisModule_CreateString(NULL, token, strlen(token)); |
| 131 | + argv = array_append(argv, RS_token); |
| 132 | + argc++; |
| 133 | + token = strtok(NULL, " "); |
| 134 | + } |
| 135 | + |
| 136 | + size_t num_ops_before = array_len(rinfo->dagOps); |
| 137 | + size_t new_ops = 0; |
| 138 | + RAI_DagOp *op; |
| 139 | + for (size_t i = 0; i < argc; i++) { |
| 140 | + const char *arg_string = RedisModule_StringPtrLen(argv[i], NULL); |
| 141 | + if (strcmp(arg_string, "|>") == 0 && i < argc - 1) { |
| 142 | + RAI_InitDagOp(&op); |
| 143 | + rinfo->dagOps = array_append(rinfo->dagOps, op); |
| 144 | + new_ops++; |
| 145 | + op->argv = &argv[i + 1]; |
| 146 | + } else { |
| 147 | + op->argc++; |
| 148 | + } |
| 149 | + } |
| 150 | + |
| 151 | + if (ParseDAGOps(rinfo, num_ops_before, new_ops) != REDISMODULE_OK) { |
| 152 | + // Remove all ops that where added before the error and go back to the initial state. |
| 153 | + RAI_SetError(err, RAI_GetErrorCode(rinfo->err), RAI_GetError(rinfo->err)); |
| 154 | + for (size_t i = num_ops_before; i < array_len(rinfo->dagOps); i++) { |
| 155 | + RAI_FreeDagOp(rinfo->dagOps[i]); |
| 156 | + } |
| 157 | + rinfo->dagOps = array_trimm_len(rinfo->dagOps, num_ops_before); |
| 158 | + goto cleanup; |
| 159 | + } |
| 160 | + rinfo->dagOpCount = array_len(rinfo->dagOps); |
| 161 | + res = REDISMODULE_OK; |
| 162 | + |
| 163 | +cleanup: |
| 164 | + for (size_t i = 0; i < argc; i++) { |
| 165 | + RedisModule_FreeString(NULL, argv[i]); |
| 166 | + } |
| 167 | + array_free(argv); |
| 168 | + return res; |
| 169 | +} |
| 170 | + |
154 | 171 | size_t RAI_DAGNumOps(RAI_DAGRunCtx *run_info) { |
155 | 172 | RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info; |
156 | 173 | return array_len(rinfo->dagOps); |
|
0 commit comments