Skip to content

Commit 9d4b1db

Browse files
committed
Add and use getters for having model number of model inputs and outputs.
1 parent ff2754b commit 9d4b1db

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

src/DAG/dag_builder.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@ int RAI_DAGAddRunOp(RAI_DAGRunCtx *run_info, RAI_DAGRunOp *DAGop, RAI_Error *err
9696
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
9797
if (op->mctx) {
9898
RAI_Model *model = op->mctx->model;
99-
if (model->ninputs != array_len(op->inkeys)) {
99+
if (ModelGetNumInputs(model) != array_len(op->inkeys)) {
100100
RAI_SetError(err, RAI_EDAGBUILDER,
101101
"Number of keys given as INPUTS does not match model definition");
102102
return REDISMODULE_ERR;
103103
}
104-
if (model->noutputs != array_len(op->outkeys)) {
104+
if (ModelGetNumOutputs(model) != array_len(op->outkeys)) {
105105
RAI_SetError(err, RAI_EDAGBUILDER,
106106
"Number of keys given as OUTPUTS does not match model definition");
107107
return REDISMODULE_ERR;

src/model.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ int RedisAI_ModelRun_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx, Redis
254254

255255
RedisModuleType *RAI_ModelRedisType(void) { return RedisAI_ModelType; }
256256

257+
size_t ModelGetNumInputs(RAI_Model *model) { return model->ninputs; }
258+
259+
size_t ModelGetNumOutputs(RAI_Model *model) { return model->noutputs; }
260+
257261
int RAI_ModelRunAsync(RAI_ModelRunCtx *mctx, RAI_OnFinishCB ModelAsyncFinish, void *private_data) {
258262

259263
RedisAI_RunInfo *rinfo = NULL;

src/model.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,15 @@ int RedisAI_ModelRun_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx, Redis
144144
*/
145145
RedisModuleType *RAI_ModelRedisType(void);
146146

147+
/**
148+
* @brief Returns the number of inputs in the model definition.
149+
*/
150+
size_t ModelGetNumInputs(RAI_Model *model);
151+
152+
/**
153+
* @brief Returns the number of outputs in the model definition.
154+
*/
155+
size_t ModelGetNumOutputs(RAI_Model *model);
147156
/**
148157
* Insert the ModelRunCtx to the run queues so it will run asynchronously.
149158
*

0 commit comments

Comments
 (0)