Skip to content

Commit b0ea914

Browse files
committed
Merge with master
2 parents fea057d + ffd4084 commit b0ea914

36 files changed

+1172
-687
lines changed

docs/commands.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Depending on the specified reply format:
7979
1. The tensor's shape as an Array consisting of an item per dimension
8080
* **BLOB**: the tensor's binary data as a String. If used together with the **META** option, the binary data string will put after the metadata in the array reply.
8181
* **VALUES**: Array containing the numerical representation of the tensor's data. If used together with the **META** option, the binary data string will put after the metadata in the array reply.
82-
82+
* Default: **META** and **BLOB** are returned by default, in case that non of the arguments above is specified.
8383

8484

8585
**Examples**

opt/redis_valgrind.sup

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
obj:*/libtensorflow.so.*
66
}
77

8+
{
9+
ignore_unversioned_libs
10+
Memcheck:Leak
11+
...
12+
obj:*/libtensorflow_framework.so.*
13+
}
14+
815
{
916
ignore_unversioned_libs
1017
Memcheck:Leak
@@ -54,17 +61,3 @@
5461
fun:RAI_LoadBackend
5562
}
5663

57-
{
58-
<tf-operator new>
59-
Memcheck:Leak
60-
...
61-
fun:clone
62-
}
63-
64-
{
65-
<malloc>
66-
Memcheck:Leak
67-
fun:malloc
68-
...
69-
fun:clone
70-
}

src/DAG/dag.c

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *curre
9191

9292
static void Dag_StoreOutputsFromModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp) {
9393

94-
RAI_ContextReadLock(rinfo);
94+
RAI_ContextWriteLock(rinfo);
9595
const size_t noutputs = RAI_ModelRunCtxNumOutputs(currentOp->mctx);
9696
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
9797
RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(currentOp->mctx, outputNumber);
@@ -284,6 +284,9 @@ void RedisAI_BatchedDagRunSession_ModelRun_Step(RedisAI_RunInfo **batched_rinfo,
284284
if (rinfo->single_op_dag == 0)
285285
Dag_StoreOutputsFromModelRunCtx(rinfo, currentOp);
286286
}
287+
// Clear the result in case of an error.
288+
if (result == REDISMODULE_ERR)
289+
RAI_ClearError(&err);
287290
}
288291

