Skip to content

Commit e3d2db2

Browse files
committed
Validate that we don't load more than one tensor under the same name to DAG in LLAPI.
1 parent 322890a commit e3d2db2

File tree

5 files changed

+37
-4
lines changed

5 files changed

+37
-4
lines changed

src/DAG/dag_builder.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor *t
3030
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
3131
RedisModuleString *key_name = RedisModule_CreateString(NULL, t_name, strlen(t_name));
3232

33+
// Cannot load more than one tensor under the same name
34+
if (AI_dictFind(rinfo->tensorsNamesToIndices, key_name) != NULL) {
35+
RedisModule_FreeString(NULL, key_name);
36+
return REDISMODULE_ERR;
37+
}
38+
3339
// Add the tensor to the DAG shared tensors and map its name to the relevant index.
3440
size_t index = array_len(rinfo->dagSharedTensors);
3541
AI_dictAdd(rinfo->tensorsNamesToIndices, (void *)key_name, (void *)index);

src/DAG/dag_parser.c

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@ int ParseDAGRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleS
246246
const char *arg_string = RedisModule_StringPtrLen(argv[arg_pos], NULL);
247247

248248
if (!strcasecmp(arg_string, "LOAD") && !load_complete && chainingOpCount == 0) {
249-
/* Load the required tensors from key space and store them in both
250-
dagTensorsLoadedContext and dagTensorsContext dicts. */
249+
/* Load the required tensors from key space to the dag shared tensors
250+
* array, and save a mapping of their names to the corresponding indices. */
251251
const int parse_result =
252252
_ParseDAGLoadArgs(ctx, &argv[arg_pos], argc - arg_pos, rinfo->tensorsNamesToIndices,
253253
&rinfo->dagSharedTensors, "|>", rinfo->err);
@@ -263,8 +263,9 @@ int ParseDAGRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleS
263263
"ERR PERSIST cannot be specified in a read-only DAG");
264264
goto cleanup;
265265
}
266-
/* Store the keys to persist in dagTensorsPersistedContext dict.
267-
These keys will be populated later on with actual tensors. */
266+
/* Store the keys to persist in persistTensors dict, these keys will
267+
* be mapped later to the indices in the dagSharedTensors array in which the
268+
* tensors to persist will be found by the end of the DAG run. */
268269
const int parse_result = _ParseDAGPersistArgs(&argv[arg_pos], argc - arg_pos,
269270
rinfo->persistTensors, "|>", rinfo->err);
270271
if (parse_result <= 0)

tests/module/DAG_utils.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,26 @@ static void _DAGFinishFunc(RAI_OnFinishCtx *onFinishCtx, void *private_data) {
7272
pthread_cond_signal(&global_cond);
7373
}
7474

75+
int testLoadTensor(RedisModuleCtx *ctx) {
76+
RAI_DAGRunCtx *run_info = RedisAI_DAGRunCtxCreate();
77+
int res = LLAPIMODULE_ERR;
78+
RAI_Tensor *t = (RAI_Tensor *)_getFromKeySpace(ctx, "a{1}");
79+
if (RedisAI_DAGLoadTensor(run_info, "input", t) != REDISMODULE_OK) {
80+
goto cleanup;
81+
}
82+
t = (RAI_Tensor *)_getFromKeySpace(ctx, "b{1}");
83+
84+
// cannot load more than one tensor under the same name.
85+
if (RedisAI_DAGLoadTensor(run_info, "input", t) != REDISMODULE_ERR) {
86+
goto cleanup;
87+
}
88+
res = LLAPIMODULE_OK;
89+
90+
cleanup:
91+
RedisAI_DAGFree(run_info);
92+
return res;
93+
}
94+
7595
int testModelRunOpError(RedisModuleCtx *ctx) {
7696

7797
RAI_DAGRunCtx *run_info = RedisAI_DAGRunCtxCreate();

tests/module/DAG_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ typedef struct RAI_RunResults {
1111
RAI_Error *error;
1212
} RAI_RunResults;
1313

14+
int testLoadTensor(RedisModuleCtx *ctx);
15+
1416
int testModelRunOpError(RedisModuleCtx *ctx);
1517

1618
int testEmptyDAGError(RedisModuleCtx *ctx);

tests/module/LLAPI.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,10 @@ int RAI_llapi_DAGRun(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
243243
return REDISMODULE_OK;
244244
}
245245

246+
// Test the case a successful and failure tensor load input to DAG.
247+
if(testLoadTensor(ctx) != LLAPIMODULE_OK) {
248+
return RedisModule_ReplyWithSimpleString(ctx, "LOAD tensor test failed");
249+
}
246250
// Test the case of a failure due to addition of a non compatible MODELRUN op.
247251
if(testModelRunOpError(ctx) != LLAPIMODULE_OK) {
248252
return RedisModule_ReplyWithSimpleString(ctx, "MODELRUN op error test failed");

0 commit comments

Comments
 (0)