Skip to content

Commit ba93a39

Browse files
authored
Valgrind fixes (#462)
* Fix leak * Fix shallow copying load and persisted tensors * Fix typo
1 parent a939fc9 commit ba93a39

File tree

4 files changed

+21
-37
lines changed

4 files changed

+21
-37
lines changed

get_deps.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ if [[ $WITH_PT != 0 ]]; then
253253

254254
echo "Done."
255255
else
256-
echo "librotch is in place."
256+
echo "libtorch is in place."
257257
fi
258258
else
259259
echo "SKipping libtorch."

src/dag.c

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv,
447447
RedisModule_ReplyWithError(ctx, "ERR could not save tensor");
448448
rinfo->dagReplyLength++;
449449
} else {
450-
if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, tensor) !=
450+
if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, RAI_TensorGetShallowCopy(tensor)) !=
451451
REDISMODULE_OK) {
452452
RedisModule_ReplyWithError(ctx, "ERR could not save tensor");
453453
rinfo->dagReplyLength++;
@@ -473,6 +473,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv,
473473
localcontext_key_name);
474474
local_entry = AI_dictNext(local_iter);
475475
}
476+
AI_dictReleaseIterator(local_iter);
476477

477478
for (size_t opN = 0; opN < array_len(rinfo->dagOps); opN++) {
478479
RedisModule_Log(
@@ -532,7 +533,7 @@ int RAI_parseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv,
532533
RedisModule_CloseKey(key);
533534
char *dictKey = (char*) RedisModule_Alloc((strlen(arg_string) + 5)*sizeof(char));
534535
sprintf(dictKey, "%s%04d", arg_string, 1);
535-
AI_dictAdd(*localContextDict, (void*)dictKey, (void *)t);
536+
AI_dictAdd(*localContextDict, (void*)dictKey, (void *)RAI_TensorGetShallowCopy(t));
536537
AI_dictAdd(*loadedContextDict, (void*)dictKey, (void *)1);
537538
RedisModule_Free(dictKey);
538539
number_loaded_keys++;
@@ -796,6 +797,7 @@ int RedisAI_DagRunSyntaxParser(RedisModuleCtx *ctx, RedisModuleString **argv,
796797
const char* key = RedisModule_StringPtrLen(currentOp->inkeys[j], NULL);
797798
AI_dictEntry *entry = AI_dictFind(mangled_tensors, key);
798799
if (!entry) {
800+
AI_dictRelease(mangled_tensors);
799801
return RedisModule_ReplyWithError(ctx,
800802
"ERR INPUT key cannot be found in DAG");
801803
}
@@ -837,6 +839,8 @@ int RedisAI_DagRunSyntaxParser(RedisModuleCtx *ctx, RedisModuleString **argv,
837839
char *key = (char *)AI_dictGetKey(entry);
838840
AI_dictEntry *mangled_entry = AI_dictFind(mangled_tensors, key);
839841
if (!mangled_entry) {
842+
AI_dictRelease(mangled_tensors);
843+
AI_dictRelease(mangled_persisted);
840844
return RedisModule_ReplyWithError(ctx,
841845
"ERR PERSIST key cannot be found in DAG");
842846
}
@@ -849,6 +853,7 @@ int RedisAI_DagRunSyntaxParser(RedisModuleCtx *ctx, RedisModuleString **argv,
849853
AI_dictReleaseIterator(iter);
850854
}
851855

856+
AI_dictRelease(rinfo->dagTensorsPersistedContext);
852857
rinfo->dagTensorsPersistedContext = mangled_persisted;
853858

854859
{

src/run_info.c

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -246,36 +246,9 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) {
246246
}
247247

248248
if (rinfo->dagTensorsContext) {
249-
AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext);
250-
AI_dictEntry *entry = AI_dictNext(iter);
251-
RAI_Tensor *tensor = NULL;
252-
253-
while (entry) {
254-
tensor = AI_dictGetVal(entry);
255-
char *key = (char *)AI_dictGetKey(entry);
256-
257-
if (tensor && key != NULL) {
258-
// if the key is persisted then we should not delete it
259-
AI_dictEntry *persisted_entry =
260-
AI_dictFind(rinfo->dagTensorsPersistedContext, key);
261-
// if the key was loaded from the keyspace then we should not delete it
262-
AI_dictEntry *loaded_entry =
263-
AI_dictFind(rinfo->dagTensorsLoadedContext, key);
264-
265-
if (persisted_entry == NULL && loaded_entry == NULL) {
266-
AI_dictDelete(rinfo->dagTensorsContext, key);
267-
}
268-
269-
if (persisted_entry) {
270-
AI_dictDelete(rinfo->dagTensorsPersistedContext, key);
271-
}
272-
if (loaded_entry) {
273-
AI_dictDelete(rinfo->dagTensorsLoadedContext, key);
274-
}
275-
}
276-
entry = AI_dictNext(iter);
277-
}
278-
AI_dictReleaseIterator(iter);
249+
AI_dictRelease(rinfo->dagTensorsContext);
250+
AI_dictRelease(rinfo->dagTensorsLoadedContext);
251+
AI_dictRelease(rinfo->dagTensorsPersistedContext);
279252
}
280253

281254
if (rinfo->dagOps) {

src/tensor.c

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,7 @@ int RAI_parseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
980980
int meta = 0;
981981
int blob = 0;
982982
int values = 0;
983+
int fmt_error = 0;
983984
for (int i=2; i<argc; i++) {
984985
const char *fmtstr = RedisModule_StringPtrLen(argv[i], NULL);
985986
if (!strcasecmp(fmtstr, "BLOB")) {
@@ -992,11 +993,15 @@ int RAI_parseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
992993
meta = 1;
993994
datafmt = REDISAI_DATA_NONE;
994995
} else {
995-
RedisModule_ReplyWithError(ctx, "ERR unsupported data format");
996-
return -1;
996+
fmt_error = 1;
997997
}
998998
}
999999

1000+
if (fmt_error) {
1001+
RedisModule_ReplyWithError(ctx, "ERR unsupported data format");
1002+
return -1;
1003+
}
1004+
10001005
if (blob && values) {
10011006
RedisModule_ReplyWithError(ctx, "ERR both BLOB and VALUES specified");
10021007
return -1;
@@ -1033,14 +1038,15 @@ int RAI_parseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
10331038

10341039
const long long ndims = RAI_TensorNumDims(t);
10351040

1036-
RedisModule_ReplyWithArray(ctx, resplen);
1037-
10381041
char *dtypestr = NULL;
10391042
const int dtypestr_result = Tensor_DataTypeStr(RAI_TensorDataType(t), &dtypestr);
10401043
if(dtypestr_result==REDISMODULE_ERR){
10411044
RedisModule_ReplyWithError(ctx, "ERR unsupported dtype");
10421045
return -1;
10431046
}
1047+
1048+
RedisModule_ReplyWithArray(ctx, resplen);
1049+
10441050
RedisModule_ReplyWithCString(ctx, "dtype");
10451051
RedisModule_ReplyWithCString(ctx, dtypestr);
10461052

0 commit comments

Comments
 (0)