289292
/**
@@ -452,16 +455,20 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
452455
return 1;
453456
}
454457

455-
int RedisAI_DagDeviceComplete(RedisAI_RunInfo *rinfo) {
458+
bool RedisAI_DagDeviceComplete(RedisAI_RunInfo *rinfo) {
456459
return rinfo->dagDeviceCompleteOpCount == rinfo->dagDeviceOpCount;
457460
}
458461

459-
int RedisAI_DagComplete(RedisAI_RunInfo *rinfo) {
462+
bool RedisAI_DagComplete(RedisAI_RunInfo *rinfo) {
460463
int completeOpCount = __atomic_load_n(rinfo->dagCompleteOpCount, __ATOMIC_RELAXED);
461464

462465
return completeOpCount == rinfo->dagOpCount;
463466
}
464467

468+
bool RedisAI_DagError(RedisAI_RunInfo *rinfo) {
469+
return __atomic_load_n(rinfo->dagError, __ATOMIC_RELAXED) != 0;
470+
}
471+
465472
RAI_DagOp *RedisAI_DagCurrentOp(RedisAI_RunInfo *rinfo) {
466473
if (rinfo->dagDeviceCompleteOpCount == rinfo->dagDeviceOpCount) {
467474
return NULL;
@@ -470,21 +477,21 @@ RAI_DagOp *RedisAI_DagCurrentOp(RedisAI_RunInfo *rinfo) {
470477
return rinfo->dagDeviceOps[rinfo->dagDeviceCompleteOpCount];
471478
}
472479

473-
void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady,
474-
int *currentOpBatchable) {
480+
void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, bool *currentOpReady,
481+
bool *currentOpBatchable) {
475482
RAI_DagOp *currentOp_ = RedisAI_DagCurrentOp(rinfo);
476483

477-
*currentOpReady = 0;
478-
*currentOpBatchable = 0;
484+
*currentOpReady = false;
485+
*currentOpBatchable = false;
479486

480487
if (currentOp_ == NULL) {
481488
return;
482489
}
483490

484491
if (currentOp_->mctx && currentOp_->mctx->model->opts.batchsize > 0) {
485-
*currentOpBatchable = 1;
492+
*currentOpBatchable = true;
486493
}
487-
*currentOpReady = 1;
494+
*currentOpReady = true;
488495
// If this is a single op dag, the op is definitely ready.
489496
if (rinfo->single_op_dag == 1)
490497
return;
@@ -495,7 +502,7 @@ void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady,
495502
for (int i = 0; i < n_inkeys; i++) {
496503
if (AI_dictFind(rinfo->dagTensorsContext, currentOp_->inkeys[i]) == NULL) {
497504
RAI_ContextUnlock(rinfo);
498-
*currentOpReady = 0;
505+
*currentOpReady = false;
499506
return;
500507
}
501508
}
@@ -604,13 +611,11 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
604611

605612
if (*rinfo->timedOut) {
606613
RedisModule_ReplyWithSimpleString(ctx, "TIMEDOUT");
607-
RAI_FreeRunInfo(rinfo);
608614
return REDISMODULE_OK;
609615
}
610616

611617
if (RAI_GetErrorCode(rinfo->err) == RAI_EDAGRUN) {
612618
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(rinfo->err));
613-
RAI_FreeRunInfo(rinfo);
614619
return REDISMODULE_OK;
615620
}
616621
int dag_error = 0;
@@ -717,7 +722,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
717722
if (!rinfo->single_op_dag) {
718723
RedisModule_ReplySetArrayLength(ctx, rinfo->dagReplyLength);
719724
}
720-
RAI_FreeRunInfo(rinfo);
721725
return REDISMODULE_OK;
722726
}
723727

@@ -745,15 +749,12 @@ int RedisAI_DagRun_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx, RedisMo
745749
return REDISMODULE_OK;
746750
}
747751

748-
void RunInfo_FreeData(RedisModuleCtx *ctx, void *rinfo) {}
749-
750-
void RedisAI_Disconnected(RedisModuleCtx *ctx, RedisModuleBlockedClient *bc) {
751-
RedisModule_Log(ctx, "warning", "Blocked client %p disconnected!", (void *)bc);
752-
}
752+
void RunInfo_FreeData(RedisModuleCtx *ctx, void *rinfo) { RAI_FreeRunInfo(rinfo); }
753753

754754
void DAG_ReplyAndUnblock(RedisAI_OnFinishCtx *ctx, void *private_data) {
755755

756756
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)ctx;
757-
if (rinfo->client)
757+
if (rinfo->client) {
758758
RedisModule_UnblockClient(rinfo->client, rinfo);
759+
}
759760
}

src/DAG/dag.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,25 @@
1919
* successfully. Since rinfo carries information on what queue
2020
* it has been placed in, there's no need to pass the device identifier.
2121
* @param rinfo context in which RedisAI blocking commands operate.
22-
* @return nonzero if all ops are complete for device, 0 otherwise
22+
* @return true if all ops are complete for device, 0 otherwise
2323
*/
24-
int RedisAI_DagDeviceComplete(RedisAI_RunInfo *rinfo);
24+
bool RedisAI_DagDeviceComplete(RedisAI_RunInfo *rinfo);
2525

2626
/**
2727
* Get whether all DAG ops have been executed successfully irrespective
2828
* of the device, i.e. if the DAG has been completely executed.
2929
* @param rinfo context in which RedisAI blocking commands operate.
30-
* @return nonzero of all ops in DAG are complete, 0 otherwise
30+
* @return true of all ops in DAG are complete, 0 otherwise
3131
*/
32-
int RedisAI_DagComplete(RedisAI_RunInfo *rinfo);
32+
bool RedisAI_DagComplete(RedisAI_RunInfo *rinfo);
33+
34+
/**
35+
* @brief Get an indication if an error happend during the dag run.
36+
*
37+
* @param rinfo context in which RedisAI blocking commands operate.
38+
* @return true if there was an error
39+
*/
40+
bool RedisAI_DagError(RedisAI_RunInfo *rinfo);
3341

3442
/**
3543
* Get current DAG op for the given device. An op is current if it's
@@ -50,7 +58,8 @@ RAI_DagOp *RedisAI_DagCurrentOp(RedisAI_RunInfo *rinfo);
5058
* a MODELRUN and is BATCHSIZE greater than zero
5159
* @return
5260
*/
53-
void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady, int *currentOpBatchable);
61+
void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, bool *currentOpReady,
62+
bool *currentOpBatchable);
5463

5564
/**
5665
* Get batching information about a DAG op.
@@ -142,9 +151,4 @@ int DAG_InsertDAGToQueue(RedisAI_RunInfo *rinfo);
142151
*/
143152
void RunInfo_FreeData(RedisModuleCtx *ctx, void *rinfo);
144153

145-
/**
146-
* @brief A callback to send to BlockClient.
147-
*/
148-
void RedisAI_Disconnected(RedisModuleCtx *ctx, RedisModuleBlockedClient *bc);
149-
150154
#endif /* SRC_DAG_H_ */

src/backends.c

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ RedisModuleString *RAI_GetBackendsPath(RedisModuleCtx *ctx) {
3939
RedisModuleString *module_path = RAI_GetModulePath(ctx);
4040
backends_path = RedisModule_CreateStringPrintf(ctx, "%s/backends",
4141
RedisModule_StringPtrLen(module_path, NULL));
42+
RedisModule_FreeString(ctx, module_path);
4243
}
4344

4445
return backends_path;
@@ -422,21 +423,26 @@ int RAI_LoadBackend(RedisModuleCtx *ctx, int backend, const char *path) {
422423
RedisModuleString *backends_path = RAI_GetBackendsPath(ctx);
423424
fullpath = RedisModule_CreateStringPrintf(
424425
ctx, "%s/%s", RedisModule_StringPtrLen(backends_path, NULL), path);
426+
RedisModule_FreeString(ctx, backends_path);
425427
}
426428

427429
int ret;
428430
switch (backend) {
429431
case RAI_BACKEND_TENSORFLOW:
430-
return RAI_LoadBackend_TensorFlow(ctx, RedisModule_StringPtrLen(fullpath, NULL));
432+
ret = RAI_LoadBackend_TensorFlow(ctx, RedisModule_StringPtrLen(fullpath, NULL));
433+
break;
431434
case RAI_BACKEND_TFLITE:
432-
return RAI_LoadBackend_TFLite(ctx, RedisModule_StringPtrLen(fullpath, NULL));
435+
ret = RAI_LoadBackend_TFLite(ctx, RedisModule_StringPtrLen(fullpath, NULL));
436+
break;
433437
case RAI_BACKEND_TORCH:
434-
return RAI_LoadBackend_Torch(ctx, RedisModule_StringPtrLen(fullpath, NULL));
438+
ret = RAI_LoadBackend_Torch(ctx, RedisModule_StringPtrLen(fullpath, NULL));
439+
break;
435440
case RAI_BACKEND_ONNXRUNTIME:
436-
return RAI_LoadBackend_ONNXRuntime(ctx, RedisModule_StringPtrLen(fullpath, NULL));
441+
ret = RAI_LoadBackend_ONNXRuntime(ctx, RedisModule_StringPtrLen(fullpath, NULL));
442+
break;
437443
}
438-
439-
return REDISMODULE_ERR;
444+
RedisModule_FreeString(ctx, fullpath);
445+
return ret;
440446
}
441447

442448
int RAI_LoadDefaultBackend(RedisModuleCtx *ctx, int backend) {

src/backends/onnxruntime.c

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
288288

289289
RAI_Device device;
290290
int64_t deviceid;
291+
char **inputs_ = NULL;
292+
char **outputs_ = NULL;
291293

292294
if (!parseDeviceStr(devicestr, &device, &deviceid)) {
293295
RAI_SetError(error, RAI_EMODELCREATE, "ERR unsupported device");
@@ -352,6 +354,41 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
352354
goto error;
353355
}
354356

357+
size_t n_input_nodes;
358+
status = ort->SessionGetInputCount(session, &n_input_nodes);
359+
if (status != NULL) {
360+
goto error;
361+
}
362+
363+
size_t n_output_nodes;
364+
status = ort->SessionGetOutputCount(session, &n_output_nodes);
365+
if (status != NULL) {
366+
goto error;
367+
}
368+
369+
OrtAllocator *allocator;
370+
status = ort->GetAllocatorWithDefaultOptions(&allocator);
371+
372+
inputs_ = array_new(char *, n_input_nodes);
373+
for (long long i = 0; i < n_input_nodes; i++) {
374+
char *input_name;
375+
status = ort->SessionGetInputName(session, i, allocator, &input_name);
376+
if (status != NULL) {
377+
goto error;
378+
}
379+
inputs_ = array_append(inputs_, input_name);
380+
}
381+
382+
outputs_ = array_new(char *, n_output_nodes);
383+
for (long long i = 0; i < n_output_nodes; i++) {
384+
char *output_name;
385+
status = ort->SessionGetOutputName(session, i, allocator, &output_name);
386+
if (status != NULL) {
387+
goto error;
388+
}
389+
outputs_ = array_append(outputs_, output_name);
390+
}
391+
355392
// Since ONNXRuntime doesn't have a re-serialization function,
356393
// we cache the blob in order to re-serialize it.
357394
// Not optimal for storage purposes, but again, it may be temporary
@@ -367,11 +404,29 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
367404
ret->opts = opts;
368405
ret->data = buffer;
369406
ret->datalen = modellen;
407+
ret->ninputs = n_input_nodes;
408+
ret->noutputs = n_output_nodes;
409+
ret->inputs = inputs_;
410+
ret->outputs = outputs_;
370411

371412
return ret;
372413

373414
error:
374415
RAI_SetError(error, RAI_EMODELCREATE, ort->GetErrorMessage(status));
416+
if (inputs_) {
417+
n_input_nodes = array_len(inputs_);
418+
for (uint32_t i = 0; i < n_input_nodes; i++) {
419+
status = ort->AllocatorFree(allocator, inputs_[i]);
420+
}
421+
array_free(inputs_);
422+
}
423+
if (outputs_) {
424+
n_output_nodes = array_len(outputs_);
425+
for (uint32_t i = 0; i < n_output_nodes; i++) {
426+
status = ort->AllocatorFree(allocator, outputs_[i]);
427+
}
428+
array_free(outputs_);
429+
}
375430
ort->ReleaseStatus(status);
376431
return NULL;
377432
}
@@ -381,6 +436,19 @@ void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) {
381436

382437
RedisModule_Free(model->data);
383438
RedisModule_Free(model->devicestr);
439+
OrtAllocator *allocator;
440+
OrtStatus *status = NULL;
441+
status = ort->GetAllocatorWithDefaultOptions(&allocator);
442+
for (uint32_t i = 0; i < model->ninputs; i++) {
443+
status = ort->AllocatorFree(allocator, model->inputs[i]);
444+
}
445+
array_free(model->inputs);
446+
447+
for (uint32_t i = 0; i < model->noutputs; i++) {
448+
status = ort->AllocatorFree(allocator, model->outputs[i]);
449+
}
450+
array_free(model->outputs);
451+
384452
ort->ReleaseSession(model->session);
385453

386454
model->model = NULL;

0 commit comments

Comments
 (0)